spikingjelly.datasets.speechcommands 源代码

import os
from typing import Callable, Tuple, Optional
from pathlib import Path
from random import choice

import torch
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torchvision.datasets.utils import download_url, extract_archive
from torchvision.datasets.utils import verify_str_arg
import numpy as np


__all__ = ["SpeechCommands", "SPEECHCOMMANDS"]


FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
_CHECKSUMS = {
    "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz": "3cd23799cb2bbdec517f1cc028f8d43c",
    "https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz": "6b74f3901214cb2c2934e98196829835",
}
VAL_RECORD = "validation_list.txt"
TEST_RECORD = "testing_list.txt"
TRAIN_RECORD = "training_list.txt"


def _load_speechcommands_item(
    relpath: str, path: str
) -> Tuple[Tensor, int, str, str, int]:
    filepath = os.path.join(path, relpath)
    label, filename = os.path.split(relpath)
    speaker, _ = os.path.splitext(filename)

    speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
    utterance_number = int(utterance_number)

    # Load audio
    waveform, sample_rate = torchaudio.load(filepath)
    return waveform, sample_rate, label, speaker_id, utterance_number


[文档] class SpeechCommands(Dataset): def __init__( self, label_dict: dict, root: str, silence_cnt: Optional[int] = 0, silence_size: Optional[int] = 16000, transform: Optional[Callable] = None, url: Optional[str] = URL, split: Optional[str] = "train", folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE, download: Optional[bool] = False, ) -> None: r""" **API Language:** :ref:`中文 <SpeechCommands.__init__-cn>` | :ref:`English <SpeechCommands.__init__-en>` ---- .. _SpeechCommands.__init__-cn: * **中文** SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_ , 根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。 数据集包含三大类单词的音频: #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual". #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine". #. 辅助词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow". v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。 更详细的介绍参见前述论文,以及数据集的README。 代码实现基于torchaudio并扩充了功能,同时也参考了 `原论文的实现 <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_。 .. note:: SpeechCommands 并非神经形态数据集。因此, :class:`SpeechCommands` 并不继承自 :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` , 而是继承自 :class:`torch.utils.data.Dataset` . :param label_dict: 标签与类别的对应字典 :type label_dict: dict :param root: 数据集的根目录 :type root: str :param silence_cnt: Silence数据的数量 :type silence_cnt: Optional[int] :param silence_size: Silence数据的尺寸 :type silence_size: Optional[int] :param transform: 对原始音频的变换/处理函数,输入为原始音频波形,输出为变换后的音频 :type transform: Optional[Callable] :param url: 数据集版本,默认为v0.02 :type url: Optional[str] :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"`` :type split: Optional[str] :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"`` :type folder_in_archive: Optional[str] :param download: 是否下载数据,默认为False :type download: Optional[bool] ---- .. _SpeechCommands.__init__-en: * **English** The SpeechCommands dataset, from `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_, is divided based on provided test set and validation set lists, containing both v0.01 and v0.02 versions. The dataset contains audio of three major categories of words: #. Command words, totaling 10: "Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". For v0.02, 5 additional words are included: "Forward", "Backward", "Follow", "Learn", "Visual". #. Numbers 0-9, totaling 10: "One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine". #. Auxiliary words, can be considered as noise words, totaling 10: "Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow". The v0.01 version contains a total of 30 classes and 64,727 audio clips, while the v0.02 version contains a total of 35 classes and 105,829 audio clips. For more details, please refer to the aforementioned paper and the dataset's README. The code implementation is based on torchaudio with expanded functionality, and also refers to `the original paper implementation <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_. .. note:: SpeechCommands is not a neuromorphic dataset. Therefore, :class:`SpeechCommands` does not inherit from :class:`NeuromorphicDatasetFolder <spikingjelly.datasets.base.NeuromorphicDatasetFolder>` , but instead inherits from :class:`torch.utils.data.Dataset`. :param label_dict: dictionary mapping labels to categories :type label_dict: dict :param root: root directory of the dataset :type root: str :param silence_cnt: number of Silence data samples :type silence_cnt: Optional[int] :param silence_size: size of Silence data samples :type silence_size: Optional[int] :param transform: a function/transform that takes in a raw audio :type transform: Optional[Callable] :param url: dataset version, default is v0.02 :type url: Optional[str] :param split: dataset split, can be ``"train", "test", "val"``, default is ``"train"`` :type split: Optional[str] :param folder_in_archive: directory name after extraction, default is ``"SpeechCommands"`` :type folder_in_archive: Optional[str] :param download: whether to download the dataset, default is False :type download: Optional[bool] :return: None :rtype: None """ self.split = verify_str_arg(split, "split", ("train", "val", "test")) self.label_dict = label_dict self.transform = transform self.silence_cnt = silence_cnt self.silence_size = silence_size if silence_cnt < 0: raise ValueError(f"Invalid silence_cnt parameter: {silence_cnt}") if silence_size <= 0: raise ValueError(f"Invalid silence_size parameter: {silence_size}") if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) self.noise_list = sorted( str(p) for p in Path(self._path).glob("_background_noise_/*.wav") ) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, md5=checksum) extract_archive(archive, self._path) elif not os.path.isdir(self._path): raise FileNotFoundError( 'Audio data not found. Please specify "download=True" and try again.' ) if self.split == "train": record = os.path.join(self._path, TRAIN_RECORD) if os.path.exists(record): with open(record, "r") as f: self._walker = list([line.rstrip("\n") for line in f]) else: print("No training list, generating...") walker = sorted(str(p) for p in Path(self._path).glob("*/*.wav")) walker = filter( lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker ) walker = map(lambda w: os.path.relpath(w, self._path), walker) walker = set(walker) val_record = os.path.join(self._path, VAL_RECORD) with open(val_record, "r") as f: val_walker = set([line.rstrip("\n") for line in f]) test_record = os.path.join(self._path, TEST_RECORD) with open(test_record, "r") as f: test_walker = set([line.rstrip("\n") for line in f]) walker = walker - val_walker - test_walker self._walker = list(walker) with open(record, "w") as f: f.write("\n".join(self._walker)) print("Training list generated!") labels = [ self.label_dict.get(os.path.split(relpath)[0]) for relpath in self._walker ] label_weights = 1.0 / np.unique(labels, return_counts=True)[1] if self.silence_cnt == 0: label_weights /= np.sum(label_weights) self.weights = torch.DoubleTensor( [label_weights[label] for label in labels] ) else: silence_weight = 1.0 / self.silence_cnt total_weight = np.sum(label_weights) + silence_weight label_weights /= total_weight self.weights = torch.DoubleTensor( [label_weights[label] for label in labels] + [silence_weight / total_weight] * self.silence_cnt ) else: if self.split == "val": record = os.path.join(self._path, VAL_RECORD) else: record = os.path.join(self._path, TEST_RECORD) with open(record, "r") as f: self._walker = list([line.rstrip("\n") for line in f]) def __getitem__(self, n: int) -> Tuple[Tensor, int]: if n < len(self._walker): fileid = self._walker[n] waveform, sample_rate, label, speaker_id, utterance_number = ( _load_speechcommands_item(fileid, self._path) ) else: # Silence data are randomly and dynamically generated from noise data # Load random noise noisepath = choice(self.noise_list) waveform, sample_rate = torchaudio.load(noisepath) # Random crop offset = np.random.randint(waveform.shape[1] - self.silence_size) waveform = waveform[:, offset : offset + self.silence_size] label = "_silence_" m = waveform.abs().max() if m > 0: waveform /= m if self.transform is not None: waveform = self.transform(waveform) label = self.label_dict.get(label) return waveform, label def __len__(self) -> int: return len(self._walker) + self.silence_cnt
SPEECHCOMMANDS = SpeechCommands # for backward compatibility