spikingjelly.activation_based.examples.speechcommands 源代码

r"""
.. codeauthor:: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>

A reproduction of the paper `Technical report: supervised training of convolutional spiking neural networks with PyTorch <https://arxiv.org/pdf/1911.10124.pdf>`_\ .

This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.

..  note::

    To prevent too much dependency like `librosa <https://librosa.org/doc/latest/index.html>`_, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.

Confusion matrix of TEST set after training (50 epochs):

+------------------------+--------------------------------------------------------------------------------------------------+
| Count                  | Prediction                                                                                       |
|                        +-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|                        | "Yes" | "Stop" | "No" | "Right" | "Up" | "Left" | "On" | "Down" | "Off" | "Go" | Other | Silence |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| Ground Truth | "Yes"   | 234   | 0      | 2    | 0       | 0    | 3      | 0    | 0      | 0     | 1    | 16    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Stop"  | 0     | 233    | 0    | 1       | 5    | 0      | 0    | 0      | 0     | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "No"    | 0     | 1      | 223  | 1       | 0    | 1      | 0    | 5      | 0     | 9    | 12    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Right" | 0     | 0      | 0    | 234     | 0    | 0      | 0    | 0      | 0     | 0    | 24    | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Up"    | 0     | 4      | 0    | 0       | 249  | 0      | 0    | 0      | 8     | 0    | 11    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Left"  | 3     | 1      | 2    | 3       | 1    | 250    | 0    | 0      | 1     | 0    | 6     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "On"    | 0     | 3      | 0    | 0       | 0    | 0      | 231  | 0      | 2     | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Down"  | 0     | 0      | 7    | 0       | 0    | 1      | 2    | 230    | 0     | 4    | 8     | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Off"   | 0     | 0      | 2    | 1       | 4    | 2      | 6    | 0      | 237   | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Go"    | 0     | 2      | 5    | 0       | 0    | 2      | 0    | 1      | 5     | 220  | 16    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | Other   | 6     | 21     | 12   | 25      | 22   | 19     | 25   | 14     | 11    | 40   | 4072  | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | Silence | 0     | 0      | 0    | 0       | 0    | 0      | 0    | 0      | 0     | 0    | 0     | 260     |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
"""

import argparse
import math
import time
from typing import Optional

import numpy as np
import torch
import torchvision.transforms
from scipy.signal import savgol_filter
from sklearn.metrics import confusion_matrix
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchaudio.transforms import Spectrogram
from tqdm import tqdm

from spikingjelly.activation_based import neuron, surrogate
from spikingjelly.activation_based.functional import reset_net
from spikingjelly.datasets.speechcommands import SPEECHCOMMANDS

label_dict = {
    "yes": 0,
    "stop": 1,
    "no": 2,
    "right": 3,
    "up": 4,
    "left": 5,
    "on": 6,
    "down": 7,
    "off": 8,
    "go": 9,
    "bed": 10,
    "three": 10,
    "one": 10,
    "four": 10,
    "two": 10,
    "five": 10,
    "cat": 10,
    "dog": 10,
    "eight": 10,
    "bird": 10,
    "happy": 10,
    "sheila": 10,
    "zero": 10,
    "wow": 10,
    "marvin": 10,
    "house": 10,
    "six": 10,
    "seven": 10,
    "tree": 10,
    "nine": 10,
    "_silence_": 11,
}
label_cnt = len(set(label_dict.values()))
n_mels = 40
f_max = 4000
f_min = 20
delta_order = 0
size = 16000
try:
    import cupy  # noqa

    backend = "cupy"
except ModuleNotFoundError:
    backend = "torch"
    print("Cupy is not intalled. Using torch backend for neurons.")


