"""
.. codeauthor:: Yanqi Chen <chyq@pku.edu.cn>
A reproduction of the paper `Enabling Spike-Based Backpropagation for Training Deep Neural Network Architectures <https://doi.org/10.3389/fnins.2020.00119>`_\ .
This code reproduces a novel gradient-based training method of SNN. We to some extent refer to the network structure and some other detailed implementation in the `authors' implementation <https://github.com/chan8972/Enabling_Spikebased_Backpropagation>`_\ . Since the training method and neuron models are slightly different from which in this framework, we rewrite them in a compatible style.
Assuming you have at least 1 Nvidia GPU.
"""
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
from spikingjelly.activation_based import layer
from tqdm import tqdm
import math
parser = argparse.ArgumentParser(description="spikingjelly CIFAR10 Training")
parser.add_argument("data", metavar="DIR", help="path to dataset")
parser.add_argument(
"-j",
"--workers",
default=4,
type=int,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument("-b", "--batch-size", default=16, type=int, metavar="N")
parser.add_argument(
"-T", "--timesteps", default=100, type=int, help="Simulation timesteps"
)
parser.add_argument(
"--lr",
"--learning-rate",
default=0.0025,
type=float,
metavar="LR",
help="initial learning rate",
dest="lr",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
action="store_true",
help="use pre-trained parameters.",
)
parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
#### Surrogate function ####
[文档]
class relu(torch.autograd.Function):
[文档]
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return (x > 0).float()
[文档]
@staticmethod
def backward(ctx, grad_output):
inputs = ctx.saved_tensors[0]
grad_x = grad_output.clone()
grad_x[inputs <= 0.0] = 0
return grad_x
#### Neurons ####
[文档]
class BaseNode(nn.Module):
def __init__(
self, v_threshold=1.0, v_reset=0.0, surrogate_function=relu.apply, monitor=False
):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
if self.v_reset is None:
self.v = 0
else:
self.v = self.v_reset
self.surrogate_function = surrogate_function
self.v_acc = 0 # Accumulated voltage (Assuming NO fire for this neuron)
self.v_acc_l = (
0 # Accumulated voltage with leaky (Assuming NO fire for this neuron)
)
if monitor:
self.monitor = {"v": [], "s": []}
else:
self.monitor = False
self.new_grad = None
[文档]
def spiking(self):
spike = self.v - self.v_threshold
self.v.masked_fill_(spike > 0, self.v_reset)
spike = self.surrogate_function(spike)
return spike
[文档]
def forward(self, dv: torch.Tensor):
raise NotImplementedError
[文档]
def reset(self):
if self.v_reset is None:
self.v = 0
else:
self.v = self.v_reset
if self.monitor:
self.monitor = {"v": [], "s": []}
self.v_acc = 0
self.v_acc_l = 0
[文档]
class LIFNode(BaseNode):
def __init__(
self,
tau=100.0,
v_threshold=1.0,
v_reset=0.0,
surrogate_function=relu.apply,
fire=True,
):
super().__init__(v_threshold, v_reset, surrogate_function)
self.tau = tau
self.fire = fire # If no fire, the voltage threshold of neuron is infinity
self.new_grad = None
[文档]
def forward(self, dv: torch.Tensor):
self.v += dv
if self.fire:
spike = self.spiking()
self.v_acc += spike
self.v_acc_l = self.v - ((self.v - self.v_reset) / self.tau) + spike
self.v = self.v - ((self.v - self.v_reset) / self.tau).detach()
if self.fire:
if self.training:
spike.register_hook(lambda grad: torch.mul(grad, self.new_grad))
return spike
return self.v
[文档]
class IFNode(BaseNode):
def __init__(self, v_threshold=0.75, v_reset=0.0, surrogate_function=relu.apply):
super().__init__(v_threshold, v_reset, surrogate_function)
[文档]
def forward(self, dv: torch.Tensor):
self.v += dv
return self.spiking()
#### Network ####
[文档]
class ResNet11(nn.Module):
def __init__(self):
super().__init__()
self.train_epoch = 0
self.cnn11 = nn.Conv2d(
in_channels=3,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.lif11 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.avgpool1 = nn.AvgPool2d(kernel_size=2)
self.if1 = IFNode()
self.cnn21 = nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.lif21 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn22 = nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.shortcut1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=1, stride=1, bias=False),
)
self.lif2 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn31 = nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.lif31 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn32 = nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.shortcut2 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),
)
self.lif3 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn41 = nn.Conv2d(
in_channels=256,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.lif41 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn42 = nn.Conv2d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.shortcut3 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=1, stride=1, bias=False),
)
self.lif4 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn51 = nn.Conv2d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
bias=False,
)
self.lif51 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.cnn52 = nn.Conv2d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.shortcut4 = nn.Sequential(
nn.AvgPool2d(kernel_size=(1, 1), stride=(2, 2), padding=(0, 0))
)
self.lif5 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.fc0 = nn.Linear(512 * 4 * 4, 1024, bias=False)
self.lif6 = nn.Sequential(LIFNode(), layer.Dropout(0.25))
self.fc1 = nn.Linear(1024, 10, bias=False)
self.lif_out = LIFNode(fire=False)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
variance1 = math.sqrt(1.0 / n)
m.weight.data.normal_(0, variance1)
elif isinstance(m, nn.Linear):
size = m.weight.size()
fan_in = size[1]
variance2 = math.sqrt(1.0 / fan_in)
m.weight.data.normal_(0.0, variance2)
[文档]
def forward(self, x):
x = self.if1(self.avgpool1(self.lif11(self.cnn11(x))))
x = self.lif2(self.cnn22(self.lif21(self.cnn21(x))) + self.shortcut1(x))
x = self.lif3(self.cnn32(self.lif31(self.cnn31(x))) + self.shortcut2(x))
x = self.lif4(self.cnn42(self.lif41(self.cnn41(x))) + self.shortcut3(x))
x = self.lif5(self.cnn52(self.lif51(self.cnn51(x))) + self.shortcut4(x))
out = x.view(x.size(0), -1)
out = self.lif_out(self.fc1(self.lif6(self.fc0(out))))
return out
[文档]
def reset_(self):
for item in self.modules():
if hasattr(item, "reset"):
item.reset()
[文档]
def main():
args = parser.parse_args()
torch.cuda.set_device(args.gpu)
learning_rate = args.lr
batch_size = args.batch_size
T = args.timesteps
log_prefix = f"{sys.argv[0]}-lr-{learning_rate}-T-{T}-b-{batch_size}"
writer = SummaryWriter("./logs/" + log_prefix)
cudnn.benchmark = True
dataset_root_dir = args.data
# Load data
train_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.CIFAR10(
root=dataset_root_dir,
train=True,
transform=torchvision.transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534]
),
]
),
download=True,
),
batch_size=batch_size,
shuffle=True,
pin_memory=True,
drop_last=True,
num_workers=args.workers,
)
test_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.CIFAR10(
root=dataset_root_dir,
train=False,
transform=torchvision.transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=[0.4914, 0.4822, 0.4465], std=[0.557, 0.549, 0.5534]
),
]
),
download=True,
),
batch_size=batch_size,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.workers,
)
# Prepare model
net = ResNet11().cuda()
if args.pretrained:
# The pretrained parameter can either be downloaded in authors' Dropbox (https://www.dropbox.com/sh/vvq9afkq90refka/AAAIEnyBZ_wO7eM510GCyZ8ta?dl=0) or trained by yourself.
# Should be placed together with code before training!
checkpoint = torch.load("./model_bestT1_cifar10_r11.pth.tar")
net.load_state_dict(checkpoint["state_dict"])
print(net)
criterion = nn.MSELoss(reduction="sum").cuda()
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[70, 100, 125], gamma=0.2
)
max_test_accuracy = 0
train_epoch = 0
step = 0
while 1:
print(log_prefix)
#### Train ####
for img, label in tqdm(train_data_loader):
label = label.cuda()
label = F.one_hot(label, 10).float()
img = img.cuda()
optimizer.zero_grad()
for t in range(T - 1):
# Poisson encoding
rand_num = torch.rand_like(img).cuda()
poisson_input = (torch.abs(img) > rand_num).float()
poisson_input = torch.mul(poisson_input, torch.sign(img))
net(poisson_input)
output = net(poisson_input)
for m in net.modules():
if isinstance(m, LIFNode) and m.fire:
m.v_acc += (m.v_acc < 1e-3).float()
m.new_grad = (m.v_acc_l > 1e-3).float() + math.log(
1 - 1 / m.tau
) * torch.div(m.v_acc_l, m.v_acc)
loss = criterion(output / T, label)
loss.backward()
optimizer.step()
net.reset_()
writer.add_scalar("train_loss", loss, step)
step += 1
#### Evaluate ####
with torch.no_grad():
print("Test:")
net.eval()
accuracy = 0
test_num = 0
for img, label in tqdm(test_data_loader):
label = label.cuda()
img = img.cuda()
for t in range(T - 1):
# Poisson encoding
rand_num = torch.rand_like(img).cuda()
poisson_input = (torch.abs(img) > rand_num).float()
poisson_input = torch.mul(poisson_input, torch.sign(img))
net(poisson_input)
output = net(poisson_input)
accuracy += (output.argmax(dim=1) == label).float().sum().item()
test_num += label.numel()
net.reset_()
accuracy /= test_num
if max_test_accuracy < accuracy:
max_test_accuracy = accuracy
torch.save(
net.state_dict(),
"./logs/" + log_prefix + "/model_bestT1_cifar10_r11.pth.tar",
)
print(
"保存模型参数",
"./logs/" + log_prefix + "/model_bestT1_cifar10_r11.pth.tar",
)
writer.add_scalar("test_acc", accuracy, train_epoch)
print(f"Test Acc: {accuracy}, Max Acc: {max_test_accuracy}")
net.train()
train_epoch += 1
net.train_epoch += 1
scheduler.step()
if __name__ == "__main__":
main()