import os
from pathlib import Path
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Callable, Optional, Tuple, Union
import numpy as np
from torchvision.datasets.utils import extract_archive
from .. import configure
from . import utils
from .base import NeuromorphicDatasetFolder, NeuromorphicDatasetConfig
# https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py
__all__ = [
"CIFAR10DVS_CLASS_NAMES",
"CIFAR10DVS",
"CIFAR10DVSTEBNSplit",
]
CIFAR10DVS_CLASS_NAMES = (
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
EVT_DVS = 0 # DVS event type
EVT_APS = 1 # APS event
def _read_bits(arr, mask=None, shift=None):
if mask is not None:
arr = arr & mask
if shift is not None:
arr = arr >> shift
return arr
y_mask = 0x7FC00000
y_shift = 22
x_mask = 0x003FF000
x_shift = 12
polarity_mask = 0x800
polarity_shift = 11
valid_mask = 0x80000000
valid_shift = 31
def _skip_header(fp):
p = 0
lt = fp.readline()
ltd = lt.decode().strip()
while ltd and ltd[0] == "#":
p += len(lt)
lt = fp.readline()
try:
ltd = lt.decode().strip()
except UnicodeDecodeError:
break
return p
def _load_raw_events(
fp, bytes_skip=0, bytes_trim=0, filter_dvs=False, times_first=False
):
p = _skip_header(fp)
fp.seek(p + bytes_skip)
data = fp.read()
if bytes_trim > 0:
data = data[:-bytes_trim]
data = np.frombuffer(data, dtype=">u4")
if len(data) % 2 != 0:
print(data[:20:2])
print("---")
print(data[1:21:2])
raise ValueError("odd number of data elements")
raw_addr = data[::2]
timestamp = data[1::2]
if times_first:
timestamp, raw_addr = raw_addr, timestamp
if filter_dvs:
valid = _read_bits(raw_addr, valid_mask, valid_shift) == EVT_DVS
timestamp = timestamp[valid]
raw_addr = raw_addr[valid]
return timestamp, raw_addr
def _parse_raw_address(
addr,
x_mask=x_mask,
x_shift=x_shift,
y_mask=y_mask,
y_shift=y_shift,
polarity_mask=polarity_mask,
polarity_shift=polarity_shift,
):
polarity = _read_bits(addr, polarity_mask, polarity_shift).astype(np.bool_)
x = _read_bits(addr, x_mask, x_shift)
y = _read_bits(addr, y_mask, y_shift)
return x, y, polarity
def _load_events(
fp,
filter_dvs=False,
# bytes_skip=0,
# bytes_trim=0,
# times_first=False,
**kwargs,
):
timestamp, addr = _load_raw_events(
fp,
filter_dvs=filter_dvs,
# bytes_skip=bytes_skip,
# bytes_trim=bytes_trim,
# times_first=times_first
)
x, y, polarity = _parse_raw_address(addr, **kwargs)
return timestamp, x, y, polarity
def _load_origin_data(file_name: Union[str, Path]) -> dict:
with open(file_name, "rb") as fp:
t, x, y, p = _load_events(
fp,
x_mask=0xFE,
x_shift=1,
y_mask=0x7F00,
y_shift=8,
polarity_mask=1,
polarity_shift=None,
)
return {"t": t, "x": 127 - y, "y": 127 - x, "p": 1 - p.astype(int)}
def _read_aedat_save_to_np(bin_file: Union[str, Path], np_file: Union[str, Path]):
events = _load_origin_data(bin_file)
utils.np_savez(np_file, t=events["t"], x=events["x"], y=events["y"], p=events["p"])
print(f"Save [{bin_file}] to [{np_file}].")
[文档]
class CIFAR10DVS(NeuromorphicDatasetFolder):
def __init__(
self,
root: str,
data_type: str = "event",
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
**API Language:**
:ref:`中文 <CIFAR10DVS.__init__-cn>` | :ref:`English <CIFAR10DVS.__init__-en>`
----
.. _CIFAR10DVS.__init__-cn:
* **中文**
CIFAR10-DVS 数据集,由 `CIFAR10-DVS: An Event-Stream Dataset for Object Classification <https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_ 提出。
有关参数的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>`
----
.. _CIFAR10DVS.__init__-en:
* **English**
The CIFAR10-DVS dataset, which is proposed by
`CIFAR10-DVS: An Event-Stream Dataset for Object Classification <https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.
Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` for more details about params information.
:param root: 数据集的根路径
:type root: Union[str, Path]
:param data_type: ``\"event\"`` 或 ``\"frame\"``
:type data_type: str
:param frames_number: 积分帧的数量
:type frames_number: Optional[int]
:param split_by: ``\"time\"`` 或 ``\"number\"``
:type split_by: Optional[str]
:param duration: 每帧的时间时长
:type duration: Optional[int]
:param custom_integrate_function: 用户自定义积分函数
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: 自定义积分帧目录名
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: 数据变换
:type transform: Optional[Callable]
:param target_transform: 标签变换
:type target_transform: Optional[Callable]
:param root: Root directory of the dataset
:type root: Union[str, Path]
:param data_type: ``\"event\"`` or ``\"frame\"``
:type data_type: str
:param frames_number: Number of frames to integrate
:type frames_number: Optional[int]
:param split_by: ``\"time\"`` or ``\"number\"``
:type split_by: Optional[str]
:param duration: Time duration per frame
:type duration: Optional[int]
:param custom_integrate_function: User-defined integrate function
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: Custom frames directory name
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: Transform function
:type transform: Optional[Callable]
:param target_transform: Target transform function
:type target_transform: Optional[Callable]
:return: None
:rtype: None
"""
super().__init__(
root,
None,
data_type,
frames_number,
split_by,
duration,
custom_integrate_function,
custom_integrated_frames_dir_name,
transform,
target_transform,
)
[文档]
@classmethod
def get_H_W(cls) -> Tuple:
r"""
**API Language:**
:ref:`中文 <cifar10_dvs.get_H_W-cn>` | :ref:`English <cifar10_dvs.get_H_W-en>`
----
.. _cifar10_dvs.get_H_W-cn:
* **中文**
* **中文**
:return: ``(128, 128)``
:rtype: Tuple
----
.. _cifar10_dvs.get_H_W-en:
* **English**
* **English**
:return: ``(128, 128)``
:rtype: Tuple
"""
return 128, 128
[文档]
@classmethod
def resource_url_md5(cls) -> list:
return [
(
"airplane.zip",
"https://ndownloader.figshare.com/files/7712788",
"0afd5c4bf9ae06af762a77b180354fdd",
),
(
"automobile.zip",
"https://ndownloader.figshare.com/files/7712791",
"8438dfeba3bc970c94962d995b1b9bdd",
),
(
"bird.zip",
"https://ndownloader.figshare.com/files/7712794",
"a9c207c91c55b9dc2002dc21c684d785",
),
(
"cat.zip",
"https://ndownloader.figshare.com/files/7712812",
"52c63c677c2b15fa5146a8daf4d56687",
),
(
"deer.zip",
"https://ndownloader.figshare.com/files/7712815",
"b6bf21f6c04d21ba4e23fc3e36c8a4a3",
),
(
"dog.zip",
"https://ndownloader.figshare.com/files/7712818",
"f379ebdf6703d16e0a690782e62639c3",
),
(
"frog.zip",
"https://ndownloader.figshare.com/files/7712842",
"cad6ed91214b1c7388a5f6ee56d08803",
),
(
"horse.zip",
"https://ndownloader.figshare.com/files/7712851",
"e7cbbf77bec584ffbf913f00e682782a",
),
(
"ship.zip",
"https://ndownloader.figshare.com/files/7712836",
"41c7bd7d6b251be82557c6cce9a7d5c9",
),
(
"truck.zip",
"https://ndownloader.figshare.com/files/7712839",
"89f3922fd147d9aeff89e76a2b0b70a7",
),
]
[文档]
@classmethod
def downloadable(cls) -> bool:
"""
:return: ``True``
:rtype: bool
"""
return True
def _move_data(root: Union[str, Path]):
root = Path(root)
for cn in CIFAR10DVS_CLASS_NAMES:
source = root / cn
target = root / "test" / cn
if not target.exists():
target.mkdir(parents=True)
print(f"mkdir [{target}]")
for i in range(100):
source_file = source / f"cifar10_{cn}_{i}.npz"
target_file = target / f"cifar10_{cn}_{i}.npz"
target_file.symlink_to(source_file)
print(f"symlink: [{target_file}] -> [{source_file}]")
target = root / "train" / cn
if not target.exists():
target.mkdir(parents=True)
print(f"mkdir [{target}]")
for i in range(100, 1000):
source_file = source / f"cifar10_{cn}_{i}.npz"
target_file = target / f"cifar10_{cn}_{i}.npz"
target_file.symlink_to(source_file)
print(f"symlink: [{target_file}] -> [{source_file}]")
[文档]
class CIFAR10DVSTEBNSplit(CIFAR10DVS):
def __init__(
self,
root: str,
train: bool = True,
data_type: str = "event",
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
**API Language:**
:ref:`中文 <CIFAR10DVSTEBNSplit.__init__-cn>` | :ref:`English <CIFAR10DVSTEBNSplit.__init__-en>`
----
.. _CIFAR10DVSTEBNSplit.__init__-cn:
* **中文**
CIFAR10-DVS 数据集,由 `CIFAR10-DVS: An Event-Stream Dataset for Object Classification <https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_ 提出。
原始的 CIFAR10-DVS 数据集不提供训练集和测试集的划分。
在 `Temporal Effective Batch Normalization in Spiking Neural Networks <https://proceedings.neurips.cc/paper_files/paper/2022/hash/de2ad3ed44ee4e675b3be42aa0b615d0-Abstract-Conference.html>`_ 中,
作者使用每个类别中的样本 0-99 作为测试集,100-999 作为训练集。
这种划分被后来的工作广泛使用。此类实现了这种划分。
.. note::
在此划分上的验证准确率通常远高于随机划分的准确率。进行比较时要小心!
有关参数的更多详细信息,请参考 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>`
----
.. _CIFAR10DVSTEBNSplit.__init__-en:
* **English**
The CIFAR10-DVS dataset, which is proposed by
`CIFAR10-DVS: An Event-Stream Dataset for Object Classification <https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.
The original CIFAR10-DVS dataset does not provide train and test split.
In `Temporal Effective Batch Normalization in Spiking Neural Networks <https://proceedings.neurips.cc/paper_files/paper/2022/hash/de2ad3ed44ee4e675b3be42aa0b615d0-Abstract-Conference.html>`_ ,
the authors use sample 0-99 in each class as the test set, and the 100-999 as the train set.
This split is widely used by later works. This class implements this split.
.. note::
The validation accuracy on this split is typically much higher than
that on a random split. Be careful when making comparisons!
Refer to :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>`
for more details about params information.
:param root: 数据集的根路径
:type root: Union[str, Path]
:param train: 是否使用训练集(``True`` 或 ``False``)
:type train: bool
:param data_type: ``\"event\"`` 或 ``\"frame\"``
:type data_type: str
:param frames_number: 积分帧的数量
:type frames_number: Optional[int]
:param split_by: ``\"time\"`` 或 ``\"number\"``
:type split_by: Optional[str]
:param duration: 每帧的时间时长
:type duration: Optional[int]
:param custom_integrate_function: 用户自定义积分函数
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: 自定义积分帧目录名
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: 数据变换
:type transform: Optional[Callable]
:param target_transform: 标签变换
:type target_transform: Optional[Callable]
:param root: Root directory of the dataset
:type root: Union[str, Path]
:param train: Whether to use training set (``True``) or test set (``False``)
:type train: bool
:param data_type: ``\"event\"`` or ``\"frame\"``
:type data_type: str
:param frames_number: Number of frames to integrate
:type frames_number: Optional[int]
:param split_by: ``\"time\"`` or ``\"number\"``
:type split_by: Optional[str]
:param duration: Time duration per frame
:type duration: Optional[int]
:param custom_integrate_function: User-defined integrate function
:type custom_integrate_function: Optional[Callable]
:param custom_integrated_frames_dir_name: Custom frames directory name
:type custom_integrated_frames_dir_name: Optional[str]
:param transform: Transform function
:type transform: Optional[Callable]
:param target_transform: Target transform function
:type target_transform: Optional[Callable]
:return: None
:rtype: None
"""
if train is None:
raise ValueError("`train` must be `True` or `False`")
self.cfg = NeuromorphicDatasetConfig(
root=Path(root),
train=train,
data_type=data_type,
frames_number=frames_number,
split_by=split_by,
duration=duration,
custom_integrate_function=custom_integrate_function,
custom_integrated_frames_dir_name=custom_integrated_frames_dir_name,
transform=transform,
target_transform=target_transform,
)
self.prepare_raw_dataset()
builder = self.get_dataset_builder()
self.processed_root, loader = builder.build()
split_root = self.processed_root / ("train" if self.cfg.train else "test")
if not split_root.exists():
print(
f"We have the unsplit processed dataset at [{self.processed_root}]. "
f"_move_data() is called to split the dataset following TEBN's approach."
)
_move_data(self.processed_root)
print("CIFAR10-DVS has been split after TEBN's approach.")
super(NeuromorphicDatasetFolder, self).__init__(
root=split_root,
loader=loader,
extensions=self.get_extensions(),
transform=self.cfg.transform,
target_transform=self.cfg.target_transform,
)