spikingjelly.datasets.speechcommands 源代码

import os
from typing import Callable, Tuple, Dict, Optional
from pathlib import Path

import torch
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torchvision.datasets.utils import (
from torchvision.datasets.utils import verify_str_arg
import numpy as np
from random import choice

FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
VAL_RECORD = "validation_list.txt"
TEST_RECORD = "testing_list.txt"
TRAIN_RECORD = "training_list.txt"

[文档]def load_speechcommands_item(relpath: str, path: str) -> Tuple[Tensor, int, str, str, int]: filepath = os.path.join(path, relpath) label, filename = os.path.split(relpath) speaker, _ = os.path.splitext(filename) speaker_id, utterance_number = speaker.split(HASH_DIVIDER) utterance_number = int(utterance_number) # Load audio waveform, sample_rate = torchaudio.load(filepath) return waveform, sample_rate, label, speaker_id, utterance_number
[文档]class SPEECHCOMMANDS(Dataset): def __init__(self, label_dict: Dict, root: str, silence_cnt: Optional[int] = 0, silence_size: Optional[int] = 16000, transform: Optional[Callable] = None, url: Optional[str] = URL, split: Optional[str] = "train", folder_in_archive: Optional[str] = FOLDER_IN_ARCHIVE, download: Optional[bool] = False) -> None: ''' :param label_dict: 标签与类别的对应字典 :type label_dict: Dict :param root: 数据集的根目录 :type root: str :param silence_cnt: Silence数据的数量 :type silence_cnt: int, optional :param silence_size: Silence数据的尺寸 :type silence_size: int, optional :param transform: A function/transform that takes in a raw audio :type transform: Callable, optional :param url: 数据集版本,默认为v0.02 :type url: str, optional :param split: 数据集划分,可以是 ``"train", "test", "val"``,默认为 ``"train"`` :type split: str, optional :param folder_in_archive: 解压后的目录名称,默认为 ``"SpeechCommands"`` :type folder_in_archive: str, optional :param download: 是否下载数据,默认为False :type download: bool, optional SpeechCommands语音数据集,出自 `Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition <https://arxiv.org/abs/1804.03209>`_,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。 数据集包含三大类单词的音频: #. 指令单词,共10个,"Yes", "No", "Up", "Down", "Left", "Right", "On", "Off", "Stop", "Go". 对于v0.02,还额外增加了5个:"Forward", "Backward", "Follow", "Learn", "Visual". #. 0~9的数字,共10个:"One", "Two", "Three", "Four", "Five", "Six", "Seven", "Eight", "Nine". #. 辅助词,可以视为干扰词,共10个:"Bed", "Bird", "Cat", "Dog", "Happy", "House", "Marvin", "Sheila", "Tree", "Wow". v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。 代码实现基于torchaudio并扩充了功能,同时也参考了 `原论文的实现 <https://github.com/romainzimmer/s2net/blob/b073f755e70966ef133bbcd4a8f0343354f5edcd/data.py>`_。 ''' self.split = verify_str_arg(split, "split", ("train", "val", "test")) self.label_dict = label_dict self.transform = transform self.silence_cnt = silence_cnt self.silence_size = silence_size if silence_cnt < 0: raise ValueError(f"Invalid silence_cnt parameter: {silence_cnt}") if silence_size <= 0: raise ValueError(f"Invalid silence_size parameter: {silence_size}") if url in [ "speech_commands_v0.01", "speech_commands_v0.02", ]: base_url = "https://storage.googleapis.com/download.tensorflow.org/data/" ext_archive = ".tar.gz" url = os.path.join(base_url, url + ext_archive) basename = os.path.basename(url) archive = os.path.join(root, basename) basename = basename.rsplit(".", 2)[0] folder_in_archive = os.path.join(folder_in_archive, basename) self._path = os.path.join(root, folder_in_archive) self.noise_list = sorted(str(p) for p in Path(self._path).glob('_background_noise_/*.wav')) if download: if not os.path.isdir(self._path): if not os.path.isfile(archive): checksum = _CHECKSUMS.get(url, None) download_url(url, root, md5=checksum) extract_archive(archive, self._path) elif not os.path.isdir(self._path): raise FileNotFoundError("Audio data not found. Please specify \"download=True\" and try again.") if self.split == "train": record = os.path.join(self._path, TRAIN_RECORD) if os.path.exists(record): with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f]) else: print("No training list, generating...") walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav')) walker = filter(lambda w: HASH_DIVIDER in w and EXCEPT_FOLDER not in w, walker) walker = map(lambda w: os.path.relpath(w, self._path), walker) walker = set(walker) val_record = os.path.join(self._path, VAL_RECORD) with open(val_record, 'r') as f: val_walker = set([line.rstrip('\n') for line in f]) test_record = os.path.join(self._path, TEST_RECORD) with open(test_record, 'r') as f: test_walker = set([line.rstrip('\n') for line in f]) walker = walker - val_walker - test_walker self._walker = list(walker) with open(record, 'w') as f: f.write('\n'.join(self._walker)) print("Training list generated!") labels = [self.label_dict.get(os.path.split(relpath)[0]) for relpath in self._walker] label_weights = 1. / np.unique(labels, return_counts=True)[1] if self.silence_cnt == 0: label_weights /= np.sum(label_weights) self.weights = torch.DoubleTensor([label_weights[label] for label in labels]) else: silence_weight = 1. / self.silence_cnt total_weight = np.sum(label_weights) + silence_weight label_weights /= total_weight self.weights = torch.DoubleTensor([label_weights[label] for label in labels] + [silence_weight / total_weight] * self.silence_cnt) else: if self.split == "val": record = os.path.join(self._path, VAL_RECORD) else: record = os.path.join(self._path, TEST_RECORD) with open(record, 'r') as f: self._walker = list([line.rstrip('\n') for line in f]) def __getitem__(self, n: int) -> Tuple[Tensor, int]: if n < len(self._walker): fileid = self._walker[n] waveform, sample_rate, label, speaker_id, utterance_number = load_speechcommands_item(fileid, self._path) else: # Silence data are randomly and dynamically generated from noise data # Load random noise noisepath = choice(self.noise_list) waveform, sample_rate = torchaudio.load(noisepath) # Random crop offset = np.random.randint(waveform.shape[1] - self.silence_size) waveform = waveform[:, offset:offset + self.silence_size] label = "_silence_" m = waveform.abs().max() if m > 0: waveform /= m if self.transform is not None: waveform = self.transform(waveform) label = self.label_dict.get(label) return waveform, label def __len__(self) -> int: return len(self._walker) + self.silence_cnt