Clock driven: Use convolutional SNN to identify Fashion-MNIST

Author: fangwei123456

Translator: YeYumin

In this tutorial, we will build a convolutional spike neural network to classify the Fashion-MNIST dataset. The Fashion-MNIST dataset has the same format as the MNIST dataset, and both are 1 * 28 * 28 grayscale images.

Network structure

Most of the common convolutional neural networks in ANN are in the form of convolution + fully connected layers. We also use a similar structure in SNN. Import related modules, inherit torch.nn.Module, and define our network:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.clock_driven import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import readline
class Net(nn.Module):
    def __init__(self, tau, v_threshold=1.0, v_reset=0.0):

Then we add a convolutional layer and a fully connected layer to the member variables of Net. The developers of SpikingJelly found in the experiments that neurons in the convolutional layer is better to use IFNode for static image data without time information. We add 2 convolution-BN-pooling layers:

self.conv = nn.Sequential(
        nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2),  # 14 * 14

        nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2)  # 7 * 7
    )

After the input of 1 * 28 * 28 undergoes such a convolutional layer, an output spike of 128 * 7 * 7 is obtained.

Such a convolutional layer can actually function as an encoder: in the previous tutorial, in the code of MNIST classification, we used a Poisson encoder to encode pictures into spikes. In fact, we can directly send the picture to the SNN. In this case, the first spike neuron layer and the previous layer in the SNN can be regarded as an auto-encoder with learnable parameters. For example, these layers in the convolutional layer we just defined:

nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128),
neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan())

This 3-layer network, which receives pictures as input and outputs spikes, can be regarded as an encoder.

Next, we define a 3-layer fully connected network and output the classification results. The fully connected layer generally functions as a classifier, and the performance of using LIFNode will be better. Fashion-MNIST has 10 categories, so the output layer is 10 neurons, in order to reduce over-fitting, we also use layer.Dropout. For more information about it, please refer to the API documentation.

self.fc = nn.Sequential(
    nn.Flatten(),
    layer.Dropout(0.7),
    nn.Linear(128 * 7 * 7, 128 * 3 * 3, bias=False),
    neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
    layer.Dropout(0.7),
    nn.Linear(128 * 3 * 3, 128, bias=False),
    neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
    nn.Linear(128, 10, bias=False),
    neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
)

Next, define forward propagation. Forward propagation is very simple, first go through convolution and then go through full connection:

def forward(self, x):
    return self.fc(self.conv(x))

Avoid repeat computing

We can train this network directly, just like the previous MNIST classification:

for img, label in train_data_loader:
    img = img.to(device)
    label = label.to(device)
    label_one_hot = F.one_hot(label, 10).float()

    optimizer.zero_grad()

    # run the time of T,out_spikes_counter is the tensor of shape=[batch_size, 10]
    # record the number of spike firings of 10 neurons in the output layer during the entire simulation duration
    for t in range(T):
        if t == 0:
            out_spikes_counter = net(encoder(img).float())
        else:
            out_spikes_counter += net(encoder(img).float())

    # out_spikes_counter / T obtain the spike firing frequency of 10 neurons in the output layer during the simulation time
    out_spikes_counter_frequency = out_spikes_counter / T

    # the loss function is the spike firing frequency of the neurons in the output layer, and the MSE of the true category
    # such a loss function will make the spike firing frequency of the i-th neuron in the output layer approach 1 when the category i is input, and the spike firing frequency of other neurons will approach 0
    loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
    loss.backward()
    optimizer.step()
    # after optimizing the parameters once, the state of the network needs to be reset, because the neurons of SNN have "memory"
    functional.reset_net(net)

But if we re-examine the structure of the network, we can find that some calculations are repeated, for the first 2 layers of the network, the highlighted part of the following code:

self.conv = nn.Sequential(
        nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2),  # 14 * 14

        nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(128),
        neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        nn.MaxPool2d(2, 2)  # 7 * 7
    )

