Clock driven: Use single-layer fully connected SNN to identify MNIST¶
This tutorial will introduce how to train a simplest MNIST classification network using encoders and alternative gradient methods.
Build a simple SNN network from scratch¶
When building a neural network in PyTorch, we can simply use
nn.Sequential to stack multiple network layers to get a
feedforward network. The input data will flow through each network layer in order to get the output.
The MNIST Dateset contains several 8-bit grayscale images with the size of \(28\times 28\), which include total of 10 categories from 0 to 9. Taking the classification of MNIST as an example, a simple single-layer ANN network is as follows:
net = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), nn.Softmax() )
We can also use SNN with a completely similar structure for classification tasks. As far as this network is concerned, we only need to remove all the activation functions first, and then add the neurons to the original activation function position. Here we choose the LIF neuron:
net = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau) )
Among them, the membrane potential decay constant \(\tau\) needs to be set by the parameter
Train SNN network¶
First specify the training parameters and several other configurations
device = input('Enter the operating device,e.g.:"cpu" or "cuda:0"\n input device, e.g., "cpu" or "cuda:0": ') dataset_dir = input('enter the location of the MNIST data set,e.g.:"./"\n input root directory for saving MNIST dataset, e.g., "./": ') batch_size = int(input('input batch_size, e.g.:"64"\n input batch_size, e.g., "64": ')) learning_rate = float(input('input learning rate,e.g.:"1e-3"\n input learning rate, e.g., "1e-3": ')) T = int(input('enter simulation duration, e.g.:"100"\n input simulating steps, e.g., "100": ')) tau = float(input('input the time constant of the LIF neuron tau，e.g.:"100.0"\n input membrane time constant, tau, for LIF neurons, e.g., "100.0": ')) train_epoch = int(input('enter the number of training rounds, that is, the number of times to traverse the training set, e.g.:"100"\n input training epochs, e.g., "100": ')) log_dir = input('enter the location to save the tensorboard log file, e.g.:"./"\n input root directory for saving tensorboard logs, e.g., "./": ')
The optimizer uses Adam and Poisson encoder to perform spike encoding every time when a picture is input.
# Use Adam optimizer optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # Use Poisson encoder encoder = encoding.PoissonEncoder()
The writing of training code needs to follow the following three points:
1. The output of the spiking neuron is binary, and directly using the result of a single run for classification is
very susceptible to interference. Therefore, it is generally considered that the output of the spike network is
the firing frequency (or firing rate) of the output layer over a period of time, and the firing rate indicates the
response strength of the category. Therefore, the network needs to run for a period of time, that is, the average
distribution rate after
T time is used as the classification basis.
2. The desired result we hope is that except for the correct neuron firing the highest frequency, the other neurons remain silent. Cross-entropy loss or MSE loss is often used, and here we use MSE loss which have a better actual effect.
After each network simulation is over, the network status needs to be reset.
Combining the above three points, the code of training loop is as follows:
for img, label in train_data_loader: img = img.to(device) label = label.to(device) label_one_hot = F.one_hot(label, 10).float() optimizer.zero_grad() # Run time of T，out_spikes_counter is the tensor of shape=[batch_size, 10] # Record the number of spike firings of 10 neurons in the output layer during the entire simulation duration for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) # out_spikes_counter / T Obtain the spike firing frequency of 10 neurons in the output layer during the simulation time out_spikes_counter_frequency = out_spikes_counter / T # The loss function is the spike firing frequency of the neurons in the output layer, and the MSE of the true category # Such a loss function will make the spike firing frequency of the i-th neuron in the output layer approach 1 when the category i is input, and the spike firing frequency of other neurons will approach 0 loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot) loss.backward() optimizer.step() # After optimizing the parameters once, the state of the network needs to be reset, because the neurons of SNN have "memory" functional.reset_net(net)
The complete code is located in
clock_driven.examples.lif_fc_mnist.py. In the code, we also use Tensorboard to
save training logs. You can run it directly on the Python command line:
>>> import spikingjelly.clock_driven.examples.lif_fc_mnist as lif_fc_mnist >>> lif_fc_mnist.main()
It should be noted that for training such an SNN, the amount of video memory required is linearly related to the
T. A longer
T is equivalent to using a smaller simulation step, and the training is more “fine”,
but the training effect is not necessarily better. When
T is too large, the SNN will become a very deep network after
unfolding in time, which will cause the gradient to be easily attenuated or exploded.
In addition, because we use a Poisson encoder, a larger
T is required.
tau=2.0,T=100,batch_size=128,lr=1e-3, after training 100 Epoch, four npy files will be output. The highest
correct rate on the test set is 92.5%, and the correct rate curve obtained through matplotlib visualization is as follows
Select the first picture in the test set:
Use the trained model to classify and get the classification result.
Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
The voltage and spike of the output layer can be visualized by the function in the
visualizing module as shown in the figure below.
It can be seen that none of the neurons emit any spikes except for the neurons corresponding to the correct category. The complete training code can be found in clock_driven/examples/lif_fc_mnist.py.