spikingjelly.activation_based.model.train_imagenet_example 源代码

import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify


[文档] class SResNetTrainer(train_classify.Trainer): r""" **API Language:** :ref:`中文 <SResNetTrainer-cn>` | :ref:`English <SResNetTrainer-en>` ---- .. _SResNetTrainer-cn: * **中文** * **中文** :class:`SResNetTrainer` 是一个用于在 ImageNet 数据集上训练脉冲 ResNet 模型的训练器类。 它继承自 :class:`train_classify.Trainer`,并重写了数据预处理、模型输出处理、模型加载等方法。 主要功能: - 数据预处理:将 ``[N, C, H, W]`` 形状的输入扩展为 ``[T, N, C, H, W]``,其中 ``T`` 为总时间步数。 - 模型输出处理:将 ``T`` 个时间步的输出沿时间维取均值,作为最终的预测结果(发放率)。 - 模型加载:支持从 :mod:`spiking_resnet` 加载多种脉冲 ResNet 模型,并可选择 CuPy 后端加速。 - 额外命令行参数:添加了 ``--T``(时间步数)和 ``--cupy``(是否使用 CuPy 后端)参数。 ---- .. _SResNetTrainer-en: * **English** * **English** :class:`SResNetTrainer` is a trainer for training spiking ResNet models on the ImageNet dataset. It inherits from :class:`train_classify.Trainer` and overrides data preprocessing, model output processing, and model loading methods. Key features: - Data preprocessing: expands input from ``[N, C, H, W]`` to ``[T, N, C, H, W]``, where ``T`` is the total number of time-steps. - Model output processing: averages outputs over ``T`` time-steps along the time dimension as the final prediction (firing rate). - Model loading: supports loading various spiking ResNet models from :mod:`spiking_resnet` with an optional CuPy backend for acceleration. - Extra CLI arguments: adds ``--T`` (number of time-steps) and ``--cupy`` (enable CuPy backend) arguments. """
[文档] def preprocess_train_sample(self, args, x: torch.Tensor): # define how to process train sample before send it to model return x.unsqueeze(0).expand( args.T, -1, -1, -1, -1 ) # [N, C, H, W] -> [T, N, C, H, W]
[文档] def preprocess_test_sample(self, args, x: torch.Tensor): # define how to process test sample before send it to model return x.unsqueeze(0).expand( args.T, -1, -1, -1, -1 ) # [N, C, H, W] -> [T, N, C, H, W]
[文档] def process_model_output(self, args, y: torch.Tensor): return y.mean(0) # return firing rate
[文档] def get_args_parser(self, add_help=True): parser = super().get_args_parser() parser.add_argument("--T", type=int, help="total time-steps") parser.add_argument( "--cupy", action="store_true", help="set the neurons to use cupy backend" ) return parser
[文档] def get_tb_logdir_name(self, args): return super().get_tb_logdir_name(args) + f"_T{args.T}"
[文档] def load_model(self, args, num_classes): if args.model in spiking_resnet.__all__: model = spiking_resnet.__dict__[args.model]( pretrained=args.pretrained, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True, ) functional.set_step_mode(model, step_mode="m") if args.cupy: functional.set_backend(model, "cupy", neuron.IFNode) return model else: raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")
if __name__ == "__main__": # -m torch.distributed.launch --nproc_per_node=2 spikingjelly.activation_based.model.train_imagenet_example # python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90 trainer = SResNetTrainer() args = trainer.get_args_parser().parse_args() trainer.main(args)