Train large-scale SNNs

Author: fangwei123456

Usage of activation_based.model

spikingjelly.activation_based.model has defined some classic networks, which we can use as we use torchvision.models. For example, we can build the Spiking ResNet 1 :

import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet

s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)

print(f's_resnet18={s_resnet18}')

with torch.no_grad():
    T = 4
    N = 1
    x_seq = torch.rand([T, N, 3, 224, 224])
    functional.set_step_mode(s_resnet18, 'm')
    y_seq = s_resnet18(x_seq)
    print(f'y_seq.shape={y_seq.shape}')
    functional.reset_net(s_resnet18)

The outputs are:

s_resnet18=SpikingResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
  (sn1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
    (surrogate_function): ATan(alpha=2.0, spiking=True)
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn1): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
      (sn2): IFNode(
        v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
        (surrogate_function): ATan(alpha=2.0, spiking=True)
      )
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s)
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
y_seq.shape=torch.Size([4, 1, 1000])

Spiking ResNet in SpikingJelly has the same network structure as that in torchvision. Their state_dict().keys() are identical and we can load pre-trained weights by setting pretrained=True:

s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)

Usage of activation_based.model.train_classify

spikingjelly.activation_based.model.train_classify is modified by torchvision 0.12 references. We can use this module to train easily.

spikingjelly.activation_based.model.train_classify.Trainer provides a flexible method to train. Users can change its functions to implement the desirable behaviors without too much efforts. For example, spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer defines how to set the optimizer:

# spikingjelly.activation_based.model.train_classify
class Trainer:
  # ...
  def set_optimizer(self, args, parameters):
      opt_name = args.opt.lower()
      if opt_name.startswith("sgd"):
          optimizer = torch.optim.SGD(
              parameters,
              lr=args.lr,
              momentum=args.momentum,
              weight_decay=args.weight_decay,
              nesterov="nesterov" in opt_name,
          )
      elif opt_name == "rmsprop":
          optimizer = torch.optim.RMSprop(
              parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
          )
      elif opt_name == "adamw":
          optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
      else:
          raise RuntimeError(f"Invalid optimizer
           {args.opt}. Only SGD, RMSprop and AdamW are supported.")
      return optimizer

  def main(self, args):
    # ...
    optimizer = self.set_optimizer(args, parameters)
    # ...

If we want to add an optimizer, e.g., Adamax, we can inherit the class and override this function:

class MyTrainer(train_classify.Trainer):
    def set_optimizer(self, args, parameters):
        opt_name = args.opt.lower()
        if opt_name.startswith("adamax"):
            optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
            return optimizer
        else:
            return super().set_optimizer(args, parameters)

Trainer.get_args_parser defines the args for training:

(pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h

usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING]
                        [--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY]
                        [--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT]
                        [--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION]
                        [--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED]
                        [--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda]

PyTorch Classification Training

optional arguments:
  -h, --help            show this help message and exit
  --data-path DATA_PATH
                        dataset path
  --model MODEL         model name
  --device DEVICE       device (Use cuda or cpu Default: cuda)
  -b BATCH_SIZE, --batch-size BATCH_SIZE
                        images per gpu, the total batch size is $NGPU x batch_size
  --epochs N            number of total epochs to run
  -j N, --workers N     number of data loading workers (default: 16)
  --opt OPT             optimizer
  --lr LR               initial learning rate
  --momentum M          momentum
  --wd W, --weight-decay W
                        weight decay (default: 0.)
  --norm-weight-decay NORM_WEIGHT_DECAY
                        weight decay for Normalization layers (default: None, same value as --wd)
  --label-smoothing LABEL_SMOOTHING
                        label smoothing (default: 0.1)
  --mixup-alpha MIXUP_ALPHA
                        mixup alpha (default: 0.2)
  --cutmix-alpha CUTMIX_ALPHA
                        cutmix alpha (default: 1.0)
  --lr-scheduler LR_SCHEDULER
                        the lr scheduler (default: cosa)
  --lr-warmup-epochs LR_WARMUP_EPOCHS
                        the number of epochs to warmup (default: 5)
  --lr-warmup-method LR_WARMUP_METHOD
                        the warmup method (default: linear)
  --lr-warmup-decay LR_WARMUP_DECAY
                        the decay for lr
  --lr-step-size LR_STEP_SIZE
                        decrease lr every step-size epochs
  --lr-gamma LR_GAMMA   decrease lr by a factor of lr-gamma
  --output-dir OUTPUT_DIR
                        path to save outputs
  --resume RESUME       path of checkpoint. If set to 'latest', it will try to load the latest checkpoint
  --start-epoch N       start epoch
  --cache-dataset       Cache the datasets for quicker initialization. It also serializes the transforms
  --sync-bn             Use sync batch norm
  --test-only           Only test the model
  --pretrained          Use pre-trained models from the modelzoo
  --auto-augment AUTO_AUGMENT
                        auto augment policy (default: ta_wide)
  --random-erase RANDOM_ERASE
                        random erasing probability (default: 0.1)
  --world-size WORLD_SIZE
                        number of distributed processes
  --dist-url DIST_URL   url used to set up distributed training
  --model-ema           enable tracking Exponential Moving Average of model parameters
  --model-ema-steps MODEL_EMA_STEPS
                        the number of iterations that controls how often to update the EMA model (default: 32)
  --model-ema-decay MODEL_EMA_DECAY
                        decay factor for Exponential Moving Average of model parameters (default: 0.99998)
  --interpolation INTERPOLATION
                        the interpolation method (default: bilinear)
  --val-resize-size VAL_RESIZE_SIZE
                        the resize size used for validation (default: 232)
  --val-crop-size VAL_CROP_SIZE
                        the central crop size used for validation (default: 224)
  --train-crop-size TRAIN_CROP_SIZE
                        the random crop size used for training (default: 176)
  --clip-grad-norm CLIP_GRAD_NORM
                        the maximum gradient norm (default None)
  --ra-sampler          whether to use Repeated Augmentation in training
  --ra-reps RA_REPS     number of repetitions for Repeated Augmentation (default: 4)
  --prototype           Use prototype model builders instead those from main area
  --weights WEIGHTS     the weights enum name to load
  --seed SEED           the random seed
  --print-logdir        print the dirs for tensorboard logs and pt files and exit
  --clean               delete the dirs for tensorboard logs and pt files
  --disable-pinmemory   not use pin memory in dataloader, which can help reduce memory consumption
  --disable-amp         not use automatic mixed precision training
  --local_rank LOCAL_RANK
                        args for DDP, which should not be set by user
  --disable-uda         not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation

If we want to add some args, we can also inherit and override it:

class MyTrainer(train_classify.Trainer):
    def get_args_parser(self, add_help=True):
        parser = super().get_args_parser()
        parser.add_argument('--do-something', type=str, help="do something")

We can modify most functions in Trainer.

We can use the following codes to train with Trainer or the user-defined trainer:

trainer = Trainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)

