spikingjelly.event_driven.examples.tempotron_mnist 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import spikingjelly.event_driven.encoding as encoding
import spikingjelly.event_driven.neuron as neuron
import sys
if sys.platform != 'win32':
    import readline
import math

[文档]class Net(nn.Module): def __init__(self, m, T): # m是高斯调谐曲线编码器编码一个像素点所使用的神经元数量 super().__init__() self.tempotron = neuron.Tempotron(784*m, 10, T)
[文档] def forward(self, x: torch.Tensor): # 返回的是输出层10个Tempotron在仿真时长内的电压峰值 return self.tempotron(x, 'v_max')
[文档]def main(): ''' :return: None 使用高斯调谐曲线编码器编码图像为脉冲,单层Tempotron进行MNIST识别。运行示例: .. code-block:: python >>> import spikingjelly.event_driven.examples.tempotron_mnist as tempotron_mnist >>> tempotron_mnist.main() 输入运行的设备,例如“cpu”或“cuda:0” input device, e.g., "cpu" or "cuda:0": cuda:15 输入保存MNIST数据集的位置,例如“./” input root directory for saving MNIST dataset, e.g., "./": ./mnist 输入batch_size,例如“64” input batch_size, e.g., "64": 64 输入学习率,例如“1e-3” input learning rate, e.g., "1e-3": 1e-3 输入仿真时长,例如“100” input simulating steps, e.g., "100": 100 输入训练轮数,即遍历训练集的次数,例如“100” input training epochs, e.g., "100": 10 输入使用高斯调谐曲线编码每个像素点使用的神经元数量,例如“16” input neuron number for encoding a piexl in GaussianTuning encoder, e.g., "16": 16 输入保存tensorboard日志文件的位置,例如“./” input root directory for saving tensorboard logs, e.g., "./": ./logs_tempotron_mnist cuda:15 ./mnist 64 0.001 100 100 16 ./logs_tempotron_mnist train_acc 0.09375 0 cuda:15 ./mnist 64 0.001 100 100 16 ./logs_tempotron_mnist train_acc 0.78125 512 ... ''' device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ') dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ') batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ')) T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": ')) train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": ')) m = int(input('输入使用高斯调谐曲线编码每个像素点使用的神经元数量,例如“16”\n input neuron number for encoding a piexl in GaussianTuning encoder, e.g., "16": ')) log_dir = input('输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": ') # 每个像素点用m个神经元来编码 encoder = encoding.GaussianTuning(n=1, m=m, x_min=torch.zeros(size=[1]).to(device), x_max=torch.ones(size=[1]).to(device)) writer = SummaryWriter(log_dir) # 初始化数据加载器 train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=True) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True), batch_size=batch_size, shuffle=True, drop_last=False) # 初始化网络 net = Net(m, T).to(device) # 使用Adam优化器 optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate) train_times = 0 for epoch in range(train_epoch): net.train() for img, label in train_data_loader: img = img.view(img.shape[0], -1).unsqueeze(1) # [batch_size, 1, 784] in_spikes = encoder.encode(img.to(device), T) # [batch_size, 1, 784, m] in_spikes = in_spikes.view(in_spikes.shape[0], -1) # [batch_size, 784*m] v_max = net(in_spikes) train_acc = (v_max.argmax(dim=1) == label.to(device)).float().mean().item() if train_times % 256 == 0: writer.add_scalar('train_acc', train_acc, train_times) if train_times % 512 == 0: print(device, dataset_dir, batch_size, learning_rate, T, train_epoch, m, log_dir) print('train_acc', train_acc, train_times) loss = neuron.Tempotron.mse_loss(v_max, net.tempotron.v_threshold, label.to(device), 10) loss.backward() optimizer.step() train_times += 1 net.eval() with torch.no_grad(): correct_num = 0 img_num = 0 for img, label in test_data_loader: img = img.view(img.shape[0], -1).unsqueeze(1) # [batch_size, 1, 784] in_spikes = encoder.encode(img.to(device), T) # [batch_size, 1, 784, m] in_spikes = in_spikes.view(in_spikes.shape[0], -1) # [batch_size, 784*m] v_max = net(in_spikes) correct_num += (v_max.argmax(dim=1) == label.to(device)).float().sum().item() img_num += img.shape[0] test_acc = correct_num / img_num writer.add_scalar('test_acc', test_acc, epoch) print('test_acc', test_acc, train_times, log_dir)
if __name__ == '__main__': main()