spikingjelly.activation_based.ann2snn.examples.cnn_mnist 源代码

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from tqdm import tqdm

from spikingjelly.activation_based import ann2snn
from spikingjelly.activation_based.ann2snn.sample_models import mnist_cnn


[文档] def val(net, device, data_loader, T=None): net.eval().to(device) correct = 0.0 total = 0.0 if T is not None: corrects = np.zeros(T) with torch.no_grad(): for batch, (img, label) in enumerate(tqdm(data_loader)): img = img.to(device) if T is None: out = net(img) correct += (out.argmax(dim=1) == label.to(device)).float().sum().item() else: for m in net.modules(): if hasattr(m, "reset"): m.reset() for t in range(T): if t == 0: out = net(img) else: out += net(img) corrects[t] += ( (out.argmax(dim=1) == label.to(device)).float().sum().item() ) total += out.shape[0] return correct / total if T is None else corrects / total
[文档] def main(): torch.random.manual_seed(0) torch.cuda.manual_seed(0) device = "cuda" dataset_dir = "G:/Dataset/mnist" batch_size = 100 T = 50 model = mnist_cnn.CNN().to(device) train_data_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=True, transform=torchvision.transforms.ToTensor(), download=True, ) train_data_loader = torch.utils.data.DataLoader( dataset=train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=False ) test_data_dataset = torchvision.datasets.MNIST( root=dataset_dir, train=False, transform=torchvision.transforms.ToTensor(), download=True, ) test_data_loader = torch.utils.data.DataLoader( dataset=test_data_dataset, batch_size=50, shuffle=True, drop_last=False ) # loss_function = nn.CrossEntropyLoss() # optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4) # for epoch in range(epochs): # model.train() # for (img, label) in train_data_loader: # optimizer.zero_grad() # out = model(img.to(device)) # loss = loss_function(out, label.to(device)) # loss.backward() # optimizer.step() # torch.save(model.state_dict(), 'SJ-mnist-cnn_model-sample.pth') # print('Epoch: %d' % epoch) # acc = val(model, device, train_data_loader) # print('Validating Accuracy: %.3f' % (acc)) # print() model.load_state_dict(torch.load("SJ-mnist-cnn_model-sample.pth")) acc = val(model, device, test_data_loader) print("ANN Validating Accuracy: %.4f" % (acc)) print("---------------------------------------------") print("Converting using MaxNorm") model_converter = ann2snn.Converter(mode="max", dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_max_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_max_accs[-1])) print("---------------------------------------------") print("Converting using RobustNorm") model_converter = ann2snn.Converter(mode="99.9%", dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_robust_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_robust_accs[-1])) print("---------------------------------------------") print("Converting using 1/2 max(activation) as scales...") model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_two_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_two_accs[-1])) print("---------------------------------------------") print("Converting using 1/3 max(activation) as scales") model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_three_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_three_accs[-1])) print("---------------------------------------------") print("Converting using 1/4 max(activation) as scales") model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_four_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_four_accs[-1])) print("---------------------------------------------") print("Converting using 1/5 max(activation) as scales") model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader) snn_model = model_converter(model) print("Simulating...") mode_five_accs = val(snn_model, device, test_data_loader, T=T) print("SNN accuracy (simulation %d time-steps): %.4f" % (T, mode_five_accs[-1])) plt.figure() plt.plot(np.arange(0, T), mode_max_accs, label="mode: max") plt.plot(np.arange(0, T), mode_robust_accs, label="mode: 99.9%") plt.plot(np.arange(0, T), mode_two_accs, label="mode: 1.0/2") plt.plot(np.arange(0, T), mode_three_accs, label="mode: 1.0/3") plt.plot(np.arange(0, T), mode_four_accs, label="mode: 1.0/4") plt.plot(np.arange(0, T), mode_five_accs, label="mode: 1.0/5") plt.legend() plt.xlabel("t") plt.ylabel("Acc") plt.show()
if __name__ == "__main__": print("Downloading SJ-mnist-cnn_model-sample.pth...") ann2snn.download_url( "https://ndownloader.figshare.com/files/34960191", "./SJ-mnist-cnn_model-sample.pth", ) main()