Trainer will calculate Acc@1, Acc@5, loss on the training and test dataset, and save them by tensorboard. The model weights of the latest epoch and the maximum test accuracy will also be saved.Trainer also supports Distributed Data Parallel (DDP) training.

Training on ImageNet

The default data loading function load_data will load the ImageNet 2 dataset. With Trainer and spikingjelly.activation_based.model.spiking_resnet, we can train large-scale SNNs easily. Here are the example codes:

# spikingjelly.activation_based.model.train_imagenet_example
import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify


class SResNetTrainer(train_classify.Trainer):
    def preprocess_train_sample(self, args, x: torch.Tensor):
        # define how to process train sample before send it to model
        return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]

    def preprocess_test_sample(self, args, x: torch.Tensor):
        # define how to process test sample before send it to model
        return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]

    def process_model_output(self, args, y: torch.Tensor):
        return y.mean(0)  # return firing rate

    def get_args_parser(self, add_help=True):
        parser = super().get_args_parser()
        parser.add_argument('--T', type=int, help="total time-steps")
        parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
        return parser

    def get_tb_logdir_name(self, args):
        return super().get_tb_logdir_name(args) + f'_T{args.T}'

    def load_model(self, args, num_classes):
        if args.model in spiking_resnet.__all__:
            model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode,
                                                        surrogate_function=surrogate.ATan(), detach_reset=True)
            functional.set_step_mode(model, step_mode='m')
            if args.cupy:
                functional.set_backend(model, 'cupy', neuron.IFNode)

            return model
        else:
            raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")


if __name__ == "__main__":
    trainer = SResNetTrainer()
    args = trainer.get_args_parser().parse_args()
    trainer.main(args)

The codes are saved in spikingjelly.activation_based.model.train_imagenet_example. Training on a single GPU:

python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90

Training with DDP on two GPUs:

python -m torch.distributed.launch --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
1

He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

2

Deng, Jia, et al. “Imagenet: A large-scale hierarchical image database.” 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.