spikingjelly.activation_based.ann2snn.examples.resnet18_cifar10 源代码

import torch
import torchvision
from tqdm import tqdm
import spikingjelly.activation_based.ann2snn as ann2snn
from spikingjelly.activation_based.ann2snn.sample_models import cifar10_resnet


[文档]def val(net, device, data_loader, T=None): net.eval().to(device) correct = 0.0 total = 0.0 with torch.no_grad(): for batch, (img, label) in enumerate(tqdm(data_loader)): img = img.to(device) if T is None: out = net(img) else: for m in net.modules(): if hasattr(m, 'reset'): m.reset() for t in range(T): if t == 0: out = net(img) else: out += net(img) correct += (out.argmax(dim=1) == label.to(device)).float().sum().item() total += out.shape[0] acc = correct / total print('Validating Accuracy: %.3f' % (acc)) return acc
[文档]def main(): torch.random.manual_seed(0) torch.cuda.manual_seed(0) device = 'cuda:9' dataset_dir = '~/dataset/cifar10' batch_size = 100 T = 400 transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) model = cifar10_resnet.ResNet18() model.load_state_dict(torch.load('SJ-cifar10-resnet18_model-sample.pth')) train_data_dataset = torchvision.datasets.CIFAR10( root=dataset_dir, train=True, transform=transform, download=True) train_data_loader = torch.utils.data.DataLoader( dataset=train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=False) test_data_dataset = torchvision.datasets.CIFAR10( root=dataset_dir, train=False, transform=transform, download=True) test_data_loader = torch.utils.data.DataLoader( dataset=test_data_dataset, batch_size=50, shuffle=True, drop_last=False) print('ANN accuracy:') val(model, device, test_data_loader) print('Converting...') model_converter = ann2snn.Converter(mode='Max', dataloader=train_data_loader) snn_model = model_converter(model) print('SNN accuracy:') val(snn_model, device, test_data_loader, T=T)
if __name__ == '__main__': print('Downloading SJ-cifar10-resnet18_model-sample.pth') ann2snn.download_url("https://ndownloader.figshare.com/files/26676110",'./SJ-cifar10-resnet18_model-sample.pth') main()