spikingjelly.activation_based.examples.speechcommands 源代码

"""
.. codeauthor:: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>

A reproduction of the paper `Technical report: supervised training of convolutional spiking neural networks with PyTorch <https://arxiv.org/pdf/1911.10124.pdf>`_\ .

This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.

..  note::

    To prevent too much dependency like `librosa <https://librosa.org/doc/latest/index.html>`_, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.

Confusion matrix of TEST set after training (50 epochs):

+------------------------+--------------------------------------------------------------------------------------------------+
| Count                  | Prediction                                                                                       |
|                        +-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|                        | "Yes" | "Stop" | "No" | "Right" | "Up" | "Left" | "On" | "Down" | "Off" | "Go" | Other | Silence |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| Ground Truth | "Yes"   | 234   | 0      | 2    | 0       | 0    | 3      | 0    | 0      | 0     | 1    | 16    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Stop"  | 0     | 233    | 0    | 1       | 5    | 0      | 0    | 0      | 0     | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "No"    | 0     | 1      | 223  | 1       | 0    | 1      | 0    | 5      | 0     | 9    | 12    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Right" | 0     | 0      | 0    | 234     | 0    | 0      | 0    | 0      | 0     | 0    | 24    | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Up"    | 0     | 4      | 0    | 0       | 249  | 0      | 0    | 0      | 8     | 0    | 11    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Left"  | 3     | 1      | 2    | 3       | 1    | 250    | 0    | 0      | 1     | 0    | 6     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "On"    | 0     | 3      | 0    | 0       | 0    | 0      | 231  | 0      | 2     | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Down"  | 0     | 0      | 7    | 0       | 0    | 1      | 2    | 230    | 0     | 4    | 8     | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Off"   | 0     | 0      | 2    | 1       | 4    | 2      | 6    | 0      | 237   | 1    | 9     | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | "Go"    | 0     | 2      | 5    | 0       | 0    | 2      | 0    | 1      | 5     | 220  | 16    | 0       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | Other   | 6     | 21     | 12   | 25      | 22   | 19     | 25   | 14     | 11    | 40   | 4072  | 1       |
|              +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
|              | Silence | 0     | 0      | 0    | 0       | 0    | 0      | 0    | 0      | 0     | 0    | 0     | 260     |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
"""

import torch
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms
from torchaudio.transforms import Spectrogram
from spikingjelly.activation_based import neuron, surrogate

from spikingjelly.datasets.speechcommands import SPEECHCOMMANDS
from spikingjelly.activation_based.functional import reset_net
from scipy.signal import savgol_filter

from sklearn.metrics import confusion_matrix

import numpy as np

import math
import time
import argparse
from typing import Optional
from tqdm import tqdm

label_dict = {'yes': 0, 'stop': 1, 'no': 2, 'right': 3, 'up': 4, 'left': 5, 'on': 6, 'down': 7, 'off': 8, 'go': 9, 'bed': 10, 'three': 10, 'one': 10, 'four': 10, 'two': 10, 'five': 10, 'cat': 10, 'dog': 10, 'eight': 10, 'bird': 10, 'happy': 10, 'sheila': 10, 'zero': 10, 'wow': 10, 'marvin': 10, 'house': 10, 'six': 10, 'seven': 10, 'tree': 10, 'nine': 10, '_silence_': 11}
label_cnt = len(set(label_dict.values()))
n_mels = 40
f_max = 4000
f_min = 20
delta_order = 0
size = 16000
try:
    import cupy
    backend = 'cupy'
except ModuleNotFoundError:
    backend = 'torch'
    print('Cupy is not intalled. Using torch backend for neurons.')

