spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation 源代码

"""
.. codeauthor:: Yanqi Chen <chyq@pku.edu.cn>

A reproduction of the paper `Enabling Spike-Based Backpropagation for Training Deep Neural Network Architectures <https://doi.org/10.3389/fnins.2020.00119>`_\ .

This code reproduces a novel gradient-based training method of SNN. We to some extent refer to the network structure and some other detailed implementation in the `authors' implementation <https://github.com/chan8972/Enabling_Spikebased_Backpropagation>`_\ . Since the training method and neuron models are slightly different from which in this framework, we rewrite them in a compatible style.

Assuming you have at least 1 Nvidia GPU.
"""

import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms

from spikingjelly.activation_based import functional, layer

from tqdm import tqdm
import math


parser = argparse.ArgumentParser(description='spikingjelly CIFAR10 Training')
parser.add_argument('data', metavar='DIR',
					help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
					help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=16, type=int,
					metavar='N')
parser.add_argument('-T', '--timesteps', default=100, type=int,
					help='Simulation timesteps')
parser.add_argument('--lr', '--learning-rate', default=0.0025, type=float,
					metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
					help='use pre-trained parameters.')
parser.add_argument('--gpu', default=None, type=int,
					help='GPU id to use.')


#### Surrogate function ####
[文档]class relu(torch.autograd.Function):
[文档] @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return (x > 0).float()
[文档] @staticmethod def backward(ctx, grad_output): inputs = ctx.saved_tensors[0] grad_x = grad_output.clone() grad_x[inputs <= 0.0] = 0 return grad_x
#### Neurons ####
[文档]class BaseNode(nn.Module): def __init__(self, v_threshold=1.0, v_reset=0.0, surrogate_function=relu.apply, monitor=False): super().__init__() self.v_threshold = v_threshold self.v_reset = v_reset if self.v_reset is None: self.v = 0 else: self.v = self.v_reset self.surrogate_function = surrogate_function self.v_acc = 0 # Accumulated voltage (Assuming NO fire for this neuron) self.v_acc_l = 0 # Accumulated voltage with leaky (Assuming NO fire for this neuron) if monitor: self.monitor = {'v':[], 's':[]} else: self.monitor = False self.new_grad = None
[文档] def spiking(self): spike = self.v - self.v_threshold self.v.masked_fill_(spike > 0, self.v_reset) spike = self.surrogate_function(spike) return spike
[文档] def forward(self, dv: torch.Tensor): raise NotImplementedError
[文档] def reset(self): if self.v_reset is None: self.v = 0 else: self.v = self.v_reset if self.monitor: self.monitor = {'v': [], 's': []} self.v_acc = 0 self.v_acc_l = 0
[文档]class LIFNode(BaseNode): def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=relu.apply, fire=True): super().__init__(v_threshold, v_reset, surrogate_function) self.tau = tau self.fire = fire # If no fire, the voltage threshold of neuron is infinity self.new_grad = None
[文档] def forward(self, dv: torch.Tensor): self.v += dv if self.fire: spike = self.spiking() self.v_acc += spike self.v_acc_l = self.v - ((self.v - self.v_reset) / self.tau) + spike self.v = self.v - ((self.v - self.v_reset) / self.tau).detach() if self.fire: if self.training: spike.register_hook(lambda grad: torch.mul(grad, self.new_grad)) return spike return self.v
[文档]class IFNode(BaseNode): def __init__(self, v_threshold=0.75, v_reset=0.0, surrogate_function=relu.apply): super().__init__(v_threshold, v_reset, surrogate_function)
[文档] def forward(self, dv: torch.Tensor): self.v += dv return self.spiking()
#### Network ####
[文档]class ResNet11(nn.Module): def __init__(self): super().__init__() self.train_epoch = 0 self.cnn11 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) self.lif11 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.avgpool1 = nn.AvgPool2d(kernel_size=2) self.if1 = IFNode() self.cnn21 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) self.lif21 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn22 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) self.shortcut1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=1, bias=False),) self.lif2 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn31 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) self.lif31 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn32 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False) self.shortcut2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),) self.lif3 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn41 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False) self.lif41 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn42 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False) self.shortcut3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=1, bias=False),) self.lif4 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn51 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False) self.lif51 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.cnn52 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1, bias=False) self.shortcut4 = nn.Sequential(nn.AvgPool2d(kernel_size=(1, 1), stride=(2, 2), padding=(0, 0))) self.lif5 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.fc0 = nn.Linear(512 * 4 * 4, 1024, bias=False) self.lif6 = nn.Sequential( LIFNode(), layer.Dropout(0.25) ) self.fc1 = nn.Linear(1024, 10, bias=False) self.lif_out = LIFNode(fire=False) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels variance1 = math.sqrt(1.0 / n) m.weight.data.normal_(0, variance1) elif isinstance(m, nn.Linear): size = m.weight.size() fan_in = size[1] variance2 = math.sqrt(1.0 / fan_in) m.weight.data.normal_(0.0, variance2)
[文档] def forward(self, x): x = self.if1(self.avgpool1(self.lif11(self.cnn11(x)))) x = self.lif2(self.cnn22(self.lif21(self.cnn21(x))) + self.shortcut1(x)) x = self.lif3(self.cnn32(self.lif31(self.cnn31(x))) + self.shortcut2(x)) x = self.lif4(self.cnn42(self.lif41(self.cnn41(x))) + self.shortcut3(x)) x = self.lif5(self.cnn52(self.lif51(self.cnn51(x))) + self.shortcut4(x)) out = x.view(x.size(0), -1) out = self.lif_out(self.fc1(self.lif6(self.fc0(out)))) return out
[文档] def reset_(self): for item in self.modules(): if hasattr(item, 'reset'): item.reset()
[文档]def main(): args = parser.parse_args() torch.cuda.set_device(args.gpu) learning_rate = args.lr batch_size = args.batch_size T = args.timesteps log_prefix = f'{sys.argv[0]}-lr-{learning_rate}-T-{T}-b-{batch_size}' writer = SummaryWriter('./logs/' + log_prefix) cudnn.benchmark = True dataset_root_dir = args.data # Load data train_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_root_dir, train=True, transform=torchvision.transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534]) ]), download=True), batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=args.workers) test_data_loader = torch.utils.data.DataLoader( dataset=torchvision.datasets.CIFAR10(root=dataset_root_dir, train=False, transform=torchvision.transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534]) ]), download=True), batch_size=batch_size, shuffle=False, pin_memory=True, drop_last=False, num_workers=args.workers) # Prepare model net = ResNet11().cuda() if args.pretrained: # The pretrained parameter can either be downloaded in authors' Dropbox (https://www.dropbox.com/sh/vvq9afkq90refka/AAAIEnyBZ_wO7eM510GCyZ8ta?dl=0) or trained by yourself. # Should be placed together with code before training! checkpoint = torch.load('./model_bestT1_cifar10_r11.pth.tar') net.load_state_dict(checkpoint['state_dict']) print(net) criterion = nn.MSELoss(reduction='sum').cuda() optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[70, 100, 125], gamma=0.2) max_test_accuracy = 0 train_epoch = 0 step = 0 while 1: print(log_prefix) #### Train #### for img, label in tqdm(train_data_loader): label = label.cuda() label = F.one_hot(label, 10).float() img = img.cuda() optimizer.zero_grad() for t in range(T - 1): # Poisson encoding rand_num = torch.rand_like(img).cuda() poisson_input = (torch.abs(img) > rand_num).float() poisson_input = torch.mul(poisson_input, torch.sign(img)) net(poisson_input) output = net(poisson_input) for m in net.modules(): if isinstance(m, LIFNode) and m.fire: m.v_acc += (m.v_acc < 1e-3).float() m.new_grad = (m.v_acc_l > 1e-3).float() + math.log(1 - 1 / m.tau) * torch.div(m.v_acc_l, m.v_acc) loss = criterion(output / T, label) loss.backward() optimizer.step() net.reset_() writer.add_scalar('train_loss', loss, step) step += 1 #### Evaluate #### with torch.no_grad(): print('Test:') net.eval() accuracy = 0 test_num = 0 for img, label in tqdm(test_data_loader): label = label.cuda() img = img.cuda() for t in range(T - 1): # Poisson encoding rand_num = torch.rand_like(img).cuda() poisson_input = (torch.abs(img) > rand_num).float() poisson_input = torch.mul(poisson_input, torch.sign(img)) net(poisson_input) output = net(poisson_input) accuracy += (output.argmax(dim=1) == label).float().sum().item() test_num += label.numel() net.reset_() accuracy /= test_num if max_test_accuracy < accuracy: max_test_accuracy = accuracy torch.save(net.state_dict(), './logs/' + log_prefix + '/model_bestT1_cifar10_r11.pth.tar') print('保存模型参数', './logs/' + log_prefix + '/model_bestT1_cifar10_r11.pth.tar') writer.add_scalar('test_acc', accuracy, train_epoch) print(f'Test Acc: {accuracy}, Max Acc: {max_test_accuracy}') net.train() train_epoch += 1 net.train_epoch += 1 scheduler.step()
if __name__ == '__main__': main()