r"""
.. codeauthor:: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>
A reproduction of the paper `Technical report: supervised training of convolutional spiking neural networks with PyTorch <https://arxiv.org/pdf/1911.10124.pdf>`_\ .
This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.
.. note::
To prevent too much dependency like `librosa <https://librosa.org/doc/latest/index.html>`_, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.
Confusion matrix of TEST set after training (50 epochs):
+------------------------+--------------------------------------------------------------------------------------------------+
| Count | Prediction |
| +-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Yes" | "Stop" | "No" | "Right" | "Up" | "Left" | "On" | "Down" | "Off" | "Go" | Other | Silence |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| Ground Truth | "Yes" | 234 | 0 | 2 | 0 | 0 | 3 | 0 | 0 | 0 | 1 | 16 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Stop" | 0 | 233 | 0 | 1 | 5 | 0 | 0 | 0 | 0 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "No" | 0 | 1 | 223 | 1 | 0 | 1 | 0 | 5 | 0 | 9 | 12 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Right" | 0 | 0 | 0 | 234 | 0 | 0 | 0 | 0 | 0 | 0 | 24 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Up" | 0 | 4 | 0 | 0 | 249 | 0 | 0 | 0 | 8 | 0 | 11 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Left" | 3 | 1 | 2 | 3 | 1 | 250 | 0 | 0 | 1 | 0 | 6 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "On" | 0 | 3 | 0 | 0 | 0 | 0 | 231 | 0 | 2 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Down" | 0 | 0 | 7 | 0 | 0 | 1 | 2 | 230 | 0 | 4 | 8 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Off" | 0 | 0 | 2 | 1 | 4 | 2 | 6 | 0 | 237 | 1 | 9 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | "Go" | 0 | 2 | 5 | 0 | 0 | 2 | 0 | 1 | 5 | 220 | 16 | 0 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | Other | 6 | 21 | 12 | 25 | 22 | 19 | 25 | 14 | 11 | 40 | 4072 | 1 |
| +---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
| | Silence | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 260 |
+--------------+---------+-------+--------+------+---------+------+--------+------+--------+-------+------+-------+---------+
"""
import argparse
import math
import time
from typing import Optional
import numpy as np
import torch
import torchvision.transforms
from scipy.signal import savgol_filter
from sklearn.metrics import confusion_matrix
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchaudio.transforms import Spectrogram
from tqdm import tqdm
from spikingjelly.activation_based import neuron, surrogate
from spikingjelly.activation_based.functional import reset_net
from spikingjelly.datasets.speechcommands import SPEECHCOMMANDS
label_dict = {
"yes": 0,
"stop": 1,
"no": 2,
"right": 3,
"up": 4,
"left": 5,
"on": 6,
"down": 7,
"off": 8,
"go": 9,
"bed": 10,
"three": 10,
"one": 10,
"four": 10,
"two": 10,
"five": 10,
"cat": 10,
"dog": 10,
"eight": 10,
"bird": 10,
"happy": 10,
"sheila": 10,
"zero": 10,
"wow": 10,
"marvin": 10,
"house": 10,
"six": 10,
"seven": 10,
"tree": 10,
"nine": 10,
"_silence_": 11,
}
label_cnt = len(set(label_dict.values()))
n_mels = 40
f_max = 4000
f_min = 20
delta_order = 0
size = 16000
try:
import cupy # noqa
backend = "cupy"
except ModuleNotFoundError:
backend = "torch"
print("Cupy is not intalled. Using torch backend for neurons.")
[文档]
def mel_to_hz(mels, dct_type):
if dct_type == "htk":
return 700.0 * (10 ** (mels / 2595.0) - 1.0)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if torch.is_tensor(mels) and mels.ndim:
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
elif mels >= min_log_mel:
# If we have scalar data, check directly
freqs = min_log_hz * math.exp(logstep * (mels - min_log_mel))
return freqs
[文档]
def hz_to_mel(frequencies, dct_type):
if dct_type == "htk":
if torch.is_tensor(frequencies) and frequencies.ndim:
return 2595.0 * torch.log10(1.0 + frequencies / 700.0)
return 2595.0 * math.log10(1.0 + frequencies / 700.0)
# Fill in the linear part
f_min = 0.0
f_sp = 200.0 / 3
mels = (frequencies - f_min) / f_sp
# Fill in the log-scale part
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = math.log(6.4) / 27.0 # step size for log region
if torch.is_tensor(frequencies) and frequencies.ndim:
# If we have array data, vectorize
log_t = frequencies >= min_log_hz
mels[log_t] = min_log_mel + torch.log(frequencies[log_t] / min_log_hz) / logstep
elif frequencies >= min_log_hz:
# If we have scalar data, heck directly
mels = min_log_mel + math.log(frequencies / min_log_hz) / logstep
return mels
[文档]
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int,
dct_type: Optional[str] = "slaney",
) -> Tensor:
if dct_type != "htk" and dct_type != "slaney":
raise ValueError("DCT type must be either 'htk' or 'slaney'")
# freq bins
# Equivalent filterbank construction by Librosa
all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
# calculate mel freq bins
# hertz to mel(f)
m_min = hz_to_mel(f_min, dct_type)
m_max = hz_to_mel(f_max, dct_type)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
# mel to hertz(mel)
f_pts = mel_to_hz(m_pts, dct_type)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
# (n_freqs, n_mels + 2)
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1)
# create overlapping triangles
zero = torch.zeros(1)
down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
if dct_type == "slaney":
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
fb *= enorm.unsqueeze(0)
return fb
[文档]
class MelScaleDelta(nn.Module):
__constants__ = ["n_mels", "sample_rate", "f_min", "f_max"]
def __init__(
self,
order,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.0,
f_max: Optional[float] = None,
n_stft: Optional[int] = None,
dct_type: Optional[str] = "slaney",
) -> None:
super(MelScaleDelta, self).__init__()
self.order = order
self.n_mels = n_mels
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
self.f_min = f_min
self.dct_type = dct_type
assert f_min <= self.f_max, "Require f_min: {} < f_max: {}".format(
f_min, self.f_max
)
fb = (
torch.empty(0)
if n_stft is None
else create_fb_matrix(
n_stft,
self.f_min,
self.f_max,
self.n_mels,
self.sample_rate,
self.dct_type,
)
)
self.register_buffer("fb", fb)
[文档]
def forward(self, specgram: Tensor) -> Tensor:
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = create_fb_matrix(
specgram.size(1),
self.f_min,
self.f_max,
self.n_mels,
self.sample_rate,
self.dct_type,
)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
# -> (channel, time, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(
shape[:-2] + mel_specgram.shape[-2:]
).squeeze()
M = torch.max(torch.abs(mel_specgram))
if M > 0:
feat = torch.log1p(mel_specgram / M)
else:
feat = mel_specgram
feat_list = [feat.numpy().T]
for k in range(1, self.order + 1):
feat_list.append(
savgol_filter(
feat.numpy(), 9, deriv=k, axis=-1, mode="interp", polyorder=k
).T
)
return torch.as_tensor(np.expand_dims(np.stack(feat_list), axis=0))
[文档]
class Pad(object):
def __init__(self, size):
self.size = size
def __call__(self, wav):
wav_size = wav.shape[-1]
pad_size = (self.size - wav_size) // 2
padded_wav = torch.nn.functional.pad(
wav, (pad_size, self.size - wav_size - pad_size), mode="constant", value=0
)
return padded_wav
[文档]
class Rescale(object):
def __call__(self, input):
std = torch.std(
input, axis=2, keepdims=True, unbiased=False
) # Numpy std is calculated via the Numpy's biased estimator. https://github.com/romainzimmer/s2net/blob/82c38bf80b55d16d12d0243440e34e52d237a2df/data.py#L201
std.masked_fill_(std == 0, 1)
return input / std
[文档]
def collate_fn(data):
X_batch = torch.cat([d[0] for d in data])
std = X_batch.std(axis=(0, 2), keepdim=True, unbiased=False)
X_batch.div_(std)
y_batch = torch.tensor([d[1] for d in data])
return X_batch, y_batch
#### Network ####
[文档]
class LIFWrapper(nn.Module):
def __init__(self, module, flatten=False):
super().__init__()
self.module = module
self.flatten = flatten
[文档]
def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
r"""
**API Language:**
:ref:`中文 <LIFWrapper-forward-cn>` | :ref:`English <LIFWrapper-forward-en>`
----
.. _LIFWrapper-forward-cn:
* **中文**
输入 ``x_seq`` 的形状为 ``[batch_size, channel, T, n_mel]``。在送入被包装模块前,
时间维和批量维会交换为 ``[T, channel, batch_size, n_mel]`` 以适配多步前向。
当 ``self.flatten=True`` 时,输出会重排并展平成
``[batch_size, T, channel * n_mel]``;否则返回
``[batch_size, channel, T, n_mel]``。
:param x_seq: 输入序列,shape=[batch_size, channel, T, n_mel]
:type x_seq: torch.Tensor
:return: 输出序列;当 ``self.flatten=True`` 时 shape=[batch_size, T, channel * n_mel],
否则 shape=[batch_size, channel, T, n_mel]
:rtype: torch.Tensor
----
.. _LIFWrapper-forward-en:
* **English**
The input ``x_seq`` has shape ``[batch_size, channel, T, n_mel]``. Before
passing it to the wrapped module, the time and batch dimensions are swapped
to ``[T, channel, batch_size, n_mel]`` to match multi-step forward mode.
If ``self.flatten=True``, the output is permuted and flattened to
``[batch_size, T, channel * n_mel]``; otherwise the output shape is
``[batch_size, channel, T, n_mel]``.
:param x_seq: Input sequence, shape=[batch_size, channel, T, n_mel]
:type x_seq: torch.Tensor
:return: Output sequence; shape=[batch_size, T, channel * n_mel] when
``self.flatten=True``, otherwise shape=[batch_size, channel, T, n_mel]
:rtype: torch.Tensor
"""
# Input: [batch size, channel, T, n_mel]
y_seq = self.module(x_seq.transpose(0, 2)) # [T, channel, batch size, n_mel]
if self.flatten:
y_seq = y_seq.permute(2, 0, 1, 3) # [batch size, T, channel, n_mel]
shape = y_seq.shape[:2]
return y_seq.reshape(shape + (-1,)) # [batch size, T, channel * n_mel]
else:
return y_seq.transpose(0, 2) # [batch size, channel, T, n_mel]
[文档]
class Net(nn.Module):
def __init__(self):
r"""
**API Language:**
:ref:`中文 <Net-init-cn>` | :ref:`English <Net-init-en>`
----
.. _Net-init-cn:
* **中文**
初始化语音命令识别网络。该网络由三层卷积-脉冲神经元模块组成,
最后一层输出会展平到 ``channel * n_mel`` 维度,再由全连接层映射到类别空间。
同时初始化训练过程统计字段。
----
.. _Net-init-en:
* **English**
Initialize the speech command recognition network. The network stacks
three convolution-spiking blocks, flattens the final feature dimension to
``channel * n_mel``, and maps features to class logits with a linear layer.
Training statistics fields are also initialized.
"""
super().__init__()
self.train_times = 0
self.epochs = 0
self.max_test_acccuracy = 0
# batch size * delta_order+1 * T * n_mel
self.conv = nn.Sequential(
# 101 * 40
nn.Conv2d(
in_channels=delta_order + 1,
out_channels=64,
kernel_size=(4, 3),
stride=1,
padding=(2, 1),
bias=False,
),
LIFWrapper(
neuron.LIFNode(
tau=10.0 / 7,
surrogate_function=surrogate.Sigmoid(alpha=10.0),
backend=backend,
step_mode="m",
)
),
# 102 * 40
nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=(4, 3),
stride=1,
padding=(6, 3),
dilation=(4, 3),
bias=False,
),
LIFWrapper(
neuron.LIFNode(
tau=10.0 / 7,
surrogate_function=surrogate.Sigmoid(alpha=10.0),
backend=backend,
step_mode="m",
)
),
# 102 * 40
nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=(4, 3),
stride=1,
padding=(24, 9),
dilation=(16, 9),
bias=False,
),
LIFWrapper(
neuron.LIFNode(
tau=10.0 / 7,
surrogate_function=surrogate.Sigmoid(alpha=10.0),
backend=backend,
step_mode="m",
),
flatten=True,
),
)
# [batch size, T, channel * n_mel]
self.fc = nn.Linear(64 * 40, label_cnt)
[文档]
def forward(self, x):
r"""
**API Language:**
:ref:`中文 <Net-forward-cn>` | :ref:`English <Net-forward-en>`
----
.. _Net-forward-cn:
* **中文**
对输入特征先经过卷积脉冲模块,得到按时间步展开的类别 logits,
然后沿时间维做均值池化,输出每个样本的最终分类 logits。
:param x: 输入特征,shape=[batch_size, delta_order + 1, T, n_mel]
:type x: torch.Tensor
:return: 分类 logits,shape=[batch_size, label_cnt]
:rtype: torch.Tensor
----
.. _Net-forward-en:
* **English**
Run the input features through the convolutional spiking stack to obtain
per-time-step class logits, then apply mean pooling over the time dimension
to produce final logits for each sample.
:param x: Input features, shape=[batch_size, delta_order + 1, T, n_mel]
:type x: torch.Tensor
:return: Classification logits, shape=[batch_size, label_cnt]
:rtype: torch.Tensor
"""
x = self.fc(self.conv(x)) # [batch size, T, #Class]
return x.mean(dim=1) # [batch size, #Class]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-b", "--batch-size", type=int, default=64)
parser.add_argument("-sr", "--sample-rate", type=int, default=16000)
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-2)
parser.add_argument("-dir", "--dataset-dir", type=str)
parser.add_argument("-e", "--epoch", type=int, default=50)
parser.add_argument("-d", "--device", type=str, default="cuda:0")
args = parser.parse_args()
sr = args.sample_rate
n_fft = int(30e-3 * sr) # 48
hop_length = int(10e-3 * sr) # 16
dataset_dir = args.dataset_dir
batch_size = args.batch_size
lr = args.learning_rate
epoch = args.epoch
device = args.device
pad = Pad(size)
spec = Spectrogram(n_fft=n_fft, hop_length=hop_length)
melscale = MelScaleDelta(
order=delta_order,
n_mels=n_mels,
sample_rate=sr,
f_min=f_min,
f_max=f_max,
dct_type="slaney",
)
rescale = Rescale()
transform = torchvision.transforms.Compose([pad, spec, melscale, rescale])
print(label_cnt)
train_dataset = SPEECHCOMMANDS(
label_dict,
dataset_dir,
silence_cnt=2300,
url="speech_commands_v0.01",
split="train",
transform=transform,
download=True,
)
train_sampler = torch.utils.data.WeightedRandomSampler(
train_dataset.weights, len(train_dataset.weights)
)
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
num_workers=16,
sampler=train_sampler,
collate_fn=collate_fn,
)
test_dataset = SPEECHCOMMANDS(
label_dict,
dataset_dir,
silence_cnt=260,
url="speech_commands_v0.01",
split="test",
transform=transform,
download=True,
)
test_dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
num_workers=16,
collate_fn=collate_fn,
shuffle=False,
drop_last=False,
)
net = Net().to(device)
optimizer = Adam(net.parameters(), lr=lr)
gamma = 0.85
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
optimizer, gamma, last_epoch=-1
)
warmup_epochs = 1
print(net)
writer = SummaryWriter("./logs/")
criterion = nn.CrossEntropyLoss().to(device)
for e in range(epoch):
net.train()
print(f"Epoch {net.epochs}")
time_start = time.time()
##### TRAIN #####
for audios, labels in tqdm(train_dataloader):
audios = audios.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad()
out_spikes_counter_frequency = net(audios)
loss = criterion(out_spikes_counter_frequency, labels)
loss.backward()
# nn.utils.clip_grad_value_(net.parameters(), 5)
optimizer.step()
reset_net(net)
# Rate-based output decoding
correct_rate = (
(out_spikes_counter_frequency.argmax(dim=1) == labels)
.float()
.mean()
.item()
)
net.train_times += 1
if e >= warmup_epochs:
lr_scheduler.step()
net.eval()
writer.add_scalar("Train Loss", loss.item(), global_step=net.epochs)
##### TEST #####
with torch.no_grad():
test_sum = 0
correct_sum = 0
pred = []
label = []
for audios, labels in tqdm(test_dataloader):
audios = audios.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
out_spikes_counter = net(audios)
preds = out_spikes_counter.argmax(dim=1)
correct_sum += (preds == labels).float().sum().item()
pred.append(preds)
label.append(labels)
test_sum += labels.numel()
reset_net(net)
pred = torch.cat(pred).cpu().numpy()
label = torch.cat(label).cpu().numpy()
# Confusion matrix
cmatrix = confusion_matrix(label, pred)
print("Confusion Matrix:")
print(cmatrix)
# plt.clf()
# fig = plt.figure()
# plt.imshow(cmatrix)
# writer.add_figure('Confusion Matrix', figure=fig,
# global_step=net.epochs)
test_accuracy = correct_sum / test_sum
writer.add_scalar("Test Acc.", test_accuracy, global_step=net.epochs)
net.epochs += 1
time_end = time.time()
print(
f"Test Acc: {test_accuracy} Loss: {loss} Elapse: {time_end - time_start:.2f}s"
)