Clock driven: Use single-layer fully connected SNN to identify MNIST

Author: Yanqi-Chen

Translator: YeYumin

This tutorial will introduce how to train a simplest MNIST classification network using encoders and alternative gradient methods.

Build a simple SNN network from scratch

When building a neural network in PyTorch, we can simply use nn.Sequential to stack multiple network layers to get a feedforward network. The input data will flow through each network layer in order to get the output.

The MNIST Dateset contains several 8-bit grayscale images with the size of \(28\times 28\), which include total of 10 categories from 0 to 9. Taking the classification of MNIST as an example, a simple single-layer ANN network is as follows:

net = nn.Sequential(
    nn.Linear(28 * 28, 10, bias=False),

We can also use SNN with a completely similar structure for classification tasks. As far as this network is concerned, we only need to remove all the activation functions first, and then add the neurons to the original activation function position. Here we choose the LIF neuron:

net = nn.Sequential(
    nn.Linear(28 * 28, 10, bias=False),

Among them, the membrane potential decay constant \(\tau\) needs to be set by the parameter tau.

Train SNN network

First specify the training parameters and several other configurations

device = input('Enter the operating device,e.g.:"cpu" or "cuda:0"\n input device, e.g., "cpu" or "cuda:0": ')
dataset_dir = input('enter the location of the MNIST data set,e.g.:"./"\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('input batch_size, e.g.:"64"\n input batch_size, e.g., "64": '))
learning_rate = float(input('input learning rate,e.g.:"1e-3"\n input learning rate, e.g., "1e-3": '))
T = int(input('enter simulation duration, e.g.:"100"\n input simulating steps, e.g., "100": '))
tau = float(input('input the time constant of the LIF neuron tau,e.g.:"100.0"\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": '))
train_epoch = int(input('enter the number of training rounds, that is, the number of times to traverse the training set, e.g.:"100"\n input training epochs, e.g., "100": '))
log_dir = input('enter the location to save the tensorboard log file, e.g.:"./"\n input root directory for saving tensorboard logs, e.g., "./": ')

The optimizer uses Adam and Poisson encoder to perform spike encoding every time when a picture is input.

# Use Adam optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
# Use Poisson encoder
encoder = encoding.PoissonEncoder()

The writing of training code needs to follow the following three points:

1. The output of the spiking neuron is binary, and directly using the result of a single run for classification is very susceptible to interference. Therefore, it is generally considered that the output of the spike network is the firing frequency (or firing rate) of the output layer over a period of time, and the firing rate indicates the response strength of the category. Therefore, the network needs to run for a period of time, that is, the average distribution rate after T time is used as the classification basis.

2. The desired result we hope is that except for the correct neuron firing the highest frequency, the other neurons remain silent. Cross-entropy loss or MSE loss is often used, and here we use MSE loss which have a better actual effect.

  1. After each network simulation is over, the network status needs to be reset.

Combining the above three points, the code of training loop is as follows:

for img, label in train_data_loader:
    img =
    label =
    label_one_hot = F.one_hot(label, 10).float()


    # Run time of T,out_spikes_counter is the tensor of shape=[batch_size, 10]
    # Record the number of spike firings of 10 neurons in the output layer during the entire simulation duration
    for t in range(T):
        if t == 0:
            out_spikes_counter = net(encoder(img).float())
            out_spikes_counter += net(encoder(img).float())

    # out_spikes_counter / T Obtain the spike firing frequency of 10 neurons in the output layer during the simulation time
    out_spikes_counter_frequency = out_spikes_counter / T

    # The loss function is the spike firing frequency of the neurons in the output layer, and the MSE of the true category
    # Such a loss function will make the spike firing frequency of the i-th neuron in the output layer approach 1 when the category i is input, and the spike firing frequency of other neurons will approach 0
    loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot)
    # After optimizing the parameters once, the state of the network needs to be reset, because the neurons of SNN have "memory"

The complete code is located in In the code, we also use Tensorboard to save training logs. You can run it directly on the Python command line:

>>> import spikingjelly.clock_driven.examples.lif_fc_mnist as lif_fc_mnist
>>> lif_fc_mnist.main()

It should be noted that for training such an SNN, the amount of video memory required is linearly related to the simulation duration T. A longer T is equivalent to using a smaller simulation step, and the training is more “fine”, but the training effect is not necessarily better. When T is too large, the SNN will become a very deep network after unfolding in time, which will cause the gradient to be easily attenuated or exploded.

In addition, because we use a Poisson encoder, a larger T is required.

Training result

Take tau=2.0,T=100,batch_size=128,lr=1e-3, after training 100 Epoch, four npy files will be output. The highest correct rate on the test set is 92.5%, and the correct rate curve obtained through matplotlib visualization is as follows


Select the first picture in the test set:


Use the trained model to classify and get the classification result.

Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]

The voltage and spike of the output layer can be visualized by the function in the visualizing module as shown in the figure below.

../_images/1d_spikes.svg ../_images/2d_heatmap.svg

It can be seen that none of the neurons emit any spikes except for the neurons corresponding to the correct category. The complete training code can be found in clock_driven/examples/