[文档] def mel_to_hz(mels, dct_type): if dct_type == "htk": return 700.0 * (10 ** (mels / 2595.0) - 1.0) # Fill in the linear scale f_min = 0.0 f_sp = 200.0 / 3 freqs = f_min + f_sp * mels # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region if torch.is_tensor(mels) and mels.ndim: # If we have vector data, vectorize log_t = mels >= min_log_mel freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) elif mels >= min_log_mel: # If we have scalar data, check directly freqs = min_log_hz * math.exp(logstep * (mels - min_log_mel)) return freqs
[文档] def hz_to_mel(frequencies, dct_type): if dct_type == "htk": if torch.is_tensor(frequencies) and frequencies.ndim: return 2595.0 * torch.log10(1.0 + frequencies / 700.0) return 2595.0 * math.log10(1.0 + frequencies / 700.0) # Fill in the linear part f_min = 0.0 f_sp = 200.0 / 3 mels = (frequencies - f_min) / f_sp # Fill in the log-scale part min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region if torch.is_tensor(frequencies) and frequencies.ndim: # If we have array data, vectorize log_t = frequencies >= min_log_hz mels[log_t] = min_log_mel + torch.log(frequencies[log_t] / min_log_hz) / logstep elif frequencies >= min_log_hz: # If we have scalar data, heck directly mels = min_log_mel + math.log(frequencies / min_log_hz) / logstep return mels
[文档] def create_fb_matrix( n_freqs: int, f_min: float, f_max: float, n_mels: int, sample_rate: int, dct_type: Optional[str] = "slaney", ) -> Tensor: if dct_type != "htk" and dct_type != "slaney": raise ValueError("DCT type must be either 'htk' or 'slaney'") # freq bins # Equivalent filterbank construction by Librosa all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) # calculate mel freq bins # hertz to mel(f) m_min = hz_to_mel(f_min, dct_type) m_max = hz_to_mel(f_max, dct_type) m_pts = torch.linspace(m_min, m_max, n_mels + 2) # mel to hertz(mel) f_pts = mel_to_hz(m_pts, dct_type) # calculate the difference between each mel point and each stft freq point in hertz f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) # (n_freqs, n_mels + 2) slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # create overlapping triangles zero = torch.zeros(1) down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) fb = torch.max(zero, torch.min(down_slopes, up_slopes)) if dct_type == "slaney": # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels]) fb *= enorm.unsqueeze(0) return fb
[文档] class MelScaleDelta(nn.Module): __constants__ = ["n_mels", "sample_rate", "f_min", "f_max"] def __init__( self, order, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0, f_max: Optional[float] = None, n_stft: Optional[int] = None, dct_type: Optional[str] = "slaney", ) -> None: super(MelScaleDelta, self).__init__() self.order = order self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min self.dct_type = dct_type assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format( f_min, self.f_max ) fb = ( torch.empty(0) if n_stft is None else create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.dct_type, ) ) self.register_buffer("fb", fb)
[文档] def forward(self, specgram: Tensor) -> Tensor: # pack batch shape = specgram.size() specgram = specgram.reshape(-1, shape[-2], shape[-1]) if self.fb.numel() == 0: tmp_fb = create_fb_matrix( specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate, self.dct_type, ) # Attributes cannot be reassigned outside __init__ so workaround self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) # unpack batch mel_specgram = mel_specgram.reshape( shape[:-2] + mel_specgram.shape[-2:] ).squeeze() M = torch.max(torch.abs(mel_specgram)) if M > 0: feat = torch.log1p(mel_specgram / M) else: feat = mel_specgram feat_list = [feat.numpy().T] for k in range(1, self.order + 1): feat_list.append( savgol_filter( feat.numpy(), 9, deriv=k, axis=-1, mode="interp", polyorder=k ).T ) return torch.as_tensor(np.expand_dims(np.stack(feat_list), axis=0))
[文档] class Pad(object): def __init__(self, size): self.size = size def __call__(self, wav): wav_size = wav.shape[-1] pad_size = (self.size - wav_size) // 2 padded_wav = torch.nn.functional.pad( wav, (pad_size, self.size - wav_size - pad_size), mode="constant", value=0 ) return padded_wav
[文档] class Rescale(object): def __call__(self, input): std = torch.std( input, axis=2, keepdims=True, unbiased=False ) # Numpy std is calculated via the Numpy's biased estimator. https://github.com/romainzimmer/s2net/blob/82c38bf80b55d16d12d0243440e34e52d237a2df/data.py#L201 std.masked_fill_(std == 0, 1) return input / std
[文档] def collate_fn(data): X_batch = torch.cat([d[0] for d in data]) std = X_batch.std(axis=(0, 2), keepdim=True, unbiased=False) X_batch.div_(std) y_batch = torch.tensor([d[1] for d in data]) return X_batch, y_batch
#### Network ####
[文档] class LIFWrapper(nn.Module): def __init__(self, module, flatten=False): super().__init__() self.module = module self.flatten = flatten
[文档] def forward(self, x_seq: torch.Tensor) -> torch.Tensor: r""" **API Language:** :ref:`中文 <LIFWrapper-forward-cn>` | :ref:`English <LIFWrapper-forward-en>` ---- .. _LIFWrapper-forward-cn: * **中文** 输入 ``x_seq`` 的形状为 ``[batch_size, channel, T, n_mel]``。在送入被包装模块前, 时间维和批量维会交换为 ``[T, channel, batch_size, n_mel]`` 以适配多步前向。 当 ``self.flatten=True`` 时,输出会重排并展平成 ``[batch_size, T, channel * n_mel]``;否则返回 ``[batch_size, channel, T, n_mel]``。 :param x_seq: 输入序列,shape=[batch_size, channel, T, n_mel] :type x_seq: torch.Tensor :return: 输出序列;当 ``self.flatten=True`` 时 shape=[batch_size, T, channel * n_mel], 否则 shape=[batch_size, channel, T, n_mel] :rtype: torch.Tensor ---- .. _LIFWrapper-forward-en: * **English** The input ``x_seq`` has shape ``[batch_size, channel, T, n_mel]``. Before passing it to the wrapped module, the time and batch dimensions are swapped to ``[T, channel, batch_size, n_mel]`` to match multi-step forward mode. If ``self.flatten=True``, the output is permuted and flattened to ``[batch_size, T, channel * n_mel]``; otherwise the output shape is ``[batch_size, channel, T, n_mel]``. :param x_seq: Input sequence, shape=[batch_size, channel, T, n_mel] :type x_seq: torch.Tensor :return: Output sequence; shape=[batch_size, T, channel * n_mel] when ``self.flatten=True``, otherwise shape=[batch_size, channel, T, n_mel] :rtype: torch.Tensor """ # Input: [batch size, channel, T, n_mel] y_seq = self.module(x_seq.transpose(0, 2)) # [T, channel, batch size, n_mel] if self.flatten: y_seq = y_seq.permute(2, 0, 1, 3) # [batch size, T, channel, n_mel] shape = y_seq.shape[:2] return y_seq.reshape(shape + (-1,)) # [batch size, T, channel * n_mel] else: return y_seq.transpose(0, 2) # [batch size, channel, T, n_mel]
[文档] class Net(nn.Module): def __init__(self): r""" **API Language:** :ref:`中文 <Net-init-cn>` | :ref:`English <Net-init-en>` ---- .. _Net-init-cn: * **中文** 初始化语音命令识别网络。该网络由三层卷积-脉冲神经元模块组成, 最后一层输出会展平到 ``channel * n_mel`` 维度,再由全连接层映射到类别空间。 同时初始化训练过程统计字段。 ---- .. _Net-init-en: * **English** Initialize the speech command recognition network. The network stacks three convolution-spiking blocks, flattens the final feature dimension to ``channel * n_mel``, and maps features to class logits with a linear layer. Training statistics fields are also initialized. """ super().__init__() self.train_times = 0 self.epochs = 0 self.max_test_acccuracy = 0 # batch size * delta_order+1 * T * n_mel self.conv = nn.Sequential( # 101 * 40 nn.Conv2d( in_channels=delta_order + 1, out_channels=64, kernel_size=(4, 3), stride=1, padding=(2, 1), bias=False, ), LIFWrapper( neuron.LIFNode( tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.0), backend=backend, step_mode="m", ) ), # 102 * 40 nn.Conv2d( in_channels=64, out_channels=64, kernel_size=(4, 3), stride=1, padding=(6, 3), dilation=(4, 3), bias=False, ), LIFWrapper( neuron.LIFNode( tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.0), backend=backend, step_mode="m", ) ), # 102 * 40 nn.Conv2d( in_channels=64, out_channels=64, kernel_size=(4, 3), stride=1, padding=(24, 9), dilation=(16, 9), bias=False, ), LIFWrapper( neuron.LIFNode( tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.0), backend=backend, step_mode="m", ), flatten=True, ), ) # [batch size, T, channel * n_mel] self.fc = nn.Linear(64 * 40, label_cnt)
[文档] def forward(self, x): r""" **API Language:** :ref:`中文 <Net-forward-cn>` | :ref:`English <Net-forward-en>` ---- .. _Net-forward-cn: * **中文** 对输入特征先经过卷积脉冲模块,得到按时间步展开的类别 logits, 然后沿时间维做均值池化,输出每个样本的最终分类 logits。 :param x: 输入特征,shape=[batch_size, delta_order + 1, T, n_mel] :type x: torch.Tensor :return: 分类 logits,shape=[batch_size, label_cnt] :rtype: torch.Tensor ---- .. _Net-forward-en: * **English** Run the input features through the convolutional spiking stack to obtain per-time-step class logits, then apply mean pooling over the time dimension to produce final logits for each sample. :param x: Input features, shape=[batch_size, delta_order + 1, T, n_mel] :type x: torch.Tensor :return: Classification logits, shape=[batch_size, label_cnt] :rtype: torch.Tensor """ x = self.fc(self.conv(x)) # [batch size, T, #Class] return x.mean(dim=1) # [batch size, #Class]
if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-b", "--batch-size", type=int, default=64) parser.add_argument("-sr", "--sample-rate", type=int, default=16000) parser.add_argument("-lr", "--learning-rate", type=float, default=1e-2) parser.add_argument("-dir", "--dataset-dir", type=str) parser.add_argument("-e", "--epoch", type=int, default=50) parser.add_argument("-d", "--device", type=str, default="cuda:0") args = parser.parse_args() sr = args.sample_rate n_fft = int(30e-3 * sr) # 48 hop_length = int(10e-3 * sr) # 16 dataset_dir = args.dataset_dir batch_size = args.batch_size lr = args.learning_rate epoch = args.epoch device = args.device pad = Pad(size) spec = Spectrogram(n_fft=n_fft, hop_length=hop_length) melscale = MelScaleDelta( order=delta_order, n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, dct_type="slaney", ) rescale = Rescale() transform = torchvision.transforms.Compose([pad, spec, melscale, rescale]) print(label_cnt) train_dataset = SPEECHCOMMANDS( label_dict, dataset_dir, silence_cnt=2300, url="speech_commands_v0.01", split="train", transform=transform, download=True, ) train_sampler = torch.utils.data.WeightedRandomSampler( train_dataset.weights, len(train_dataset.weights) ) train_dataloader = DataLoader( train_dataset, batch_size=batch_size, num_workers=16, sampler=train_sampler, collate_fn=collate_fn, ) test_dataset = SPEECHCOMMANDS( label_dict, dataset_dir, silence_cnt=260, url="speech_commands_v0.01", split="test", transform=transform, download=True, ) test_dataloader = DataLoader( test_dataset, batch_size=batch_size, num_workers=16, collate_fn=collate_fn, shuffle=False, drop_last=False, ) net = Net().to(device) optimizer = Adam(net.parameters(), lr=lr) gamma = 0.85 lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma, last_epoch=-1 ) warmup_epochs = 1 print(net) writer = SummaryWriter("./logs/") criterion = nn.CrossEntropyLoss().to(device) for e in range(epoch): net.train() print(f"Epoch {net.epochs}") time_start = time.time() ##### TRAIN ##### for audios, labels in tqdm(train_dataloader): audios = audios.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad() out_spikes_counter_frequency = net(audios) loss = criterion(out_spikes_counter_frequency, labels) loss.backward() # nn.utils.clip_grad_value_(net.parameters(), 5) optimizer.step() reset_net(net) # Rate-based output decoding correct_rate = ( (out_spikes_counter_frequency.argmax(dim=1) == labels) .float() .mean() .item() ) net.train_times += 1 if e >= warmup_epochs: lr_scheduler.step() net.eval() writer.add_scalar("Train Loss", loss.item(), global_step=net.epochs) ##### TEST ##### with torch.no_grad(): test_sum = 0 correct_sum = 0 pred = [] label = [] for audios, labels in tqdm(test_dataloader): audios = audios.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) out_spikes_counter = net(audios) preds = out_spikes_counter.argmax(dim=1) correct_sum += (preds == labels).float().sum().item() pred.append(preds) label.append(labels) test_sum += labels.numel() reset_net(net) pred = torch.cat(pred).cpu().numpy() label = torch.cat(label).cpu().numpy() # Confusion matrix cmatrix = confusion_matrix(label, pred) print("Confusion Matrix:") print(cmatrix) # plt.clf() # fig = plt.figure() # plt.imshow(cmatrix) # writer.add_figure('Confusion Matrix', figure=fig, # global_step=net.epochs) test_accuracy = correct_sum / test_sum writer.add_scalar("Test Acc.", test_accuracy, global_step=net.epochs) net.epochs += 1 time_end = time.time() print( f"Test Acc: {test_accuracy} Loss: {loss} Elapse: {time_end - time_start:.2f}s" )