分类 DVS Gesture

本教程作者: fangwei123456

神经形态数据集处理 中我们已经学习了如何使用神经形态数据集,下面让我们搭建SNN对其进行分类。

网络结构

我们将使用 [1] 一文中定义的网络,其结构如下:

../_images/network1.png

[1] 一文中的所有网络都在 spikingjelly.activation_based.model.parametric_lif_net 中进行了定义,其中用于DVS Gesture的网络结构为:

# spikingjelly.activation_based.model.parametric_lif_net

import torch
import torch.nn as nn
from .. import layer

class DVSGestureNet(nn.Module):
    def __init__(self, channels=128, spiking_neuron: callable = None, *args, **kwargs):
        super().__init__()

        conv = []
        for i in range(5):
            if conv.__len__() == 0:
                in_channels = 2
            else:
                in_channels = channels

            conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False))
            conv.append(layer.BatchNorm2d(channels))
            conv.append(spiking_neuron(*args, **kwargs))
            conv.append(layer.MaxPool2d(2, 2))


        self.conv_fc = nn.Sequential(
            *conv,

            layer.Flatten(),
            layer.Dropout(0.5),
            layer.Linear(channels * 4 * 4, 512),
            spiking_neuron(*args, **kwargs),

            layer.Dropout(0.5),
            layer.Linear(512, 110),
            spiking_neuron(*args, **kwargs),

            layer.VotingLayer(10)
        )

    def forward(self, x: torch.Tensor):
        return self.conv_fc(x)

训练

训练的代码与之前的教程 使用卷积SNN识别Fashion-MNIST 几乎相同,相同之处不再赘述,下面只介绍差异部分。

定义网络,使用多步模式。若使用 CuPy 则将所有的 neuron.LIFNode 设置为 cupy 后端:

# spikingjelly.activation_based.examples.classify_dvsg

import torch
import sys
import torch.nn.functional as F
from torch.cuda import amp
from spikingjelly.activation_based import functional, surrogate, neuron
from spikingjelly.activation_based.model import parametric_lif_net
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import os
import argparse
import datetime

def main():
    # ...
    net = parametric_lif_net.DVSGestureNet(channels=args.channels, spiking_neuron=neuron.LIFNode, surrogate_function=surrogate.ATan(), detach_reset=True)

    functional.set_step_mode(net, 'm')
    if args.cupy:
        functional.set_backend(net, 'cupy', instance=neuron.LIFNode)
    # ...

新建数据集:

# spikingjelly.activation_based.examples.classify_dvsg

def main():
    # ...
    train_set = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number')
    test_set = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, split_by='number')
    # ...

注意,由 DataLoader 打包的数据,第0维总是batch维度,因此我们从 DataLoader 读取的数据实际上是 shape = [N, T, C, H, W],因此我们需要转换为SpikingJelly的多步模式使用的 shape = [T, N, C, H, W]

# spikingjelly.activation_based.examples.classify_dvsg

 def main():
    # ...
    for epoch in range(start_epoch, args.epochs):
        for frame, label in train_data_loader:
            optimizer.zero_grad()
            frame = frame.to(args.device)
            frame = frame.transpose(0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
            # ...

        with torch.no_grad():
        for frame, label in test_data_loader:
            frame = frame.to(args.device)
            frame = frame.transpose(0, 1)  # [N, T, C, H, W] -> [T, N, C, H, W]
            # ...

    # ...

DVS Gesture有11类,因此在生成one hot的target时别忘了设置为11类:

# spikingjelly.activation_based.examples.classify_dvsg

def main():
    # ...
    label_onehot = F.one_hot(label, 11).float()
    # ...

DVSGestureNet 输出的并不是脉冲发放频率,而是 shape = [T, N, 11] 的原始输出:

# spikingjelly.activation_based.model.parametric_lif_net

class DVSGestureNet(nn.Module):
    # ...
    def forward(self, x: torch.Tensor):
        return self.conv_fc(x)

因此,我们需要对输出在时间维度上求平均后,得到脉冲发放频率,然后才去计算损失和正确率:

# spikingjelly.activation_based.examples.classify_dvsg

def main():
    # ...
    out_fr = net(frame).mean(0)
    loss = F.mse_loss(out_fr, label_onehot)
    # ...

运行我们的网络:

python -m spikingjelly.activation_based.examples.classify_dvsg -T 16 -device cuda:0 -b 16 -epochs 64 -data-dir /datasets/DVSGesture/ -amp -cupy -opt adam -lr 0.001 -j 8

得到输出为:

Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
DVSGestureNet(
(conv_fc): Sequential(
    (0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (2): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (6): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (10): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (14): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
    (17): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
    (18): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (19): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
    (20): Flatten(start_dim=1, end_dim=-1, step_mode=m)
    (21): Dropout(p=0.5)
    (22): Linear(in_features=2048, out_features=512, bias=True)
    (23): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (24): Dropout(p=0.5)
    (25): Linear(in_features=512, out_features=110, bias=True)
    (26): LIFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
    (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
    (27): VotingLayer(voting_size=10, step_mode=m)
)
)
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
Mkdir ./logs/T16_b16_adam_lr0.001_c128_amp_cupy.
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 0, train_loss = 0.0666, train_acc = 0.3964, test_loss = 0.0514, test_acc = 0.6042, max_test_acc = 0.6042
train speed = 92.7646 images/s, test speed = 115.2935 images/s
escape time = 2022-05-25 21:31:54

Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 1, train_loss = 0.0463, train_acc = 0.6036, test_loss = 0.0439, test_acc = 0.6319, max_test_acc = 0.6319
train speed = 101.5938 images/s, test speed = 120.5184 images/s
escape time = 2022-05-25 21:30:48

...

Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 63, train_loss = 0.0011, train_acc = 0.9991, test_loss = 0.0103, test_acc = 0.9375, max_test_acc = 0.9375
train speed = 100.4324 images/s, test speed = 121.0402 images/s
escape time = 2022-05-25 21:30:51

最终获得了 max_test_acc = 0.9375 的性能。如果精心调整超参数、增加训练 epochs,通常还能获得更高的性能。

下图展示了训练过程中的正确率曲线:

../_images/dvsg_logs.svg