Train large-scale SNNs ====================================== Author: `fangwei123456 `_ 中文版: :doc:`../cn/train_large_scale_snn` Usage of ``activation_based.model`` ---------------------------------------------- :class:`spikingjelly.activation_based.model` has defined some classic networks, which we can use as we use :class:`torchvision.models`. \ For example, we can build the Spiking ResNet [#ResNet]_ : .. code-block:: python import torch import torch.nn as nn from spikingjelly.activation_based import surrogate, neuron, functional from spikingjelly.activation_based.model import spiking_resnet s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True) print(f's_resnet18={s_resnet18}') with torch.no_grad(): T = 4 N = 1 x_seq = torch.rand([T, N, 3, 224, 224]) functional.set_step_mode(s_resnet18, 'm') y_seq = s_resnet18(x_seq) print(f'y_seq.shape={y_seq.shape}') functional.reset_net(s_resnet18) The outputs are: .. code-block:: shell s_resnet18=SpikingResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) ) (1): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) ) ) (layer2): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (downsample): Sequential( (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s) (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) ) ) (1): BasicBlock( (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) ) ) (layer3): Sequential( (0): BasicBlock( (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (downsample): Sequential( (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) ) ) (1): BasicBlock( (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) ) ) (layer4): Sequential( (0): BasicBlock( (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (downsample): Sequential( (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s) (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) ) ) (1): BasicBlock( (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn1): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s) (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s) (sn2): IFNode( v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch (surrogate_function): ATan(alpha=2.0, spiking=True) ) ) ) (avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s) (fc): Linear(in_features=512, out_features=1000, bias=True) ) y_seq.shape=torch.Size([4, 1, 1000]) Spiking ResNet in SpikingJelly has the same network structure as that in ``torchvision``. Their ``state_dict().keys()`` are identical and we can load \ pre-trained weights by setting ``pretrained=True``: .. code-block:: python s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True) Usage of ``activation_based.model.train_classify`` -------------------------------------------------------- :class:`spikingjelly.activation_based.model.train_classify` is modified by `torchvision 0.12 references `_. \ We can use this module to train easily. :class:`spikingjelly.activation_based.model.train_classify.Trainer` provides a flexible method to train. Users can change its functions to implement the desirable behaviors without too much efforts. For example, :class:`spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer` defines how to set the optimizer: .. code-block:: python # spikingjelly.activation_based.model.train_classify class Trainer: # ... def set_optimizer(self, args, parameters): opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov="nesterov" in opt_name, ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 ) elif opt_name == "adamw": optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.") return optimizer def main(self, args): # ... optimizer = self.set_optimizer(args, parameters) # ... If we want to add an optimizer, e.g., ``Adamax``, we can inherit the class and override this function: .. code-block:: python class MyTrainer(train_classify.Trainer): def set_optimizer(self, args, parameters): opt_name = args.opt.lower() if opt_name.startswith("adamax"): optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay) return optimizer else: return super().set_optimizer(args, parameters) :class:`Trainer.get_args_parser ` defines the args for training: .. code-block:: shell (pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING] [--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY] [--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT] [--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION] [--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED] [--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda] PyTorch Classification Training optional arguments: -h, --help show this help message and exit --data-path DATA_PATH dataset path --model MODEL model name --device DEVICE device (Use cuda or cpu Default: cuda) -b BATCH_SIZE, --batch-size BATCH_SIZE images per gpu, the total batch size is $NGPU x batch_size --epochs N number of total epochs to run -j N, --workers N number of data loading workers (default: 16) --opt OPT optimizer --lr LR initial learning rate --momentum M momentum --wd W, --weight-decay W weight decay (default: 0.) --norm-weight-decay NORM_WEIGHT_DECAY weight decay for Normalization layers (default: None, same value as --wd) --label-smoothing LABEL_SMOOTHING label smoothing (default: 0.1) --mixup-alpha MIXUP_ALPHA mixup alpha (default: 0.2) --cutmix-alpha CUTMIX_ALPHA cutmix alpha (default: 1.0) --lr-scheduler LR_SCHEDULER the lr scheduler (default: cosa) --lr-warmup-epochs LR_WARMUP_EPOCHS the number of epochs to warmup (default: 5) --lr-warmup-method LR_WARMUP_METHOD the warmup method (default: linear) --lr-warmup-decay LR_WARMUP_DECAY the decay for lr --lr-step-size LR_STEP_SIZE decrease lr every step-size epochs --lr-gamma LR_GAMMA decrease lr by a factor of lr-gamma --output-dir OUTPUT_DIR path to save outputs --resume RESUME path of checkpoint. If set to 'latest', it will try to load the latest checkpoint --start-epoch N start epoch --cache-dataset Cache the datasets for quicker initialization. It also serializes the transforms --sync-bn Use sync batch norm --test-only Only test the model --pretrained Use pre-trained models from the modelzoo --auto-augment AUTO_AUGMENT auto augment policy (default: ta_wide) --random-erase RANDOM_ERASE random erasing probability (default: 0.1) --world-size WORLD_SIZE number of distributed processes --dist-url DIST_URL url used to set up distributed training --model-ema enable tracking Exponential Moving Average of model parameters --model-ema-steps MODEL_EMA_STEPS the number of iterations that controls how often to update the EMA model (default: 32) --model-ema-decay MODEL_EMA_DECAY decay factor for Exponential Moving Average of model parameters (default: 0.99998) --interpolation INTERPOLATION the interpolation method (default: bilinear) --val-resize-size VAL_RESIZE_SIZE the resize size used for validation (default: 232) --val-crop-size VAL_CROP_SIZE the central crop size used for validation (default: 224) --train-crop-size TRAIN_CROP_SIZE the random crop size used for training (default: 176) --clip-grad-norm CLIP_GRAD_NORM the maximum gradient norm (default None) --ra-sampler whether to use Repeated Augmentation in training --ra-reps RA_REPS number of repetitions for Repeated Augmentation (default: 4) --prototype Use prototype model builders instead those from main area --weights WEIGHTS the weights enum name to load --seed SEED the random seed --print-logdir print the dirs for tensorboard logs and pt files and exit --clean delete the dirs for tensorboard logs and pt files --disable-pinmemory not use pin memory in dataloader, which can help reduce memory consumption --disable-amp not use automatic mixed precision training --local_rank LOCAL_RANK args for DDP, which should not be set by user --disable-uda not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation If we want to add some args, we can also inherit and override it: .. code-block:: python class MyTrainer(train_classify.Trainer): def get_args_parser(self, add_help=True): parser = super().get_args_parser() parser.add_argument('--do-something', type=str, help="do something") We can modify most functions in :class:`Trainer `. We can use the following codes to train with ``Trainer`` or the user-defined trainer: .. code-block:: python trainer = Trainer() args = trainer.get_args_parser().parse_args() trainer.main(args) ``Trainer`` will calculate ``Acc@1, Acc@5, loss`` on the training and test dataset, and save them by ``tensorboard``. The model weights of the latest epoch and the maximum test accuracy will also be saved.\ ``Trainer`` also supports Distributed Data Parallel (DDP) training. Training on ImageNet ---------------------------------------------- The default data loading function :class:`load_data ` will load the ImageNet [#ImageNet]_ dataset. With :class:`Trainer ` and :class:`spikingjelly.activation_based.model.spiking_resnet`, \ we can train large-scale SNNs easily. Here are the example codes: .. code-block:: python # spikingjelly.activation_based.model.train_imagenet_example import torch from spikingjelly.activation_based import surrogate, neuron, functional from spikingjelly.activation_based.model import spiking_resnet, train_classify class SResNetTrainer(train_classify.Trainer): def preprocess_train_sample(self, args, x: torch.Tensor): # define how to process train sample before send it to model return x.unsqueeze(0).expand(args.T, -1, -1, -1, -1) # [N, C, H, W] -> [T, N, C, H, W] def preprocess_test_sample(self, args, x: torch.Tensor): # define how to process test sample before send it to model return x.unsqueeze(0).expand(args.T, -1, -1, -1, -1) # [N, C, H, W] -> [T, N, C, H, W] def process_model_output(self, args, y: torch.Tensor): return y.mean(0) # return firing rate def get_args_parser(self, add_help=True): parser = super().get_args_parser() parser.add_argument('--T', type=int, help="total time-steps") parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend") return parser def get_tb_logdir_name(self, args): return super().get_tb_logdir_name(args) + f'_T{args.T}' def load_model(self, args, num_classes): if args.model in spiking_resnet.__all__: model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True) functional.set_step_mode(model, step_mode='m') if args.cupy: functional.set_backend(model, 'cupy', neuron.IFNode) return model else: raise ValueError(f"args.model should be one of {spiking_resnet.__all__}") if __name__ == "__main__": trainer = SResNetTrainer() args = trainer.get_args_parser().parse_args() trainer.main(args) The codes are saved in :class:`spikingjelly.activation_based.model.train_imagenet_example`. Training on a single GPU: .. code-block:: shell python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90 Training with DDP on two GPUs: .. code-block:: shell torchrun --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90 .. [#ResNet] He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. .. [#ImageNet] Deng, Jia, et al. "Imagenet: A large-scale hierarchical image database." 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.