from typing import Callable, Optional, Tuple
import os
from pathlib import Path
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
import math
import h5py
import numpy as np
from torchvision.datasets.utils import extract_archive
from .. import configure
from . import utils
from .base import NeuromorphicDatasetFolder
from .base import NeuromorphicDatasetBuilder
from .base import NeuromorphicDatasetConfig
__all__ = [
"SHD_N_CLASSES",
"SpikingHeidelbergDigits",
"SSC_N_CLASSES",
"SpikingSpeechCommands",
]
SHD_N_CLASSES = 20
SSC_N_CLASSES = 35
def _cal_fixed_frames_number_segment_index(
events_t: np.ndarray, split_by: str, frames_num: int
) -> tuple:
j_l = np.zeros(shape=[frames_num], dtype=int)
j_r = np.zeros(shape=[frames_num], dtype=int)
N = events_t.size
if split_by == "number":
di = N // frames_num
for i in range(frames_num):
j_l[i] = i * di
j_r[i] = j_l[i] + di
j_r[-1] = N
elif split_by == "time":
dt = (
events_t[-1] - events_t[0]
) / frames_num # different from utils.cal_fixed_frames_number_segment_index
idx = np.arange(N)
for i in range(frames_num):
t_l = dt * i + events_t[0]
t_r = t_l + dt
mask = np.logical_and(events_t >= t_l, events_t < t_r)
idx_masked = idx[mask]
j_l[i] = idx_masked[0]
j_r[i] = idx_masked[-1] + 1
j_r[-1] = N
else:
raise NotImplementedError
return j_l, j_r
def _integrate_events_segment_to_frame(
x: np.ndarray, W: int, j_l: int = 0, j_r: int = -1
) -> np.ndarray:
frame = np.zeros(shape=[W])
x = x[j_l:j_r].astype(int) # avoid overflow
position = x
events_number_per_pos = np.bincount(position)
frame[np.arange(events_number_per_pos.size)] += events_number_per_pos
return frame
def _integrate_events_by_fixed_frames_number(
events: dict, split_by: str, frames_num: int, W: int
) -> np.ndarray:
t, x = (events[key] for key in ("t", "x"))
j_l, j_r = _cal_fixed_frames_number_segment_index(t, split_by, frames_num)
frames = np.zeros([frames_num, W])
for i in range(frames_num):
frames[i] = _integrate_events_segment_to_frame(x, W, j_l[i], j_r[i])
return frames
def _integrate_events_file_to_frames_file_by_fixed_frames_number(
h5_file: h5py.File,
i: int,
output_dir: str,
split_by: str,
frames_num: int,
W: int,
print_save: bool = False,
) -> None:
events = {"t": h5_file["spikes"]["times"][i], "x": h5_file["spikes"]["units"][i]}
label = h5_file["labels"][i]
fname = os.path.join(output_dir, str(label), str(i))
utils.np_savez(
fname,
frames=_integrate_events_by_fixed_frames_number(
events, split_by, frames_num, W
),
)
if print_save:
print(f"Frames [{fname}] saved.")
def _integrate_events_by_fixed_duration(
events: dict, duration: int, W: int
) -> np.ndarray:
x = events["x"]
t = 1000 * events["t"]
t = t - t[0]
N = t.size
frames_num = int(math.ceil(t[-1] / duration))
frames = np.zeros([frames_num, W])
frame_index = t // duration
left = 0
for i in range(frames_num - 1):
right = np.searchsorted(frame_index, i + 1, side="left")
frames[i] = _integrate_events_segment_to_frame(x, W, left, right)
left = right
frames[-1] = _integrate_events_segment_to_frame(x, W, left, N)
return frames
def _integrate_events_file_to_frames_file_by_fixed_duration(
h5_file: h5py.File,
i: int,
output_dir: str,
duration: int,
W: int,
print_save: bool = False,
) -> None:
events = {"t": h5_file["spikes"]["times"][i], "x": h5_file["spikes"]["units"][i]}
label = h5_file["labels"][i]
fname = os.path.join(output_dir, str(label), str(i))
frames = _integrate_events_by_fixed_duration(events, duration, W)
utils.np_savez(fname, frames=frames)
if print_save:
print(f"Frames [{fname}] saved.")
return frames.shape[0]
class NullBuilder(NeuromorphicDatasetBuilder):
def build_impl(self) -> None:
pass
def build(self) -> Tuple[Path, Callable]:
return self.processed_root, self.get_loader()
@property
def processed_root(self) -> Path:
return self.raw_root
def get_loader(self) -> Callable:
return lambda x: x
class SHDFrameFixedNumberBuilder(NeuromorphicDatasetBuilder):
def __init__(
self,
cfg: NeuromorphicDatasetConfig,
raw_root: Path,
W: int,
dataset_name: str = "shd",
splits: Tuple[str] = ("train", "test"),
n_classes: int = SHD_N_CLASSES,
):
super().__init__(cfg, raw_root)
self.W = W
self.dataset_name = dataset_name
self.splits = splits
self.n_classes = n_classes
def build_impl(self):
for split in self.splits:
processed_root = self.processed_root / split
processed_root.mkdir()
print(f"Mkdir [{processed_root}]")
for i in range(self.n_classes):
processed_class_root = processed_root / str(i)
processed_class_root.mkdir()
print(f"Mkdir [{processed_class_root}]")
t_ckp = time.time()
with ThreadPoolExecutor(
max_workers=configure.max_threads_number_for_datasets_preprocess
) as tpe:
futures = []
print(
f"Start ThreadPoolExecutor with max workers = [{tpe._max_workers}]."
)
h5_file = h5py.File(self.raw_root / f"{self.dataset_name}_{split}.h5")
for i in range(len(h5_file["labels"])):
print(
f"Start to integrate [{i}]-th {split} sample to frames and "
f"save to [{processed_root}]."
)
futures.append(
tpe.submit(
_integrate_events_file_to_frames_file_by_fixed_frames_number,
h5_file,
i,
processed_root,
self.cfg.split_by,
self.cfg.frames_number,
self.W,
True,
)
)
for future in futures:
future.result()
print(f"Used time = [{round(time.time() - t_ckp, 2)}s].")
@property
def processed_root(self) -> Path:
return (
self.cfg.root
/ f"frames_number_{self.cfg.frames_number}_split_by_{self.cfg.split_by}"
)
def get_loader(self):
return utils.load_npz_frames
class SHDFrameFixedDurationBuilder(NeuromorphicDatasetBuilder):
def __init__(
self,
cfg: NeuromorphicDatasetConfig,
raw_root: Path,
W: int,
dataset_name: str = "shd",
splits: Tuple[str] = ("train", "test"),
n_classes: int = SHD_N_CLASSES,
):
super().__init__(cfg, raw_root)
self.W = W
self.dataset_name = dataset_name
self.splits = splits
self.n_classes = n_classes
def build_impl(self):
for split in self.splits:
processed_root = self.processed_root / split
processed_root.mkdir()
print(f"Mkdir [{processed_root}]")
for i in range(self.n_classes):
processed_class_root = processed_root / str(i)
processed_class_root.mkdir()
print(f"Mkdir [{processed_class_root}]")
t_ckp = time.time()
with ThreadPoolExecutor(
max_workers=configure.max_threads_number_for_datasets_preprocess
) as tpe:
futures = []
print(
f"Start ThreadPoolExecutor with max workers = [{tpe._max_workers}]."
)
h5_file = h5py.File(self.raw_root / f"{self.dataset_name}_{split}.h5")
for i in range(len(h5_file["labels"])):
print(
f"Start to integrate [{i}]-th {split} sample to frames and "
f"save to [{processed_root}]."
)
futures.append(
tpe.submit(
_integrate_events_file_to_frames_file_by_fixed_duration,
h5_file,
i,
processed_root,
self.cfg.duration,
self.W,
True,
)
)
for future in futures:
future.result()
print(f"Used time = [{round(time.time() - t_ckp, 2)}s].")
@property
def processed_root(self) -> Path:
return self.cfg.root / f"duration_{self.cfg.duration}"
def get_loader(self):
return utils.load_npz_frames
class SHDFrameCustomIntegrateBuilder(NeuromorphicDatasetBuilder):
def __init__(
self,
cfg: NeuromorphicDatasetConfig,
raw_root: Path,
W: int,
dataset_name: str = "shd",
splits: Tuple[str] = ("train", "test"),
n_classes: int = SHD_N_CLASSES,
):
super().__init__(cfg, raw_root)
self.W = W
self.dataset_name = dataset_name
self.splits = splits
self.n_classes = n_classes
def build_impl(self):
for split in self.splits:
processed_root = self.processed_root / split
processed_root.mkdir()
print(f"Mkdir [{processed_root}]")
for i in range(self.n_classes):
processed_class_root = processed_root / str(i)
processed_class_root.mkdir()
print(f"Mkdir [{processed_class_root}]")
t_ckp = time.time()
with ThreadPoolExecutor(
max_workers=configure.max_threads_number_for_datasets_preprocess
) as tpe:
futures = []
print(
f"Start ThreadPoolExecutor with max workers = [{tpe._max_workers}]."
)
h5_file = h5py.File(self.raw_root / f"{self.dataset_name}_{split}.h5")
for i in range(len(h5_file["labels"])):
print(
f"Start to integrate [{i}]-th {split} sample to frames and "
f"save to [{processed_root}]."
)
futures.append(
tpe.submit(
self.cfg.custom_integrate_function,
h5_file,
i,
processed_root,
self.W,
)
)
for future in futures:
future.result()
print(f"Used time = [{round(time.time() - t_ckp, 2)}s].")
@property
def processed_root(self) -> Path:
name = self.cfg.custom_integrated_frames_dir_name
if name is None:
name = self.cfg.custom_integrate_function.__name__
return self.cfg.root / name
def get_loader(self):
return utils.load_npz_frames
[文档]
class SpikingHeidelbergDigits(NeuromorphicDatasetFolder):
def __init__(
self,
root: str,
train: bool = True,
data_type: str = "event",
frames_number: Optional[int] = None,
split_by: Optional[str] = None,
duration: Optional[int] = None,
custom_integrate_function: Optional[Callable] = None,
custom_integrated_frames_dir_name: Optional[str] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
**API Language:**
:ref:`中文 <SpikingHeidelbergDigits.__init__-cn>` | :ref:`English <SpikingHeidelbergDigits.__init__-en>`
----
.. _SpikingHeidelbergDigits.__init__-cn:
* **中文**
Spiking Heidelberg Digits (SHD) 数据集,由 `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_ 提出。
有关参数信息的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>`
.. note::
与 SpikingJelly 中的其他数据集不同,SHD 是一个神经形态音频数据集。
#. 此数据集中的事件格式为 ``(x, t)`` 而不是 ``(x, y, t, p)``。
#. 原始数据集通过符号链接复制了解析后的数据集。原始数据集由两个 ``.h5`` 文件组成,而不是一系列 ``.npz`` 文件。
#. 当 ``data_type == "event"`` 时,将绕过 ``DatasetFolder`` 的数据加载过程。否则,将以 ``Dataset`` 样式加载数据。
----
.. _SpikingHeidelbergDigits.__init__-en:
* **English**
The Spiking Heidelberg Digits (SHD) dataset, which is proposed by
`The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_.
Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` for more details about params information.
:param root: 数据集的根路径
:type root: Union[str, Path]
:param train: 是否使用训练集
:type train: Optional[bool]
:param data_type: ``"event"`` 或 ``"frame"``
:type data_type: str
:param frames_number: 积分帧的数量
:type frames_number: Optional[int]
:param split_by: ``"time"`` 或 ``"number"``
:type split_by: Optional[str]
:param duration: 每帧的时间时长
:type duration: Optional[int]
:param custom_integrate_function: 用户自定义积分函数
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: 自定义积分帧目录名
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: 数据变换
:type transform: Optional[Callable]
:param target_transform: 标签变换
:type target_transform: Optional[Callable]
:param root: Root directory of the dataset
:type root: Union[str, Path]
:param train: Whether to use training set or test set
:type train: Optional[bool]
:param data_type: ``"event"`` or ``"frame"``
:type data_type: str
:param frames_number: Number of frames to integrate
:type frames_number: Optional[int]
:param split_by: ``"time"`` or ``"number"``
:type split_by: Optional[str]
:param duration: Time duration per frame
:type duration: Optional[int]
:param custom_integrate_function: User-defined integrate function
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: Custom frames directory name
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: Transform function
:type transform: Optional[Callable]
:param target_transform: Target transform function
:type target_transform: Optional[Callable]
.. note::
Unlike other datasets in SpikingJelly, SHD is a neuromorphic audio dataset.
#. Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``.
#. The raw dataset replicates the extracted dataset (by symbolic links). The raw dataset consists of two ``.h5`` files instead of a series of ``.npz`` files.
#. When ``data_type == "event"``, the data loading procedure of ``DatasetFolder`` will be bypassed. Instead, data will be loaded in ``Dataset`` style.
:return: None
:rtype: None
"""
if train is None:
raise ValueError("`train` must be `True` or `False`")
self.cfg = NeuromorphicDatasetConfig(
root=Path(root),
train=train,
data_type=data_type,
frames_number=frames_number,
split_by=split_by,
duration=duration,
custom_integrate_function=custom_integrate_function,
custom_integrated_frames_dir_name=custom_integrated_frames_dir_name,
transform=transform,
target_transform=target_transform,
)
self.prepare_raw_dataset()
builder = self.get_dataset_builder()
self.processed_root, loader = builder.build()
split_root = self.processed_root / ("train" if self.cfg.train else "test")
if data_type == "event": # init as Dataset
self.transform = transform
self.target_transform = target_transform
else: # init as DatasetFolder
super(NeuromorphicDatasetFolder, self).__init__(
root=split_root,
loader=loader,
extensions=self.get_extensions(),
transform=self.cfg.transform,
target_transform=self.cfg.target_transform,
)
@property
def raw_root(self) -> Path:
"""
``root / "events_h5"``
"""
return self.cfg.root / "events_h5"
[文档]
def get_dataset_builder(self):
if self.cfg.data_type == "event":
# prepare for manual __getitem__
h5_file = self.raw_root / (
"shd_train.h5" if self.cfg.train else "shd_test.h5"
)
self.h5_file = h5py.File(h5_file)
self.length = len(self.h5_file["labels"])
return NullBuilder(self.cfg, self.raw_root)
_, W = self.get_H_W()
if self.cfg.frames_number is not None:
return SHDFrameFixedNumberBuilder(self.cfg, self.raw_root, W)
elif self.cfg.duration is not None:
return SHDFrameFixedDurationBuilder(self.cfg, self.raw_root, W)
elif self.cfg.custom_integrate_function is not None:
return SHDFrameCustomIntegrateBuilder(self.cfg, self.raw_root, W)
else:
# not reachable
raise NotImplementedError(
"Please specify the frames number or duration or "
"custom integrate function."
)
[文档]
@classmethod
def get_H_W(cls) -> Tuple:
"""
:return: ``(None, 700)`` (i.e., 700 channels)
:rtype: Tuple[None, int]
"""
return None, 700
[文档]
@classmethod
def resource_url_md5(cls) -> list:
return [
(
"shd_train.h5.zip",
"https://zenkelab.org/datasets/shd_train.h5.zip",
"f3252aeb598ac776c1b526422d90eecb",
),
(
"shd_test.h5.zip",
"https://zenkelab.org/datasets/shd_test.h5.zip",
"1503a5064faa34311c398fb0a1ed0a6f",
),
]
[文档]
@classmethod
def downloadable(cls) -> bool:
"""
:return: ``True``
:rtype: bool
"""
return True
def __len__(self):
if self.cfg.data_type == "event":
return self.length
return super().__len__()
def __getitem__(self, index):
if self.cfg.data_type != "event":
return super().__getitem__(index)
events = {
"t": self.h5_file["spikes"]["times"][index],
"x": self.h5_file["spikes"]["units"][index],
}
label = self.h5_file["labels"][index]
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
label = self.target_transform(label)
return events, label
[文档]
class SpikingSpeechCommands(NeuromorphicDatasetFolder):
def __init__(
self,
root: str,
split: str = "train", # 'train' | 'valid' | 'test'
data_type: str = "event",
frames_number: Optional[int] = None,
split_by: Optional[str] = None,
duration: Optional[int] = None,
custom_integrate_function: Optional[Callable] = None,
custom_integrated_frames_dir_name: Optional[str] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
**API Language:**
:ref:`中文 <SpikingSpeechCommands.__init__-cn>` | :ref:`English <SpikingSpeechCommands.__init__-en>`
----
.. _SpikingSpeechCommands.__init__-cn:
* **中文**
Spiking Speech Commands (SSC) 数据集,由 `The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_ 提出。
有关参数信息的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>`
.. note::
与 SpikingJelly 中的其他数据集不同,SSC 是一个神经形态音频数据集。
#. 此数据集中的事件格式为 ``(x, t)`` 而不是 ``(x, y, t, p)``。
#. 原始数据集通过符号链接复制了解析后的数据集。原始数据集由三个 ``.h5`` 文件组成,而不是一系列 ``.npz`` 文件。
#. 当 ``data_type == "event"`` 时,将绕过 ``DatasetFolder`` 的数据加载过程。否则,将以 ``Dataset`` 样式加载数据。
----
.. _SpikingSpeechCommands.__init__-en:
* **English**
The Spiking Speech Commands (SSC) dataset, which is proposed by
`The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks <https://doi.org/10.1109/TNNLS.2020.3044364>`_.
Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` for more details about params information.
:param root: 数据集的根路径
:type root: Union[str, Path]
:param split: "train"、"valid" 或 "test"
:type split: str
:param data_type: ``"event"`` 或 ``"frame"``
:type data_type: str
:param frames_number: 积分帧的数量
:type frames_number: Optional[int]
:param split_by: ``"time"`` 或 ``"number"``
:type split_by: Optional[str]
:param duration: 每帧的时间时长
:type duration: Optional[int]
:param custom_integrate_function: 用户自定义积分函数
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: 自定义积分帧目录名
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: 数据变换
:type transform: Optional[Callable]
:param target_transform: 标签变换
:type target_transform: Optional[Callable]
:param root: Root directory of the dataset
:type root: Union[str, Path]
:param split: "train", "valid", or "test"
:type split: str
:param data_type: ``"event"`` or ``"frame"``
:type data_type: str
:param frames_number: Number of frames to integrate
:type frames_number: Optional[int]
:param split_by: ``"time"`` or ``"number"``
:type split_by: Optional[str]
:param duration: Time duration per frame
:type duration: Optional[int]
:param custom_integrate_function: User-defined integrate function
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: Custom frames directory name
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: Transform function
:type transform: Optional[Callable]
:param target_transform: Target transform function
:type target_transform: Optional[Callable]
.. note::
Unlike other datasets in SpikingJelly, SSC is a neuromorphic audio dataset.
#. Events in this dataset are in the format of ``(x, t)`` rather than ``(x, y, t, p)``.
#. The raw dataset replicates the extracted dataset (by symbolic links). The raw dataset consists of three ``.h5`` files instead of a series of ``.npz`` files.
#. When ``data_type == "event"``, the data loading procedure of ``DatasetFolder`` will be bypassed. Instead, data will be loaded in ``Dataset`` style.
:return: None
:rtype: None
"""
self.splits = ("train", "valid", "test")
if split not in self.splits:
raise ValueError(f"Invalid split: {split}; valid splits are {self.splits}")
self.split = split
self.cfg = NeuromorphicDatasetConfig(
root=Path(root),
train=None,
data_type=data_type,
frames_number=frames_number,
split_by=split_by,
duration=duration,
custom_integrate_function=custom_integrate_function,
custom_integrated_frames_dir_name=custom_integrated_frames_dir_name,
transform=transform,
target_transform=target_transform,
)
self.prepare_raw_dataset()
builder = self.get_dataset_builder()
self.processed_root, loader = builder.build()
split_root = self.get_root_when_train_is_none(self.processed_root)
if data_type == "event": # init as Dataset
self.transform = transform
self.target_transform = target_transform
else: # init as DatasetFolder
super(NeuromorphicDatasetFolder, self).__init__(
root=split_root,
loader=loader,
extensions=self.get_extensions(),
transform=self.cfg.transform,
target_transform=self.cfg.target_transform,
)
[文档]
def get_root_when_train_is_none(self, _root: Path):
return _root / self.split
@property
def raw_root(self) -> Path:
"""
``root / "events_h5"``
"""
return self.cfg.root / "events_h5"
[文档]
def get_dataset_builder(self):
if self.cfg.data_type == "event":
# prepare for manual __getitem__
h5_file = self.raw_root / f"ssc_{self.split}.h5"
self.h5_file = h5py.File(h5_file)
self.length = len(self.h5_file["labels"])
return NullBuilder(self.cfg, self.raw_root)
_, W = self.get_H_W()
if self.cfg.frames_number is not None:
return SHDFrameFixedNumberBuilder(
self.cfg,
self.raw_root,
W,
dataset_name="ssc",
splits=self.splits,
n_classes=SSC_N_CLASSES,
)
elif self.cfg.duration is not None:
return SHDFrameFixedDurationBuilder(
self.cfg,
self.raw_root,
W,
dataset_name="ssc",
splits=self.splits,
n_classes=SSC_N_CLASSES,
)
elif self.cfg.custom_integrate_function is not None:
return SHDFrameCustomIntegrateBuilder(
self.cfg,
self.raw_root,
W,
dataset_name="ssc",
splits=self.splits,
n_classes=SSC_N_CLASSES,
)
else:
# not reachable
raise NotImplementedError(
"Please specify the frames number or duration or "
"custom integrate function."
)
[文档]
@classmethod
def get_H_W(cls) -> Tuple:
"""
:return: ``(None, 700)`` (i.e., 700 channels)
:rtype: Tuple[None, int]
"""
return None, 700
[文档]
@classmethod
def resource_url_md5(cls) -> list:
return [
(
"shd_train.h5.zip",
"https://zenkelab.org/datasets/shd_train.h5.zip",
"f3252aeb598ac776c1b526422d90eecb",
),
(
"shd_test.h5.zip",
"https://zenkelab.org/datasets/shd_test.h5.zip",
"1503a5064faa34311c398fb0a1ed0a6f",
),
]
[文档]
@classmethod
def downloadable(cls) -> bool:
"""
:return: ``True``
:rtype: bool
"""
return True
[文档]
@classmethod
def extract_downloaded_files(cls, download_root: Path, extract_root: Path):
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 2)) as tpe:
futures = []
for zip_file in download_root.iterdir():
print(f"Extract [{zip_file}] to [{extract_root}].")
futures.append(tpe.submit(extract_archive, zip_file, extract_root))
for future in futures:
future.result()
[文档]
@classmethod
def create_raw_from_extracted(cls, extract_root: Path, raw_root: Path):
for f in extract_root.iterdir():
target = raw_root / f.name
if target.exists():
continue
target.symlink_to(f)
def __len__(self):
if self.cfg.data_type == "event":
return self.length
return super().__len__()
def __getitem__(self, index):
if self.cfg.data_type != "event":
return super().__getitem__(index)
events = {
"t": self.h5_file["spikes"]["times"][index],
"x": self.h5_file["spikes"]["units"][index],
}
label = self.h5_file["labels"][index]
if self.transform is not None:
events = self.transform(events)
if self.target_transform is not None:
label = self.target_transform(label)
return events, label