The input images received by these two layers does not change with t , but in the for loop, each time img will recalculate these two layers to get the same output. We extract these layers and encapsulate the time loop into the network itself to facilitate calculation. The new network structure is fully defined as:

class Net(nn.Module):
    def __init__(self, tau, T, v_threshold=1.0, v_reset=0.0):
        super().__init__()
        self.T = T

        self.static_conv = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
        )

        self.conv = nn.Sequential(
            neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
            nn.MaxPool2d(2, 2),  # 14 * 14

            nn.Conv2d(128, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
            nn.MaxPool2d(2, 2)  # 7 * 7

        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            layer.Dropout(0.7),
            nn.Linear(128 * 7 * 7, 128 * 3 * 3, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
            layer.Dropout(0.7),
            nn.Linear(128 * 3 * 3, 128, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
            nn.Linear(128, 10, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        )


    def forward(self, x):
        x = self.static_conv(x)

        out_spikes_counter = self.fc(self.conv(x))
        for t in range(1, self.T):
            out_spikes_counter += self.fc(self.conv(x))

        return out_spikes_counter / self.T

For SNN whose input does not change with time, although the SNN is stateful as a whole, the first few layers of the network may not be stateful. We can extract these layers separately and put them out of the time loop to avoid additional calculations .

Training network

The complete code is located in spikingjelly.clock_driven.examples.conv_fashion_mnist. It can also be run directly from the command line.The network with the highest accuracy of the test set during the training process will be saved in the same level directory of the tensorboard log file. The server for training this network uses Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz CPU and GeForce RTX 2080 Ti GPU.

>>> from spikingjelly.clock_driven.examples import conv_fashion_mnist
>>> conv_fashion_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
 input device, e.g., "cpu" or "cuda:0": cuda:9
输入保存Fashion MNIST数据集的位置,例如“./”
 input root directory for saving Fashion MNIST dataset, e.g., "./": ./fmnist
输入batch_size,例如“64”
 input batch_size, e.g., "64": 128
输入学习率,例如“1e-3”
 input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“8”
 input simulating steps, e.g., "8": 8
输入LIF神经元的时间常数tau,例如“2.0”
 input membrane time constant, tau, for LIF neurons, e.g., "2.0": 2.0
输入训练轮数,即遍历训练集的次数,例如“100”
 input training epochs, e.g., "100": 100
输入保存tensorboard日志文件的位置,例如“./”
 input root directory for saving tensorboard logs, e.g., "./": ./logs_conv_fashion_mnist
saving net...
saved
epoch=0, t_train=41.182421264238656, t_test=2.5504338955506682, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.8704, train_times=468
saving net...
saved
epoch=1, t_train=40.93981215544045, t_test=2.538706629537046, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.8928, train_times=936
saving net...
saved
epoch=2, t_train=40.86129532009363, t_test=2.5383697943761945, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.899, train_times=1404
saving net...
saved

...

epoch=95, t_train=40.98498909268528, t_test=2.558146824128926, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.9425, train_times=44928
saving net...
saved
epoch=96, t_train=41.19765609316528, t_test=2.6626883540302515, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.9426, train_times=45396
saving net...
saved
epoch=97, t_train=41.10238983668387, t_test=2.553960849530995, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.9427, train_times=45864
saving net...
saved
epoch=98, t_train=40.89284007716924, t_test=2.5465594390407205, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.944, train_times=46332
epoch=99, t_train=40.843392613343894, t_test=2.557370903901756, device=cuda:0, dataset_dir=./fmnist, batch_size=128, learning_rate=0.001, T=8, log_dir=./logs_conv_fashion_mnist, max_test_accuracy=0.944, train_times=46800

After running 100 rounds of training, the correct rates on the training batch and test set are as follows:

../_images/train1.svg ../_images/test1.svg

After training for 100 epochs, the highest test set accuracy rate can reach 94.4%, which is a very good performance for SNN, only slightly lower than the use of Normalization, random horizontal flip, random vertical flip, random translation in the BenchMark of Fashion-MNIST, ResNet18 of random rotation has a 94.9% correct rate.

Visual encoder

As we said in the previous article, if the data is directly fed into the SNN, the first spike neuron layer and the layers before it can be regarded as a learnable encoder. Specifically, it is the highlighted part of our network as shown below:

class Net(nn.Module):
    def __init__(self, tau, T, v_threshold=1.0, v_reset=0.0):
        ...
        self.static_conv = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
        )

        self.conv = nn.Sequential(
            neuron.IFNode(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate.ATan()),
        ...

Now let’s take a look at the coding effect of the trained encoder. Let’s create a new python file, import related modules, and redefine a data loader with batch_size=1, because we want to view one picture by one:

from matplotlib import pyplot as plt
import numpy as np
from spikingjelly.clock_driven.examples.conv_fashion_mnist import Net
from spikingjelly import visualizing
import torch
import torch.nn as nn
import torchvision

test_data_loader = torch.utils.data.DataLoader(
    dataset=torchvision.datasets.FashionMNIST(
        root=dataset_dir,
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True),
    batch_size=1,
    shuffle=True,
    drop_last=False)

Load the trained network from the location where the network is saved, that is, under the log_dir directory. And we extract the encoder. Just run on the CPU:

net = torch.load('./logs_conv_fashion_mnist/net_max_acc.pt', 'cpu')
encoder = nn.Sequential(
    net.static_conv,
    net.conv[0]
)
encoder.eval()

Next, extract a picture from the data set, send it to the encoder, and check the accumulated value \(\sum_{t} S_{t}\) of the output spike. In order to display clearly, we also normalized the pixel value of the output feature_map, and linearly transformed the value range to [0, 1].

with torch.no_grad():
    # every time all the data sets are traversed, test once on the test set
    for img, label in test_data_loader:
        fig = plt.figure(dpi=200)
        plt.imshow(img.squeeze().numpy(), cmap='gray')
        # Note that the size of the image input to the network is ``[1, 1, 28, 28]``, the 0th dimension is ``batch``, and the first dimension is ``channel``
        # therefore, when calling ``imshow``, first use ``squeeze()`` to change the size to ``[28, 28]``
        plt.title('Input image', fontsize=20)
        plt.xticks([])
        plt.yticks([])
        plt.show()
        out_spikes = 0
        for t in range(net.T):
            out_spikes += encoder(img).squeeze()
            # the size of encoder(img) is ``[1, 128, 28, 28]``,the same use ``squeeze()`` transform size to ``[128, 28, 28]``
            if t == 0 or t == net.T - 1:
                out_spikes_c = out_spikes.clone()
                for i in range(out_spikes_c.shape[0]):
                    if out_spikes_c[i].max().item() > out_spikes_c[i].min().item():
                        # Normalize each feature map to make the display clearer
                        out_spikes_c[i] = (out_spikes_c[i] - out_spikes_c[i].min()) / (out_spikes_c[i].max() - out_spikes_c[i].min())
                visualizing.plot_2d_spiking_feature_map(out_spikes_c, 8, 16, 1, None)
                plt.title('$\\sum_{t} S_{t}$ at $t = ' + str(t) + '$', fontsize=20)
                plt.show()

The following shows two input pictures and the cumulative spike \(\sum_{t} S_{t}\) output by the encoder at the begin time of t=0 and the end time t=7:

../_images/x0.svg ../_images/y00.svg ../_images/y07.svg ../_images/x1.svg ../_images/y10.svg ../_images/y17.svg

Observation shows that the cumulative output spike \(\sum_{t} S_{t}\) of the encoder is very close to the contour of the original image. It seems that this kind of self-learning spike encoder has strong coding ability.