Train large-scale SNNs
Author: fangwei123456
Usage of activation_based.model
spikingjelly.activation_based.model
has defined some classic networks, which we can use as we use torchvision.models
. For example, we can build the Spiking ResNet 1 :
import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
print(f's_resnet18={s_resnet18}')
with torch.no_grad():
T = 4
N = 1
x_seq = torch.rand([T, N, 3, 224, 224])
functional.set_step_mode(s_resnet18, 'm')
y_seq = s_resnet18(x_seq)
print(f'y_seq.shape={y_seq.shape}')
functional.reset_net(s_resnet18)
The outputs are:
s_resnet18=SpikingResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
y_seq.shape=torch.Size([4, 1, 1000])
Spiking ResNet in SpikingJelly has the same network structure as that in torchvision
. Their state_dict().keys()
are identical and we can load pre-trained weights by setting pretrained=True
:
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
Usage of activation_based.model.train_classify
spikingjelly.activation_based.model.train_classify
is modified by torchvision 0.12 references. We can use this module to train easily.
spikingjelly.activation_based.model.train_classify.Trainer
provides a flexible method to train. Users can change its functions to implement the desirable behaviors without too much
efforts. For example, spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer
defines how to set the optimizer:
# spikingjelly.activation_based.model.train_classify
class Trainer:
# ...
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer
{args.opt}. Only SGD, RMSprop and AdamW are supported.")
return optimizer
def main(self, args):
# ...
optimizer = self.set_optimizer(args, parameters)
# ...
If we want to add an optimizer, e.g., Adamax
, we can inherit the class and override this function:
class MyTrainer(train_classify.Trainer):
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("adamax"):
optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
return optimizer
else:
return super().set_optimizer(args, parameters)
Trainer.get_args_parser
defines the args for training:
(pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h
usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING]
[--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY]
[--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT]
[--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION]
[--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED]
[--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda]
PyTorch Classification Training
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
dataset path
--model MODEL model name
--device DEVICE device (Use cuda or cpu Default: cuda)
-b BATCH_SIZE, --batch-size BATCH_SIZE
images per gpu, the total batch size is $NGPU x batch_size
--epochs N number of total epochs to run
-j N, --workers N number of data loading workers (default: 16)
--opt OPT optimizer
--lr LR initial learning rate
--momentum M momentum
--wd W, --weight-decay W
weight decay (default: 0.)
--norm-weight-decay NORM_WEIGHT_DECAY
weight decay for Normalization layers (default: None, same value as --wd)
--label-smoothing LABEL_SMOOTHING
label smoothing (default: 0.1)
--mixup-alpha MIXUP_ALPHA
mixup alpha (default: 0.2)
--cutmix-alpha CUTMIX_ALPHA
cutmix alpha (default: 1.0)
--lr-scheduler LR_SCHEDULER
the lr scheduler (default: cosa)
--lr-warmup-epochs LR_WARMUP_EPOCHS
the number of epochs to warmup (default: 5)
--lr-warmup-method LR_WARMUP_METHOD
the warmup method (default: linear)
--lr-warmup-decay LR_WARMUP_DECAY
the decay for lr
--lr-step-size LR_STEP_SIZE
decrease lr every step-size epochs
--lr-gamma LR_GAMMA decrease lr by a factor of lr-gamma
--output-dir OUTPUT_DIR
path to save outputs
--resume RESUME path of checkpoint. If set to 'latest', it will try to load the latest checkpoint
--start-epoch N start epoch
--cache-dataset Cache the datasets for quicker initialization. It also serializes the transforms
--sync-bn Use sync batch norm
--test-only Only test the model
--pretrained Use pre-trained models from the modelzoo
--auto-augment AUTO_AUGMENT
auto augment policy (default: ta_wide)
--random-erase RANDOM_ERASE
random erasing probability (default: 0.1)
--world-size WORLD_SIZE
number of distributed processes
--dist-url DIST_URL url used to set up distributed training
--model-ema enable tracking Exponential Moving Average of model parameters
--model-ema-steps MODEL_EMA_STEPS
the number of iterations that controls how often to update the EMA model (default: 32)
--model-ema-decay MODEL_EMA_DECAY
decay factor for Exponential Moving Average of model parameters (default: 0.99998)
--interpolation INTERPOLATION
the interpolation method (default: bilinear)
--val-resize-size VAL_RESIZE_SIZE
the resize size used for validation (default: 232)
--val-crop-size VAL_CROP_SIZE
the central crop size used for validation (default: 224)
--train-crop-size TRAIN_CROP_SIZE
the random crop size used for training (default: 176)
--clip-grad-norm CLIP_GRAD_NORM
the maximum gradient norm (default None)
--ra-sampler whether to use Repeated Augmentation in training
--ra-reps RA_REPS number of repetitions for Repeated Augmentation (default: 4)
--prototype Use prototype model builders instead those from main area
--weights WEIGHTS the weights enum name to load
--seed SEED the random seed
--print-logdir print the dirs for tensorboard logs and pt files and exit
--clean delete the dirs for tensorboard logs and pt files
--disable-pinmemory not use pin memory in dataloader, which can help reduce memory consumption
--disable-amp not use automatic mixed precision training
--local_rank LOCAL_RANK
args for DDP, which should not be set by user
--disable-uda not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation
If we want to add some args, we can also inherit and override it:
class MyTrainer(train_classify.Trainer):
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--do-something', type=str, help="do something")
We can modify most functions in Trainer
.
We can use the following codes to train with Trainer
or the user-defined trainer:
trainer = Trainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
Trainer
will calculate Acc@1, Acc@5, loss
on the training and test dataset, and save them by tensorboard
. The model weights of the latest epoch and the maximum test accuracy will also be saved.Trainer
also supports Distributed Data Parallel (DDP) training.
Training on ImageNet
The default data loading function load_data
will load the ImageNet 2 dataset. With Trainer
and spikingjelly.activation_based.model.spiking_resnet
, we can train large-scale SNNs easily. Here are the example codes:
# spikingjelly.activation_based.model.train_imagenet_example
import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify
class SResNetTrainer(train_classify.Trainer):
def preprocess_train_sample(self, args, x: torch.Tensor):
# define how to process train sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def preprocess_test_sample(self, args, x: torch.Tensor):
# define how to process test sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def process_model_output(self, args, y: torch.Tensor):
return y.mean(0) # return firing rate
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--T', type=int, help="total time-steps")
parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
return parser
def get_tb_logdir_name(self, args):
return super().get_tb_logdir_name(args) + f'_T{args.T}'
def load_model(self, args, num_classes):
if args.model in spiking_resnet.__all__:
model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode,
surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(model, step_mode='m')
if args.cupy:
functional.set_backend(model, 'cupy', neuron.IFNode)
return model
else:
raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")
if __name__ == "__main__":
trainer = SResNetTrainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
The codes are saved in spikingjelly.activation_based.model.train_imagenet_example
. Training on a single GPU:
python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
Training with DDP on two GPUs:
python -m torch.distributed.launch --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
- 1
He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- 2
Deng, Jia, et al. “Imagenet: A large-scale hierarchical image database.” 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.