spikingjelly.activation_based.ann2snn.examples.resnet18_cifar10 源代码

import torch
import torchvision
from tqdm import tqdm
import spikingjelly.activation_based.ann2snn as ann2snn
from spikingjelly.activation_based.ann2snn.sample_models import cifar10_resnet


[文档] def val(net, device, data_loader, T=None): net.eval().to(device) correct = 0.0 total = 0.0 with torch.no_grad(): for batch, (img, label) in enumerate(tqdm(data_loader)): img = img.to(device) if T is None: out = net(img) else: for m in net.modules(): if hasattr(m, "reset"): m.reset() for t in range(T): if t == 0: out = net(img) else: out += net(img) correct += (out.argmax(dim=1) == label.to(device)).float().sum().item() total += out.shape[0] acc = correct / total print("Validating Accuracy: %.3f" % (acc)) return acc
[文档] def main(): torch.random.manual_seed(0) torch.cuda.manual_seed(0) device = "cuda:9" dataset_dir = "~/dataset/cifar10" batch_size = 100 T = 400 transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ), ] ) model = cifar10_resnet.ResNet18() model.load_state_dict(torch.load("SJ-cifar10-resnet18_model-sample.pth")) train_data_dataset = torchvision.datasets.CIFAR10( root=dataset_dir, train=True, transform=transform, download=True ) train_data_loader = torch.utils.data.DataLoader( dataset=train_data_dataset, batch_size=batch_size, shuffle=True, drop_last=False ) test_data_dataset = torchvision.datasets.CIFAR10( root=dataset_dir, train=False, transform=transform, download=True ) test_data_loader = torch.utils.data.DataLoader( dataset=test_data_dataset, batch_size=50, shuffle=True, drop_last=False ) print("ANN accuracy:") val(model, device, test_data_loader) print("Converting...") model_converter = ann2snn.Converter(mode="Max", dataloader=train_data_loader) snn_model = model_converter(model) print("SNN accuracy:") val(snn_model, device, test_data_loader, T=T)
if __name__ == "__main__": print("Downloading SJ-cifar10-resnet18_model-sample.pth") ann2snn.download_url( "https://ndownloader.figshare.com/files/26676110", "./SJ-cifar10-resnet18_model-sample.pth", ) main()