训练大规模SNN
本教程作者: fangwei123456
使用 activation_based.model
在 spikingjelly.activation_based.model
中定义了一些经典网络模型,可以直接拿来使用,使用方法与 torchvision.models
类似。以Spiking ResNet 1 为例:
import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
print(f's_resnet18={s_resnet18}')
with torch.no_grad():
T = 4
N = 1
x_seq = torch.rand([T, N, 3, 224, 224])
functional.set_step_mode(s_resnet18, 'm')
y_seq = s_resnet18(x_seq)
print(f'y_seq.shape={y_seq.shape}')
functional.reset_net(s_resnet18)
输出为:
s_resnet18=SpikingResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
y_seq.shape=torch.Size([4, 1, 1000])
SpikingJelly按照 torchvision
中的ResNet结构搭建的Spiking ResNet,保持了 state_dict().keys()
相同,因此支持直接加载预训练权重,设置 pretrained=True
即可:
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
使用 activation_based.model.train_classify
spikingjelly.activation_based.model.train_classify
是根据 torchvision 0.12 references 的分类代码进行改动而来,使用这个模块可以很方便的进行训练。
spikingjelly.activation_based.model.train_classify.Trainer
提供了较为灵活的训练方式,预留了一些接口给用户改动。例如, spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer
定义了如何设置优化器,默认为:
# spikingjelly.activation_based.model.train_classify
class Trainer:
# ...
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer
{args.opt}. Only SGD, RMSprop and AdamW are supported.")
return optimizer
def main(self, args):
# ...
optimizer = self.set_optimizer(args, parameters)
# ...
如果我们增加一个优化器,例如 Adamax
,只需要继承并重写此方法,例如:
class MyTrainer(train_classify.Trainer):
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("adamax"):
optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
return optimizer
else:
return super().set_optimizer(args, parameters)
默认的 Trainer.get_args_parser
已经包含了较多的参数设置:
(pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h
usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING]
[--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY]
[--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT]
[--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION]
[--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED]
[--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda]
PyTorch Classification Training
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
dataset path
--model MODEL model name
--device DEVICE device (Use cuda or cpu Default: cuda)
-b BATCH_SIZE, --batch-size BATCH_SIZE
images per gpu, the total batch size is $NGPU x batch_size
--epochs N number of total epochs to run
-j N, --workers N number of data loading workers (default: 16)
--opt OPT optimizer
--lr LR initial learning rate
--momentum M momentum
--wd W, --weight-decay W
weight decay (default: 0.)
--norm-weight-decay NORM_WEIGHT_DECAY
weight decay for Normalization layers (default: None, same value as --wd)
--label-smoothing LABEL_SMOOTHING
label smoothing (default: 0.1)
--mixup-alpha MIXUP_ALPHA
mixup alpha (default: 0.2)
--cutmix-alpha CUTMIX_ALPHA
cutmix alpha (default: 1.0)
--lr-scheduler LR_SCHEDULER
the lr scheduler (default: cosa)
--lr-warmup-epochs LR_WARMUP_EPOCHS
the number of epochs to warmup (default: 5)
--lr-warmup-method LR_WARMUP_METHOD
the warmup method (default: linear)
--lr-warmup-decay LR_WARMUP_DECAY
the decay for lr
--lr-step-size LR_STEP_SIZE
decrease lr every step-size epochs
--lr-gamma LR_GAMMA decrease lr by a factor of lr-gamma
--output-dir OUTPUT_DIR
path to save outputs
--resume RESUME path of checkpoint. If set to 'latest', it will try to load the latest checkpoint
--start-epoch N start epoch
--cache-dataset Cache the datasets for quicker initialization. It also serializes the transforms
--sync-bn Use sync batch norm
--test-only Only test the model
--pretrained Use pre-trained models from the modelzoo
--auto-augment AUTO_AUGMENT
auto augment policy (default: ta_wide)
--random-erase RANDOM_ERASE
random erasing probability (default: 0.1)
--world-size WORLD_SIZE
number of distributed processes
--dist-url DIST_URL url used to set up distributed training
--model-ema enable tracking Exponential Moving Average of model parameters
--model-ema-steps MODEL_EMA_STEPS
the number of iterations that controls how often to update the EMA model (default: 32)
--model-ema-decay MODEL_EMA_DECAY
decay factor for Exponential Moving Average of model parameters (default: 0.99998)
--interpolation INTERPOLATION
the interpolation method (default: bilinear)
--val-resize-size VAL_RESIZE_SIZE
the resize size used for validation (default: 232)
--val-crop-size VAL_CROP_SIZE
the central crop size used for validation (default: 224)
--train-crop-size TRAIN_CROP_SIZE
the random crop size used for training (default: 176)
--clip-grad-norm CLIP_GRAD_NORM
the maximum gradient norm (default None)
--ra-sampler whether to use Repeated Augmentation in training
--ra-reps RA_REPS number of repetitions for Repeated Augmentation (default: 4)
--prototype Use prototype model builders instead those from main area
--weights WEIGHTS the weights enum name to load
--seed SEED the random seed
--print-logdir print the dirs for tensorboard logs and pt files and exit
--clean delete the dirs for tensorboard logs and pt files
--disable-pinmemory not use pin memory in dataloader, which can help reduce memory consumption
--disable-amp not use automatic mixed precision training
--local_rank LOCAL_RANK
args for DDP, which should not be set by user
--disable-uda not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation
如果想增加参数,仍然可以通过继承的方式实现:
class MyTrainer(train_classify.Trainer):
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--do-something', type=str, help="do something")
Trainer
的许多其他函数都可以进行补充修改或覆盖,方法类似,不再赘述。
对于 Trainer
及用户自己继承实现的子类,可以通过如下方式调用并进行训练:
trainer = Trainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
Trainer
在训练中会自动计算训练集、测试集的 Acc@1, Acc@5, loss
并使用 tensorboard
保存为日志文件,此外训练过程中的最新一个epoch的模型以及测试集性能最高的模型也会被保存下来。 Trainer
支持Distributed Data Parallel训练。
在ImageNet上训练
Trainer
默认的数据加载函数load_data
加载 ImageNet 2 数据集。结合Trainer
和spikingjelly.activation_based.model.spiking_resnet
,我们可以轻松训练大型深度SNN,示例代码如下:
# spikingjelly.activation_based.model.train_imagenet_example
import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify
class SResNetTrainer(train_classify.Trainer):
def preprocess_train_sample(self, args, x: torch.Tensor):
# define how to process train sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def preprocess_test_sample(self, args, x: torch.Tensor):
# define how to process test sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def process_model_output(self, args, y: torch.Tensor):
return y.mean(0) # return firing rate
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--T', type=int, help="total time-steps")
parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
return parser
def get_tb_logdir_name(self, args):
return super().get_tb_logdir_name(args) + f'_T{args.T}'
def load_model(self, args, num_classes):
if args.model in spiking_resnet.__all__:
model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode,
surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(model, step_mode='m')
if args.cupy:
functional.set_backend(model, 'cupy', neuron.IFNode)
return model
else:
raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")
if __name__ == "__main__":
trainer = SResNetTrainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
代码位于 spikingjelly.activation_based.model.train_imagenet_example
,可以直接运行。
在单卡上进行训练:
python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
在多卡上进行训练:
python -m torch.distributed.launch --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
- 1
He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- 2
Deng, Jia, et al. “Imagenet: A large-scale hierarchical image database.” 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.