SpikingFlow.softbp.examples.mnist 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import sys
sys.path.append('.')
import SpikingFlow.softbp.neuron as neuron
import SpikingFlow.encoding as encoding
from torch.utils.tensorboard import SummaryWriter
import readline

[文档]class Net(nn.Module): def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0): super().__init__() # 网络结构,简单的双层全连接网络,每一层之后都是LIF神经元 self.fc = nn.Sequential( nn.Flatten(), nn.Linear(28 * 28, 14 * 14, bias=False), neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset), nn.Linear(14 * 14, 10, bias=False), neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset) )
[文档] def forward(self, x): return self.fc(x)
[文档] def reset_(self): for item in self.modules(): if hasattr(item, 'reset'): item.reset()
[文档]def main(): device = input('输入运行的设备,例如“cpu”或“cuda:0” ') dataset_dir = input('输入保存MNIST数据集的位置,例如“./” ') batch_size = int(input('输入batch_size,例如“64” ')) learning_rate = float(input('输入学习率,例如“1e-3” ')) T = int(input('输入仿真时长,例如“50” ')) tau = float(input('输入LIF神经元的时间常数tau,例如“100.0” ')) train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100” ')) log_dir = input('输入保存tensorboard日志文件的位置,例如“./” ') writer = SummaryWriter(log_dir) # 初始化数据加载器 train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=False) # 初始化网络 net = Net(tau=tau).to(device) # 使用Adam优化器 optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) # 使用泊松编码器 encoder = encoding.PoissonEncoder() train_times = 0 for _ in range(train_epoch): net.train() for img, label in train_data_loader: img = img.to(device) optimizer.zero_grad() # 运行T个时长,out_spikes_counter是shape=[batch_size, 10]的tensor # 记录整个仿真时长内,输出层的10个神经元的脉冲发放次数 for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) # out_spikes_counter / T 得到输出层10个神经元在仿真时长内的脉冲发放频率 out_spikes_counter_frequency = out_spikes_counter / T # 损失函数为输出层神经元的脉冲发放频率,与真实类别的交叉熵 # 这样的损失函数会使,当类别i输入时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0 loss = F.cross_entropy(out_spikes_counter_frequency, label.to(device)) loss.backward() optimizer.step() # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的 net.reset_() # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果 correct_rate = (out_spikes_counter_frequency.max(1)[1] == label.to(device)).float().mean().item() writer.add_scalar('train_correct_rate', correct_rate, train_times) if train_times % 1024 == 0: print(device, dataset_dir, batch_size, learning_rate, T, tau, train_epoch, log_dir) print(sys.argv, 'train_times', train_times, 'train_correct_rate', correct_rate) train_times += 1 net.eval() with torch.no_grad(): # 每遍历一次全部数据集,就在测试集上测试一次 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: img = img.to(device) for t in range(T): if t == 0: out_spikes_counter = net(encoder(img).float()) else: out_spikes_counter += net(encoder(img).float()) correct_sum += (out_spikes_counter.max(1)[1] == label.to(device)).float().sum().item() test_sum += label.numel() net.reset_() writer.add_scalar('test_correct_rate', correct_sum / test_sum, train_times)
if __name__ == '__main__': main()