[文档]def mel_to_hz(mels, dct_type): if dct_type == 'htk': return 700.0 * (10 ** (mels / 2595.0) - 1.0) # Fill in the linear scale f_min = 0.0 f_sp = 200.0 / 3 freqs = f_min + f_sp * mels # And now the nonlinear scale min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region if torch.is_tensor(mels) and mels.ndim: # If we have vector data, vectorize log_t = mels >= min_log_mel freqs[log_t] = min_log_hz * \ torch.exp(logstep * (mels[log_t] - min_log_mel)) elif mels >= min_log_mel: # If we have scalar data, check directly freqs = min_log_hz * math.exp(logstep * (mels - min_log_mel)) return freqs
[文档]def hz_to_mel(frequencies, dct_type): if dct_type == 'htk': if torch.is_tensor(frequencies) and frequencies.ndim: return 2595.0 * torch.log10(1.0 + frequencies / 700.0) return 2595.0 * math.log10(1.0 + frequencies / 700.0) # Fill in the linear part f_min = 0.0 f_sp = 200.0 / 3 mels = (frequencies - f_min) / f_sp # Fill in the log-scale part min_log_hz = 1000.0 # beginning of log region (Hz) min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) logstep = math.log(6.4) / 27.0 # step size for log region if torch.is_tensor(frequencies) and frequencies.ndim: # If we have array data, vectorize log_t = frequencies >= min_log_hz mels[log_t] = min_log_mel + \ torch.log(frequencies[log_t] / min_log_hz) / logstep elif frequencies >= min_log_hz: # If we have scalar data, heck directly mels = min_log_mel + math.log(frequencies / min_log_hz) / logstep return mels
[文档]def create_fb_matrix( n_freqs: int, f_min: float, f_max: float, n_mels: int, sample_rate: int, dct_type: Optional[str] = 'slaney') -> Tensor: if dct_type != "htk" and dct_type != "slaney": raise ValueError("DCT type must be either 'htk' or 'slaney'") # freq bins # Equivalent filterbank construction by Librosa all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) # calculate mel freq bins # hertz to mel(f) m_min = hz_to_mel(f_min, dct_type) m_max = hz_to_mel(f_max, dct_type) m_pts = torch.linspace(m_min, m_max, n_mels + 2) # mel to hertz(mel) f_pts = mel_to_hz(m_pts, dct_type) # calculate the difference between each mel point and each stft freq point in hertz f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) # (n_freqs, n_mels + 2) slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # create overlapping triangles zero = torch.zeros(1) down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) fb = torch.max(zero, torch.min(down_slopes, up_slopes)) if dct_type == "slaney": # Slaney-style mel is scaled to be approx constant energy per channel enorm = 2.0 / (f_pts[2:n_mels + 2] - f_pts[:n_mels]) fb *= enorm.unsqueeze(0) return fb
[文档]class MelScaleDelta(nn.Module): __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] def __init__(self, order, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0., f_max: Optional[float] = None, n_stft: Optional[int] = None, dct_type: Optional[str] = 'slaney') -> None: super(MelScaleDelta, self).__init__() self.order = order self.n_mels = n_mels self.sample_rate = sample_rate self.f_max = f_max if f_max is not None else float(sample_rate // 2) self.f_min = f_min self.dct_type = dct_type assert f_min <= self.f_max, 'Require f_min: {} < f_max: {}'.format( f_min, self.f_max) fb = torch.empty(0) if n_stft is None else create_fb_matrix( n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate, self.dct_type) self.register_buffer('fb', fb)
[文档] def forward(self, specgram: Tensor) -> Tensor: # pack batch shape = specgram.size() specgram = specgram.reshape(-1, shape[-2], shape[-1]) if self.fb.numel() == 0: tmp_fb = create_fb_matrix(specgram.size( 1), self.f_min, self.f_max, self.n_mels, self.sample_rate, self.dct_type) # Attributes cannot be reassigned outside __init__ so workaround self.fb.resize_(tmp_fb.size()) self.fb.copy_(tmp_fb) # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul( specgram.transpose(1, 2), self.fb).transpose(1, 2) # unpack batch mel_specgram = mel_specgram.reshape( shape[:-2] + mel_specgram.shape[-2:]).squeeze() M = torch.max(torch.abs(mel_specgram)) if M > 0: feat = torch.log1p(mel_specgram/M) else: feat = mel_specgram feat_list = [feat.numpy().T] for k in range(1, self.order + 1): feat_list.append(savgol_filter( feat.numpy(), 9, deriv=k, axis=-1, mode='interp', polyorder=k).T) return torch.as_tensor(np.expand_dims(np.stack(feat_list), axis=0))
[文档]class Pad(object): def __init__(self, size): self.size = size def __call__(self, wav): wav_size = wav.shape[-1] pad_size = (self.size - wav_size) // 2 padded_wav = torch.nn.functional.pad( wav, (pad_size, self.size-wav_size-pad_size), mode='constant', value=0) return padded_wav
[文档]class Rescale(object): def __call__(self, input): std = torch.std(input, axis=2, keepdims=True, unbiased=False) # Numpy std is calculated via the Numpy's biased estimator. https://github.com/romainzimmer/s2net/blob/82c38bf80b55d16d12d0243440e34e52d237a2df/data.py#L201 std.masked_fill_(std == 0, 1) return input / std
[文档]def collate_fn(data): X_batch = torch.cat([d[0] for d in data]) std = X_batch.std(axis=(0, 2), keepdim=True, unbiased=False) X_batch.div_(std) y_batch = torch.tensor([d[1] for d in data]) return X_batch, y_batch
#### Network ####
[文档]class LIFWrapper(nn.Module): def __init__(self, module, flatten=False): super().__init__() self.module = module self.flatten = flatten
[文档] def forward(self, x_seq: torch.Tensor): ''' :param x_seq: shape=[batch size, channel, T, n_mel] :type x_seq: torch.Tensor :return: y_seq, shape=[batch size, channel, T, n_mel] :rtype: torch.Tensor ''' # Input: [batch size, channel, T, n_mel] y_seq = self.module(x_seq.transpose(0, 2)) # [T, channel, batch size, n_mel] if self.flatten: y_seq = y_seq.permute(2, 0, 1, 3) # [batch size, T, channel, n_mel] shape = y_seq.shape[:2] return y_seq.reshape(shape + (-1,)) # [batch size, T, channel * n_mel] else: return y_seq.transpose(0, 2) # [batch size, channel, T, n_mel]
[文档]class Net(nn.Module): def __init__(self): super().__init__() self.train_times = 0 self.epochs = 0 self.max_test_acccuracy = 0 # batch size * delta_order+1 * T * n_mel self.conv = nn.Sequential( # 101 * 40 nn.Conv2d(in_channels=delta_order+1, out_channels=64, kernel_size=(4, 3), stride=1, padding=(2, 1), bias=False), LIFWrapper(neuron.LIFNode(tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.), backend=backend, step_mode='m')), # 102 * 40 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4, 3), stride=1, padding=(6, 3), dilation=(4, 3), bias=False), LIFWrapper(neuron.LIFNode(tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.), backend=backend, step_mode='m')), # 102 * 40 nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(4, 3), stride=1, padding=(24, 9), dilation=(16, 9), bias=False), LIFWrapper(neuron.LIFNode(tau=10.0 / 7, surrogate_function=surrogate.Sigmoid(alpha=10.), backend=backend, step_mode='m'), flatten=True), ) # [batch size, T, channel * n_mel] self.fc = nn.Linear(64 * 40, label_cnt)
[文档] def forward(self, x): x = self.fc(self.conv(x)) # [batch size, T, #Class] return x.mean(dim=1) # [batch size, #Class]
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-b', '--batch-size', type=int, default=64) parser.add_argument('-sr', '--sample-rate', type=int, default=16000) parser.add_argument('-lr', '--learning-rate', type=float, default=1e-2) parser.add_argument('-dir', '--dataset-dir', type=str) parser.add_argument('-e', '--epoch', type=int, default=50) parser.add_argument('-d', '--device', type=str, default='cuda:0') args = parser.parse_args() sr = args.sample_rate n_fft = int(30e-3*sr) # 48 hop_length = int(10e-3*sr) # 16 dataset_dir = args.dataset_dir batch_size = args.batch_size lr = args.learning_rate epoch = args.epoch device = args.device pad = Pad(size) spec = Spectrogram(n_fft=n_fft, hop_length=hop_length) melscale = MelScaleDelta(order=delta_order, n_mels=n_mels, sample_rate=sr, f_min=f_min, f_max=f_max, dct_type='slaney') rescale = Rescale() transform = torchvision.transforms.Compose([pad, spec, melscale, rescale]) print(label_cnt) train_dataset = SPEECHCOMMANDS( label_dict, dataset_dir, silence_cnt=2300, url="speech_commands_v0.01", split="train", transform=transform, download=True) train_sampler = torch.utils.data.WeightedRandomSampler( train_dataset.weights, len(train_dataset.weights)) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=16, sampler=train_sampler, collate_fn=collate_fn) test_dataset = SPEECHCOMMANDS( label_dict, dataset_dir, silence_cnt=260, url="speech_commands_v0.01", split="test", transform=transform, download=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=16, collate_fn=collate_fn, shuffle=False, drop_last=False) net = Net().to(device) optimizer = Adam(net.parameters(), lr=lr) gamma = 0.85 lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1) warmup_epochs = 1 print(net) writer = SummaryWriter('./logs/') criterion = nn.CrossEntropyLoss().to(device) for e in range(epoch): net.train() print(f'Epoch {net.epochs}') time_start = time.time() ##### TRAIN ##### for audios, labels in tqdm(train_dataloader): audios = audios.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad() out_spikes_counter_frequency = net(audios) loss = criterion(out_spikes_counter_frequency, labels) loss.backward() # nn.utils.clip_grad_value_(net.parameters(), 5) optimizer.step() reset_net(net) # Rate-based output decoding correct_rate = (out_spikes_counter_frequency.argmax( dim=1) == labels).float().mean().item() net.train_times += 1 if e >= warmup_epochs: lr_scheduler.step() net.eval() writer.add_scalar('Train Loss', loss.item(), global_step=net.epochs) ##### TEST ##### with torch.no_grad(): test_sum = 0 correct_sum = 0 pred = [] label = [] for audios, labels in tqdm(test_dataloader): audios = audios.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) out_spikes_counter = net(audios) preds = out_spikes_counter.argmax(dim=1) correct_sum += (preds == labels).float().sum().item() pred.append(preds) label.append(labels) test_sum += labels.numel() reset_net(net) pred = torch.cat(pred).cpu().numpy() label = torch.cat(label).cpu().numpy() # Confusion matrix cmatrix = confusion_matrix(label, pred) print("Confusion Matrix:") print(cmatrix) # plt.clf() # fig = plt.figure() # plt.imshow(cmatrix) # writer.add_figure('Confusion Matrix', figure=fig, # global_step=net.epochs) test_accuracy = correct_sum / test_sum writer.add_scalar('Test Acc.', test_accuracy, global_step=net.epochs) net.epochs += 1 time_end = time.time() print(f'Test Acc: {test_accuracy} Loss: {loss} Elapse: {time_end - time_start:.2f}s')