from typing import Callable, Dict, Optional, Tuple
import numpy as np
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from .. import configure
from ..datasets import np_savez
# https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py
EVT_DVS = 0 # DVS event type
EVT_APS = 1 # APS event
[文档]def read_bits(arr, mask=None, shift=None):
if mask is not None:
arr = arr & mask
if shift is not None:
arr = arr >> shift
return arr
y_mask = 0x7FC00000
y_shift = 22
x_mask = 0x003FF000
x_shift = 12
polarity_mask = 0x800
polarity_shift = 11
valid_mask = 0x80000000
valid_shift = 31
[文档]def load_raw_events(fp,
bytes_skip=0,
bytes_trim=0,
filter_dvs=False,
times_first=False):
p = skip_header(fp)
fp.seek(p + bytes_skip)
data = fp.read()
if bytes_trim > 0:
data = data[:-bytes_trim]
data = np.fromstring(data, dtype='>u4')
if len(data) % 2 != 0:
print(data[:20:2])
print('---')
print(data[1:21:2])
raise ValueError('odd number of data elements')
raw_addr = data[::2]
timestamp = data[1::2]
if times_first:
timestamp, raw_addr = raw_addr, timestamp
if filter_dvs:
valid = read_bits(raw_addr, valid_mask, valid_shift) == EVT_DVS
timestamp = timestamp[valid]
raw_addr = raw_addr[valid]
return timestamp, raw_addr
[文档]def parse_raw_address(addr,
x_mask=x_mask,
x_shift=x_shift,
y_mask=y_mask,
y_shift=y_shift,
polarity_mask=polarity_mask,
polarity_shift=polarity_shift):
polarity = read_bits(addr, polarity_mask, polarity_shift).astype(np.bool_)
x = read_bits(addr, x_mask, x_shift)
y = read_bits(addr, y_mask, y_shift)
return x, y, polarity
[文档]def load_events(
fp,
filter_dvs=False,
# bytes_skip=0,
# bytes_trim=0,
# times_first=False,
**kwargs):
timestamp, addr = load_raw_events(
fp,
filter_dvs=filter_dvs,
# bytes_skip=bytes_skip,
# bytes_trim=bytes_trim,
# times_first=times_first
)
x, y, polarity = parse_raw_address(addr, **kwargs)
return timestamp, x, y, polarity
[文档]class CIFAR10DVS(sjds.NeuromorphicDatasetFolder):
def __init__(
self,
root: str,
data_type: str = 'event',
frames_number: int = None,
split_by: str = None,
duration: int = None,
custom_integrate_function: Callable = None,
custom_integrated_frames_dir_name: str = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
The CIFAR10-DVS dataset, which is proposed by `CIFAR10-DVS: An Event-Stream Dataset for Object Classification
<https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.
Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information.
"""
super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform,
target_transform)
[文档] @staticmethod
def resource_url_md5() -> list:
'''
:return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5
:rtype: list
'''
return [
('airplane.zip', 'https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'),
('automobile.zip', 'https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'),
('bird.zip', 'https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'),
('cat.zip', 'https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'),
('deer.zip', 'https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'),
('dog.zip', 'https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'),
('frog.zip', 'https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'),
('horse.zip', 'https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'),
('ship.zip', 'https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'),
('truck.zip', 'https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7')
]
[文档] @staticmethod
def downloadable() -> bool:
'''
:return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
:rtype: bool
'''
return True
[文档] @staticmethod
def load_origin_data(file_name: str) -> Dict:
'''
:param file_name: path of the events file
:type file_name: str
:return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray``
:rtype: Dict
This function defines how to read the origin binary data.
'''
with open(file_name, 'rb') as fp:
t, x, y, p = load_events(fp,
x_mask=0xfE,
x_shift=1,
y_mask=0x7f00,
y_shift=8,
polarity_mask=1,
polarity_shift=None)
# return {'t': t, 'x': 127 - x, 'y': y, 'p': 1 - p.astype(int)} # this will get the same data with http://www2.imse-cnm.csic.es/caviar/MNIST_DVS/dat2mat.m
# see https://github.com/jackd/events-tfds/pull/1 for more details about this problem
return {'t': t, 'x': 127 - y, 'y': 127 - x, 'p': 1 - p.astype(int)}
[文档] @staticmethod
def get_H_W() -> Tuple:
'''
:return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data.
For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset.
:rtype: tuple
'''
return 128, 128
[文档] @staticmethod
def read_aedat_save_to_np(bin_file: str, np_file: str):
events = CIFAR10DVS.load_origin_data(bin_file)
np_savez(np_file,
t=events['t'],
x=events['x'],
y=events['y'],
p=events['p']
)
print(f'Save [{bin_file}] to [{np_file}].')
[文档] @staticmethod
def create_events_np_files(extract_root: str, events_np_root: str):
'''
:param extract_root: Root directory path which saves extracted files from downloaded files
:type extract_root: str
:param events_np_root: Root directory path which saves events files in the ``npz`` format
:type events_np_root:
:return: None
This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``.
'''
t_ckp = time.time()
with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe:
sub_threads = []
for class_name in os.listdir(extract_root):
aedat_dir = os.path.join(extract_root, class_name)
np_dir = os.path.join(events_np_root, class_name)
os.mkdir(np_dir)
print(f'Mkdir [{np_dir}].')
for bin_file in os.listdir(aedat_dir):
source_file = os.path.join(aedat_dir, bin_file)
target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz')
print(f'Start to convert [{source_file}] to [{target_file}].')
sub_threads.append(tpe.submit(CIFAR10DVS.read_aedat_save_to_np, source_file,
target_file))
for sub_thread in sub_threads:
if sub_thread.exception():
print(sub_thread.exception())
exit(-1)
print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')