spikingjelly.datasets.cifar10_dvs 源代码

from typing import Callable, Dict, Optional, Tuple
import numpy as np
from .. import datasets as sjds
from torchvision.datasets.utils import extract_archive
import os
import multiprocessing
from concurrent.futures import ThreadPoolExecutor
import time
from .. import configure
from ..datasets import np_savez
# https://github.com/jackd/events-tfds/blob/master/events_tfds/data_io/aedat.py

EVT_DVS = 0  # DVS event type
EVT_APS = 1  # APS event

[文档]def read_bits(arr, mask=None, shift=None): if mask is not None: arr = arr & mask if shift is not None: arr = arr >> shift return arr
y_mask = 0x7FC00000 y_shift = 22 x_mask = 0x003FF000 x_shift = 12 polarity_mask = 0x800 polarity_shift = 11 valid_mask = 0x80000000 valid_shift = 31
[文档]def skip_header(fp): p = 0 lt = fp.readline() ltd = lt.decode().strip() while ltd and ltd[0] == "#": p += len(lt) lt = fp.readline() try: ltd = lt.decode().strip() except UnicodeDecodeError: break return p
[文档]def load_raw_events(fp, bytes_skip=0, bytes_trim=0, filter_dvs=False, times_first=False): p = skip_header(fp) fp.seek(p + bytes_skip) data = fp.read() if bytes_trim > 0: data = data[:-bytes_trim] data = np.fromstring(data, dtype='>u4') if len(data) % 2 != 0: print(data[:20:2]) print('---') print(data[1:21:2]) raise ValueError('odd number of data elements') raw_addr = data[::2] timestamp = data[1::2] if times_first: timestamp, raw_addr = raw_addr, timestamp if filter_dvs: valid = read_bits(raw_addr, valid_mask, valid_shift) == EVT_DVS timestamp = timestamp[valid] raw_addr = raw_addr[valid] return timestamp, raw_addr
[文档]def parse_raw_address(addr, x_mask=x_mask, x_shift=x_shift, y_mask=y_mask, y_shift=y_shift, polarity_mask=polarity_mask, polarity_shift=polarity_shift): polarity = read_bits(addr, polarity_mask, polarity_shift).astype(np.bool_) x = read_bits(addr, x_mask, x_shift) y = read_bits(addr, y_mask, y_shift) return x, y, polarity
[文档]def load_events( fp, filter_dvs=False, # bytes_skip=0, # bytes_trim=0, # times_first=False, **kwargs): timestamp, addr = load_raw_events( fp, filter_dvs=filter_dvs, # bytes_skip=bytes_skip, # bytes_trim=bytes_trim, # times_first=times_first ) x, y, polarity = parse_raw_address(addr, **kwargs) return timestamp, x, y, polarity
[文档]class CIFAR10DVS(sjds.NeuromorphicDatasetFolder): def __init__( self, root: str, data_type: str = 'event', frames_number: int = None, split_by: str = None, duration: int = None, custom_integrate_function: Callable = None, custom_integrated_frames_dir_name: str = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: """ The CIFAR10-DVS dataset, which is proposed by `CIFAR10-DVS: An Event-Stream Dataset for Object Classification <https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_. Refer to :class:`spikingjelly.datasets.NeuromorphicDatasetFolder` for more details about params information. """ super().__init__(root, None, data_type, frames_number, split_by, duration, custom_integrate_function, custom_integrated_frames_dir_name, transform, target_transform)
[文档] @staticmethod def resource_url_md5() -> list: ''' :return: A list ``url`` that ``url[i]`` is a tuple, which contains the i-th file's name, download link, and MD5 :rtype: list ''' return [ ('airplane.zip', 'https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'), ('automobile.zip', 'https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'), ('bird.zip', 'https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'), ('cat.zip', 'https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'), ('deer.zip', 'https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'), ('dog.zip', 'https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'), ('frog.zip', 'https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'), ('horse.zip', 'https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'), ('ship.zip', 'https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'), ('truck.zip', 'https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7') ]
[文档] @staticmethod def downloadable() -> bool: ''' :return: Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually :rtype: bool ''' return True
[文档] @staticmethod def extract_downloaded_files(download_root: str, extract_root: str): ''' :param download_root: Root directory path which saves downloaded dataset files :type download_root: str :param extract_root: Root directory path which saves extracted files from downloaded files :type extract_root: str :return: None This function defines how to extract download files. ''' with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), 10)) as tpe: sub_threads = [] for zip_file in os.listdir(download_root): zip_file = os.path.join(download_root, zip_file) print(f'Extract [{zip_file}] to [{extract_root}].') sub_threads.append(tpe.submit(extract_archive, zip_file, extract_root)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1)
[文档] @staticmethod def load_origin_data(file_name: str) -> Dict: ''' :param file_name: path of the events file :type file_name: str :return: a dict whose keys are ``['t', 'x', 'y', 'p']`` and values are ``numpy.ndarray`` :rtype: Dict This function defines how to read the origin binary data. ''' with open(file_name, 'rb') as fp: t, x, y, p = load_events(fp, x_mask=0xfE, x_shift=1, y_mask=0x7f00, y_shift=8, polarity_mask=1, polarity_shift=None) # return {'t': t, 'x': 127 - x, 'y': y, 'p': 1 - p.astype(int)} # this will get the same data with http://www2.imse-cnm.csic.es/caviar/MNIST_DVS/dat2mat.m # see https://github.com/jackd/events-tfds/pull/1 for more details about this problem return {'t': t, 'x': 127 - y, 'y': 127 - x, 'p': 1 - p.astype(int)}
[文档] @staticmethod def get_H_W() -> Tuple: ''' :return: A tuple ``(H, W)``, where ``H`` is the height of the data and ``W` is the weight of the data. For example, this function returns ``(128, 128)`` for the DVS128 Gesture dataset. :rtype: tuple ''' return 128, 128
[文档] @staticmethod def read_aedat_save_to_np(bin_file: str, np_file: str): events = CIFAR10DVS.load_origin_data(bin_file) np_savez(np_file, t=events['t'], x=events['x'], y=events['y'], p=events['p'] ) print(f'Save [{bin_file}] to [{np_file}].')
[文档] @staticmethod def create_events_np_files(extract_root: str, events_np_root: str): ''' :param extract_root: Root directory path which saves extracted files from downloaded files :type extract_root: str :param events_np_root: Root directory path which saves events files in the ``npz`` format :type events_np_root: :return: None This function defines how to convert the origin binary data in ``extract_root`` to ``npz`` format and save converted files in ``events_np_root``. ''' t_ckp = time.time() with ThreadPoolExecutor(max_workers=min(multiprocessing.cpu_count(), configure.max_threads_number_for_datasets_preprocess)) as tpe: sub_threads = [] for class_name in os.listdir(extract_root): aedat_dir = os.path.join(extract_root, class_name) np_dir = os.path.join(events_np_root, class_name) os.mkdir(np_dir) print(f'Mkdir [{np_dir}].') for bin_file in os.listdir(aedat_dir): source_file = os.path.join(aedat_dir, bin_file) target_file = os.path.join(np_dir, os.path.splitext(bin_file)[0] + '.npz') print(f'Start to convert [{source_file}] to [{target_file}].') sub_threads.append(tpe.submit(CIFAR10DVS.read_aedat_save_to_np, source_file, target_file)) for sub_thread in sub_threads: if sub_thread.exception(): print(sub_thread.exception()) exit(-1) print(f'Used time = [{round(time.time() - t_ckp, 2)}s].')