from typing import Callable, Dict, Optional, Tuple
import h5py
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets import utils
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
import math
import bisect
from .. import configure
from ..datasets import np_savez
[文档]def cal_fixed_frames_number_segment_index_shd(events_t: np.ndarray, split_by: str, frames_num: int) -> tuple:
j_l = np.zeros(shape=[frames_num], dtype=int)
j_r = np.zeros(shape=[frames_num], dtype=int)
N = events_t.size
if split_by == 'number':
di = N // frames_num
for i in range(frames_num):
j_l[i] = i * di
j_r[i] = j_l[i] + di
j_r[-1] = N
elif split_by == 'time':
dt = (events_t[-1] - events_t[0]) / frames_num
idx = np.arange(N)
for i in range(frames_num):
t_l = dt * i + events_t[0]
t_r = t_l + dt
mask = np.logical_and(events_t >= t_l, events_t < t_r)
idx_masked = idx[mask]
j_l[i] = idx_masked[0]
j_r[i] = idx_masked[-1] + 1
j_r[-1] = N
else:
raise NotImplementedError
return j_l, j_r
[文档]def integrate_events_segment_to_frame_shd(x: np.ndarray, W: int, j_l: int = 0, j_r: int = -1) -> np.ndarray:
frame = np.zeros(shape=[W])
x = x[j_l: j_r].astype(int) # avoid overflow
position = x
events_number_per_pos = np.bincount(position)
frame[np.arange(events_number_per_pos.size)] += events_number_per_pos
return frame
[文档]def integrate_events_by_fixed_frames_number_shd(events: Dict, split_by: str, frames_num: int, W: int) -> np.ndarray:
t, x = (events[key] for key in ('t', 'x'))
j_l, j_r = cal_fixed_frames_number_segment_index_shd(t, split_by, frames_num)
frames = np.zeros([frames_num, W])
for i in range(frames_num):
frames[i] = integrate_events_segment_to_frame_shd(x, W, j_l[i], j_r[i])
return frames
[文档]def integrate_events_file_to_frames_file_by_fixed_frames_number_shd(h5_file: h5py.File, i: int, output_dir: str, split_by: str, frames_num: int, W: int, print_save: bool = False) -> None:
events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]}
label = h5_file['labels'][i]
fname = os.path.join(output_dir, str(label), str(i))
np_savez(fname, frames=integrate_events_by_fixed_frames_number_shd(events, split_by, frames_num, W))
if print_save:
print(f'Frames [{fname}] saved.')
[文档]def integrate_events_by_fixed_duration_shd(events: Dict, duration: int, W: int) -> np.ndarray:
x = events['x']
t = 1000*events['t']
t = t - t[0]
N = t.size
frames_num = int(math.ceil(t[-1] / duration))
frames = np.zeros([frames_num, W])
frame_index = t // duration
left = 0
for i in range(frames_num - 1):
right = np.searchsorted(frame_index, i + 1, side='left')
frames[i] = integrate_events_segment_to_frame_shd(x, W, left, right)
left = right
frames[-1] = integrate_events_segment_to_frame_shd(x, W, left, N)
return frames
[文档]def integrate_events_file_to_frames_file_by_fixed_duration_shd(h5_file: h5py.File, i: int, output_dir: str, duration: int, W: int, print_save: bool = False) -> None:
events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]}
label = h5_file['labels'][i]
fname = os.path.join(output_dir, str(label), str(i))
frames = integrate_events_by_fixed_duration_shd(events, duration, W)
np_savez(fname, frames=frames)
if print_save:
print(f'Frames [{fname}] saved.')
return frames.shape[0]
[文档]def custom_integrate_function_example(h5_file: h5py.File, i: int, output_dir: str, W: int):
events = {'t': h5_file['spikes']['times'][i], 'x': h5_file['spikes']['units'][i]}
label = h5_file['labels'][i]
frames = np.zeros([2, W])
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames[0] = integrate_events_segment_to_frame_shd(events['x'], W, 0, index_split)
frames[1] = integrate_events_segment_to_frame_shd(events['x'], W, index_split, events['t'].__len__())
fname = os.path.join(output_dir, str(label), str(i))
np_savez(fname, frames=frames)
[文档]class SpikingHeidelbergDigits(Dataset):
def __init__(
self,
root: str,
train: bool = None,
data_type: str = 'event',
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
The Spiking Heidelberg Digits (SHD) dataset, which is proposed by `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_.
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
.. admonition:: Note
:class: note
Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``. Thus, this dataset is not inherited from :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` directly. But their procedures are similar.
:class:`spikingjelly.datasets.shd.custom_integrate_function_example` is an example of ``custom_integrate_function``, which is similar to the cunstom function for DVS Gesture in the ``Neuromorphic Datasets Processing`` tutorial.
"""
super().__init__()
self.root = root
self.train = train
self.data_type = data_type
self.frames_number = frames_number
self.split_by = split_by
self.duration = duration
self.custom_integrate_function = custom_integrate_function
self.custom_integrated_frames_dir_name = custom_integrated_frames_dir_name
self.transform = transform
self.target_transform = target_transform
download_root = os.path.join(root, 'download')
extract_root = os.path.join(root, 'extract')
if not os.path.exists(extract_root):
if os.path.exists(download_root):
print(f'The [{download_root}] directory for saving downloaded files already exists, check files...')
# check files
resource_list = self.resource_url_md5()
for i in range(resource_list.__len__()):
file_name, url, md5 = resource_list[i]
fpath = os.path.join(download_root, file_name)
if not utils.check_integrity(fpath=fpath, md5=md5):
print(f'The file [{fpath}] does not exist or is corrupted.')
if os.path.exists(fpath):
# If file is corrupted, we will remove it.
os.remove(fpath)
print(f'Remove [{fpath}]')
if self.downloadable():
# If file does not exist, we will download it.
print(f'Download [{file_name}] from [{url}] to [{download_root}]')
utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
else:
raise NotImplementedError(
f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.')
else:
os.mkdir(download_root)
print(f'Mkdir [{download_root}] to save downloaded files.')
resource_list = self.resource_url_md5()
if self.downloadable():
# download and extract file
for i in range(resource_list.__len__()):
file_name, url, md5 = resource_list[i]
print(f'Download [{file_name}] from [{url}] to [{download_root}]')
utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
else:
raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, '
f'please download files manually and put files at [{download_root}]. '
f'The resources file_name, url, and md5 are: \n{resource_list}')
os.mkdir(extract_root)
print(f'Mkdir [{extract_root}].')
self.extract_downloaded_files(download_root, extract_root)
else:
print(f'The directory [{extract_root}] for saving extracted files already exists.\n'
f'SpikingJelly will not check the data integrity of extracted files.\n'
f'If extracted files are not integrated, please delete [{extract_root}] manually, '
f'then SpikingJelly will re-extract files from [{download_root}].')
# shutil.rmtree(extract_root)
# print(f'Delete [{extract_root}].')
if self.data_type == 'event':
if self.train:
self.h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5'))
else:
self.h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5'))
self.length = self.h5_file['labels'].__len__()
return
elif self.data_type == 'frame':
if frames_number is not None:
assert frames_number > 0 and isinstance(frames_number, int)
assert split_by == 'time' or split_by == 'number'
frames_np_root = os.path.join(root, f'frames_number_{frames_number}_split_by_{split_by}')
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
sub_threads = []
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i,
frames_np_train_root, self.split_by, frames_number, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i,
frames_np_test_root, self.split_by, frames_number, self.get_W(), True))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
elif duration is not None:
assert duration > 0 and isinstance(duration, int)
frames_np_root = os.path.join(root, f'duration_{duration}')
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
sub_threads = []
h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i,
frames_np_train_root, self.duration, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i,
frames_np_test_root, self.duration, self.get_W(), True))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
elif custom_integrate_function is not None:
if custom_integrated_frames_dir_name is None:
custom_integrated_frames_dir_name = custom_integrate_function.__name__
frames_np_root = os.path.join(root, custom_integrated_frames_dir_name)
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
sub_threads = []
h5_file = h5py.File(os.path.join(extract_root, 'shd_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i,
frames_np_train_root, self.get_W()))
h5_file = h5py.File(os.path.join(extract_root, 'shd_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i,
frames_np_test_root, self.get_W()))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
else:
raise ValueError('At least one of "frames_number", "duration" and "custom_integrate_function" should not be None.')
self.frames_path = []
self.frames_label = []
if self.train:
sub_dir = 'train'
else:
sub_dir = 'test'
for i in range(self.classes_number()):
for fname in os.listdir(os.path.join(self.frames_np_root, sub_dir, str(i))):
self.frames_path.append(
os.path.join(self.frames_np_root, sub_dir, str(i), fname)
)
self.frames_label.append(i)
self.length = self.frames_label.__len__()
else:
raise NotImplementedError(self.data_type)
[文档] def classes_number(self):
return 20
def __len__(self):
return self.length
def __getitem__(self, i: int):
if self.data_type == 'event':
events = {'t': self.h5_file['spikes']['times'][i], 'x': self.h5_file['spikes']['units'][i]}
label = self.h5_file['labels'][i]
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
label = self.target_transform(label)
return events, label
elif self.data_type == 'frame':
frames = np.load(self.frames_path[i], allow_pickle=True)['frames'].astype(np.float32)
label = self.frames_label[i]
if self.transform is not None:
frames = self.transform(frames)
if self.target_transform is not None:
label = self.target_transform(label)
return frames, label
[文档] @staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
return [
('shd_train.h5.zip', 'https://zenkelab.org/datasets/shd_train.h5.zip', 'f3252aeb598ac776c1b526422d90eecb'),
('shd_test.h5.zip', 'https://zenkelab.org/datasets/shd_test.h5.zip', '1503a5064faa34311c398fb0a1ed0a6f'),
]
[文档] @staticmethod
def downloadable() -> bool:
'''
:return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
:rtype: bool
'''
return True
[文档] @staticmethod
def get_W():
return 700
[文档]class SpikingSpeechCommands(Dataset):
def __init__(
self,
root: str,
split: str = 'train',
data_type: str = 'event',
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
The Spiking Speech Commands (SSC) dataset, which is proposed by `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_.
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
.. admonition:: Note
:class: note
Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``. Thus, this dataset is not inherited from :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` directly. But their procedures are similar.
:class:`spikingjelly.datasets.shd.custom_integrate_function_example` is an example of ``custom_integrate_function``, which is similar to the cunstom function for DVS Gesture in the ``Neuromorphic Datasets Processing`` tutorial.
"""
super().__init__()
self.root = root
self.split = split
self.data_type = data_type
self.frames_number = frames_number
self.split_by = split_by
self.duration = duration
self.custom_integrate_function = custom_integrate_function
self.custom_integrated_frames_dir_name = custom_integrated_frames_dir_name
self.transform = transform
self.target_transform = target_transform
download_root = os.path.join(root, 'download')
extract_root = os.path.join(root, 'extract')
if not os.path.exists(extract_root):
if os.path.exists(download_root):
print(f'The [{download_root}] directory for saving downloaded files already exists, check files...')
# check files
resource_list = self.resource_url_md5()
for i in range(resource_list.__len__()):
file_name, url, md5 = resource_list[i]
fpath = os.path.join(download_root, file_name)
if not utils.check_integrity(fpath=fpath, md5=md5):
print(f'The file [{fpath}] does not exist or is corrupted.')
if os.path.exists(fpath):
# If file is corrupted, we will remove it.
os.remove(fpath)
print(f'Remove [{fpath}]')
if self.downloadable():
# If file does not exist, we will download it.
print(f'Download [{file_name}] from [{url}] to [{download_root}]')
utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
else:
raise NotImplementedError(
f'This dataset can not be downloaded by SpikingJelly, please download [{file_name}] from [{url}] manually and put files at {download_root}.')
else:
os.mkdir(download_root)
print(f'Mkdir [{download_root}] to save downloaded files.')
resource_list = self.resource_url_md5()
if self.downloadable():
# download and extract file
for i in range(resource_list.__len__()):
file_name, url, md5 = resource_list[i]
print(f'Download [{file_name}] from [{url}] to [{download_root}]')
utils.download_url(url=url, root=download_root, filename=file_name, md5=md5)
else:
raise NotImplementedError(f'This dataset can not be downloaded by SpikingJelly, '
f'please download files manually and put files at [{download_root}]. '
f'The resources file_name, url, and md5 are: \n{resource_list}')
os.mkdir(extract_root)
print(f'Mkdir [{extract_root}].')
self.extract_downloaded_files(download_root, extract_root)
else:
print(f'The directory [{extract_root}] for saving extracted files already exists.\n'
f'SpikingJelly will not check the data integrity of extracted files.\n'
f'If extracted files are not integrated, please delete [{extract_root}] manually, '
f'then SpikingJelly will re-extract files from [{download_root}].')
# shutil.rmtree(extract_root)
# print(f'Delete [{extract_root}].')
if self.data_type == 'event':
if self.split == 'train':
self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5'))
elif self.split == 'valid':
self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5'))
else:
self.h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5'))
self.length = self.h5_file['labels'].__len__()
return
elif self.data_type == 'frame':
if frames_number is not None:
assert frames_number > 0 and isinstance(frames_number, int)
assert split_by == 'time' or split_by == 'number'
frames_np_root = os.path.join(root, f'frames_number_{frames_number}_split_by_{split_by}')
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_valid_root = os.path.join(frames_np_root, 'valid')
os.mkdir(frames_np_valid_root)
print(f'Mkdir [{frames_np_valid_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_valid_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
sub_threads = []
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file,
i, frames_np_train_root, self.split_by, frames_number, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5'))
for i in range(h5_file['labels'].__len__()):
print(f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i,
frames_np_test_root, self.split_by, frames_number, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_frames_number_shd, h5_file, i,
frames_np_test_root, self.split_by, frames_number, self.get_W(), True))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
elif duration is not None:
assert duration > 0 and isinstance(duration, int)
frames_np_root = os.path.join(root, f'duration_{duration}')
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_valid_root = os.path.join(frames_np_root, 'valid')
os.mkdir(frames_np_valid_root)
print(f'Mkdir [{frames_np_valid_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_valid_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
sub_threads = []
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i,
frames_np_train_root, self.duration, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i,
frames_np_valid_root, self.duration, self.get_W(), True))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(integrate_events_file_to_frames_file_by_fixed_duration_shd, h5_file, i,
frames_np_test_root, self.duration, self.get_W(), True))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
elif custom_integrate_function is not None:
if custom_integrated_frames_dir_name is None:
custom_integrated_frames_dir_name = custom_integrate_function.__name__
frames_np_root = os.path.join(root, custom_integrated_frames_dir_name)
if os.path.exists(frames_np_root):
print(f'The directory [{frames_np_root}] already exists.')
else:
os.mkdir(frames_np_root)
print(f'Mkdir [{frames_np_root}].')
frames_np_train_root = os.path.join(frames_np_root, 'train')
os.mkdir(frames_np_train_root)
print(f'Mkdir [{frames_np_train_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_train_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_train_root, str(i))}].')
frames_np_valid_root = os.path.join(frames_np_root, 'valid')
os.mkdir(frames_np_valid_root)
print(f'Mkdir [{frames_np_valid_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_valid_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_valid_root, str(i))}].')
frames_np_test_root = os.path.join(frames_np_root, 'test')
os.mkdir(frames_np_test_root)
print(f'Mkdir [{frames_np_test_root}].')
for i in range(self.classes_number()):
os.mkdir(os.path.join(frames_np_test_root, str(i)))
print(f'Mkdir [{os.path.join(frames_np_test_root, str(i))}].')
# use multi-thread to accelerate
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=configure.max_threads_number_for_datasets_preprocess) as tpe:
print(f'Start ThreadPoolExecutor with max workers = [{tpe._max_workers}].')
sub_threads = []
h5_file = h5py.File(os.path.join(extract_root, 'ssc_train.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th train sample to frames and save to [{frames_np_train_root}].')
sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i,
frames_np_train_root, self.get_W()))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_valid.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th valid sample to frames and save to [{frames_np_valid_root}].')
sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i,
frames_np_valid_root, self.get_W()))
h5_file = h5py.File(os.path.join(extract_root, 'ssc_test.h5'))
for i in range(h5_file['labels'].__len__()):
print(
f'Start to integrate [{i}]-th test sample to frames and save to [{frames_np_test_root}].')
sub_threads.append(tpe.submit(custom_integrate_function, h5_file, i,
frames_np_test_root, self.get_W()))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')
self.frames_np_root = frames_np_root
else:
raise ValueError('At least one of "frames_number", "duration" and "custom_integrate_function" should not be None.')
self.frames_path = []
self.frames_label = []
if self.split == 'train':
sub_dir = 'train'
elif self.split == 'valid':
sub_dir = 'valid'
else:
sub_dir = 'test'
for i in range(self.classes_number()):
for fname in os.listdir(os.path.join(self.frames_np_root, sub_dir, str(i))):
self.frames_path.append(
os.path.join(self.frames_np_root, sub_dir, str(i), fname)
)
self.frames_label.append(i)
self.length = self.frames_label.__len__()
else:
raise NotImplementedError(self.data_type)
[文档] def classes_number(self):
return 35
def __len__(self):
return self.length
def __getitem__(self, i: int):
if self.data_type == 'event':
events = {'t': self.h5_file['spikes']['times'][i], 'x': self.h5_file['spikes']['units'][i]}
label = self.h5_file['labels'][i]
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
label = self.target_transform(label)
return events, label
elif self.data_type == 'frame':
frames = np.load(self.frames_path[i], allow_pickle=True)['frames'].astype(np.float32)
label = self.frames_label[i]
if self.transform is not None:
frames = self.transform(frames)
if self.target_transform is not None:
label = self.target_transform(label)
return frames, label
[文档] @staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
return [
('ssc_train.h5.zip', 'https://zenkelab.org/datasets/ssc_train.h5.zip', 'd102be95e7144fcc0553d1f45ba94170'),
('ssc_valid.h5.zip', 'https://zenkelab.org/datasets/ssc_valid.h5.zip', 'b4eee3516a4a90dd0c71a6ac23a8ae43'),
('ssc_test.h5.zip', 'https://zenkelab.org/datasets/ssc_test.h5.zip', 'a35ff1e9cffdd02a20eb850c17c37748'),
]
[文档] @staticmethod
def downloadable() -> bool:
'''
:return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
:rtype: bool
'''
return True
[文档] @staticmethod
def extract_downloaded_files(download_root: str, extract_root: str):
'''
:param download_root: Root directory path which saves downloaded dataset files
:type download_root: str
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:return: None
This function defines how to extract download files.
'''
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
sub_threads = []
for zip_file in os.listdir(download_root):
zip_file = os.path.join(download_root, zip_file)
print(f'Extract [{zip_file}] to [{extract_root}].')
sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
[文档] @staticmethod
def get_W():
return 700