spikingjelly.activation_based.examples.spiking_lstm_sequential_mnist 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import rnn
from torch.utils.tensorboard import SummaryWriter
import sys

if sys.platform != "win32":
    pass
import torchvision
import tqdm


[文档] class Net(nn.Module): def __init__(self): super().__init__() self.lstm = rnn.SpikingLSTM(28, 1024, 1) self.fc = nn.Linear(1024, 10)
[文档] def forward(self, x): x, _ = self.lstm(x) return self.fc(x[-1])
[文档] def main(): device = input( '输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ' ) dataset_dir = input( '输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ' ) batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": ')) learning_rate = float( input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": ') ) train_epoch = int( input( '输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": ' ) ) log_dir = input( '输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": ' ) writer = SummaryWriter(log_dir) # 初始化数据加载器 train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True, ), batch_size=batch_size, shuffle=True, drop_last=True, ) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True, ), batch_size=batch_size, shuffle=True, drop_last=False, ) # 初始化网络 net = Net().to(device) # 使用Adam优化器 optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate) train_times = 0 max_test_accuracy = 0 for epoch in range(train_epoch): net.train() for img, label in tqdm.tqdm(train_data_loader): img = img.to(device) # [N, 1, 28, 28] label = label.to(device) label_one_hot = F.one_hot(label, 10).float() img.squeeze_() # [N, 28, 28] img = img.permute(1, 0, 2) # [28, N, 28] optimizer.zero_grad() out_spikes_counter_frequency = net(img) loss = F.mse_loss(out_spikes_counter_frequency, label_one_hot) loss.backward() optimizer.step() accuracy = ( (out_spikes_counter_frequency.max(1)[1] == label).float().mean().item() ) if train_times % 256 == 0: writer.add_scalar("train_accuracy", accuracy, train_times) train_times += 1 net.eval() with torch.no_grad(): # 每遍历一次全部数据集,就在测试集上测试一次 test_sum = 0 correct_sum = 0 for img, label in test_data_loader: img = img.to(device) label = label.to(device) img.squeeze_() # [N, 28, 28] img = img.permute(1, 0, 2) # [28, N, 28] out_spikes_counter_frequency = net(img) correct_sum += ( (out_spikes_counter_frequency.argmax(dim=1) == label) .float() .sum() .item() ) test_sum += label.numel() test_accuracy = correct_sum / test_sum writer.add_scalar("test_accuracy", test_accuracy, epoch) # if max_test_accuracy < test_accuracy: # max_test_accuracy = test_accuracy # print('saving net...') # torch.save(net, log_dir + '/net_max_acc.pt') # print('saved') print( "device={}, dataset_dir={}, batch_size={}, learning_rate={}, log_dir={}, max_test_accuracy={}, train_times={}".format( device, dataset_dir, batch_size, learning_rate, log_dir, max_test_accuracy, train_times, ) )
if __name__ == "__main__": main()