spikingjelly.activation_based.model.train_classify 源代码

import argparse
import datetime
import os
import random
import sys
import time
import warnings

__all__ = ["Trainer", "set_deterministic", "seed_worker"]

import numpy as np
import torch
import torch.utils.data
import torchvision
from torch import nn
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms.functional import InterpolationMode

from .. import functional
from .tv_ref_classify import presets, transforms, utils
from .tv_ref_classify.sampler import RASampler

try:
    from torchvision import prototype
except ImportError:
    prototype = None


[文档] def set_deterministic(_seed_: int = 2020, disable_uda=False): r""" **API Language:** :ref:`中文 <set_deterministic-cn>` | :ref:`English <set_deterministic-en>` ---- .. _set_deterministic-cn: * **中文** 设置 PyTorch 为确定性模式,使得在相同输入下结果可复现。此函数会设置 Python、NumPy 和 PyTorch(CPU/CUDA)的随机种子,并启用 CuDNN 确定性算法 和 PyTorch 的确定性模式。 :param _seed_: 随机种子,默认为 2020 :type _seed_: int :param disable_uda: 是否禁用 UDA(不确定区域丢弃算法) :type disable_uda: bool :return: None :rtype: None ---- .. _set_deterministic-en: * **English** Set PyTorch to deterministic mode so that results are reproducible under the same input. This function seeds Python, NumPy and PyTorch (CPU/CUDA) random number generators, enables CuDNN deterministic algorithms and PyTorch deterministic mode. :param _seed_: random seed, default is 2020 :type _seed_: int :param disable_uda: whether to disable UDA (unreliable data augmentation) :type disable_uda: bool :return: None :rtype: None """ random.seed(_seed_) np.random.seed(_seed_) torch.manual_seed( _seed_ ) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA) torch.cuda.manual_seed_all(_seed_) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if disable_uda: pass else: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB). torch.use_deterministic_algorithms(True)
[文档] def seed_worker(worker_id): r""" **API Language:** :ref:`中文 <seed_worker-cn>` | :ref:`English <seed_worker-en>` ---- .. _seed_worker-cn: * **中文** DataLoader 的 worker 初始化函数,用于确保每个 worker 进程使用不同的、 可复现的随机种子。 :param worker_id: worker 的索引 :type worker_id: int :return: None :rtype: None ---- .. _seed_worker-en: * **English** DataLoader worker initialization function that ensures each worker process uses a distinct and reproducible random seed. :param worker_id: the index of the worker :type worker_id: int :return: None :rtype: None """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)
[文档] class Trainer: r""" **API Language:** :ref:`中文 <Trainer-cn>` | :ref:`English <Trainer-en>` ---- .. _Trainer-cn: * **中文** * **中文** 分类任务的训练器。封装了训练/验证循环、学习率调度、混合精度训练、torch.compile 支持、TensorBoard 日志等功能。 ---- .. _Trainer-en: * **English** * **English** Classification task trainer. Wraps training/validation loops, LR scheduling, mixed-precision training, torch.compile support, and TensorBoard logging. """
[文档] def get_data_to_device_kwargs(self, args): return {"non_blocking": not args.disable_pinmemory}
[文档] def cal_acc1_acc5(self, output, target): # define how to calculate acc1 and acc5 acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) return acc1, acc5
[文档] def preprocess_train_sample(self, args, x: torch.Tensor): # define how to process train sample before send it to model return x
[文档] def preprocess_test_sample(self, args, x: torch.Tensor): # define how to process test sample before send it to model return x
[文档] def process_model_output(self, args, y: torch.Tensor): # define how to process y = model(x) return y
[文档] def compile_model( self, args, model: nn.Module, *, enabled: bool | None = None ) -> nn.Module: if enabled is None: enabled = args.compile if not enabled: return model if not hasattr(torch, "compile"): raise RuntimeError( "torch.compile is not available in the current PyTorch version." ) compile_kwargs = {"backend": args.compile_backend} if args.compile_mode is not None: compile_kwargs["mode"] = args.compile_mode if args.compile_backend == "inductor": compile_options = {} if args.compile_disable_cudagraphs: compile_options.update( { "triton.cudagraphs": False, "triton.cudagraph_trees": False, } ) if compile_options: compile_kwargs["options"] = compile_options try: return torch.compile(model, **compile_kwargs) except RuntimeError as e: compile_options = compile_kwargs.get("options") error_text = str(e).lower() retryable_option_error = ( "options" in error_text or "cudagraph" in error_text or "config" in error_text ) if not compile_options or not retryable_option_error: raise warnings.warn( "torch.compile failed with backend options " f"{compile_options!r}; retrying without options. " f"Original error: {e}", RuntimeWarning, stacklevel=2, ) compile_kwargs.pop("options", None) return torch.compile(model, **compile_kwargs)
[文档] def get_eval_model(self, args, train_model, model_without_ddp): if args.compile and not args.compile_eval: return model_without_ddp return train_model
[文档] def train_one_epoch( self, model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None, ): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("img/s", utils.ThroughputValue(fmt="{global_avg:.3f}")) header = f"Epoch: [{epoch}]" data_to_device_kwargs = self.get_data_to_device_kwargs(args) for i, (image, target) in enumerate( metric_logger.log_every(data_loader, -1, header) ): start_time = time.time() image = image.to(device, **data_to_device_kwargs) target = target.to(device, **data_to_device_kwargs) with torch.cuda.amp.autocast(enabled=scaler is not None): image = self.preprocess_train_sample(args, image) output = self.process_model_output(args, model(image)) loss = criterion(output, target) optimizer.zero_grad(set_to_none=True) if scaler is not None: scaler.scale(loss).backward() if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() if args.clip_grad_norm is not None: nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() functional.reset_net(model) if model_ema and i % args.model_ema_steps == 0: model_ema.update_parameters(model) if epoch < args.lr_warmup_epochs: # Reset ema buffer to keep copying weights during warmup period model_ema.n_averaged.fill_(0) acc1, acc5 = self.cal_acc1_acc5(output, target) batch_size = target.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update( samples=batch_size, elapsed_time=time.time() - start_time ) # gather the stats from all processes metric_logger.synchronize_between_processes() train_loss, train_acc1, train_acc5 = ( metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, ) print( f"Train: train_acc1={train_acc1:.3f}, train_acc5={train_acc5:.3f}, train_loss={train_loss:.6f}, samples/s={metric_logger.meters['img/s'].global_avg:.3f}" ) return train_loss, train_acc1, train_acc5
[文档] def evaluate(self, args, model, criterion, data_loader, device, log_suffix=""): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" data_to_device_kwargs = self.get_data_to_device_kwargs(args) num_processed_samples = 0 start_time = time.time() with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, -1, header): image = image.to(device, **data_to_device_kwargs) target = target.to(device, **data_to_device_kwargs) image = self.preprocess_test_sample(args, image) output = self.process_model_output(args, model(image)) loss = criterion(output, target) acc1, acc5 = self.cal_acc1_acc5(output, target) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = target.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) num_processed_samples += batch_size functional.reset_net(model) # gather the stats from all processes num_processed_samples = utils.reduce_across_processes(num_processed_samples) if ( hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0 ): # See FIXME above warnings.warn( f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " "samples were used for the validation, which might bias the results. " "Try adjusting the batch size and / or the world size. " "Setting the world size to 1 is always a safe bet." ) metric_logger.synchronize_between_processes() test_loss, test_acc1, test_acc5 = ( metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg, ) print( f"Test: test_acc1={test_acc1:.3f}, test_acc5={test_acc5:.3f}, test_loss={test_loss:.6f}, samples/s={num_processed_samples / (time.time() - start_time):.3f}" ) return test_loss, test_acc1, test_acc5
def _get_cache_path(self, filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join( "~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt" ) cache_path = os.path.expanduser(cache_path) return cache_path
[文档] def load_data(self, args): return self.load_ImageNet(args)
[文档] def load_CIFAR10(self, args): # Data loading code print("Loading data") train_crop_size = args.train_crop_size interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.CIFAR10( root=args.data_path, train=True, transform=presets.ClassificationPresetTrain( crop_size=train_crop_size, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010), interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), ) print("Took", time.time() - st) print("Loading validation data") dataset_test = torchvision.datasets.CIFAR10( root=args.data_path, train=False, transform=torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ), ] ), ) print("Creating data loaders") loader_g = torch.Generator() loader_g.manual_seed(args.seed) if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: train_sampler = RASampler( dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed ) else: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, seed=args.seed ) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test, shuffle=False ) else: train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
[文档] def load_ImageNet(self, args): # Data loading code traindir = os.path.join(args.data_path, "train") valdir = os.path.join(args.data_path, "val") print("Loading data") val_resize_size, val_crop_size, train_crop_size = ( args.val_resize_size, args.val_crop_size, args.train_crop_size, ) interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() cache_path = self._get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = self._get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: if not args.prototype: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation, ) else: if args.weights: weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: preprocessing = prototype.transforms.ImageNetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation, ) dataset_test = torchvision.datasets.ImageFolder( valdir, preprocessing, ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") loader_g = torch.Generator() loader_g.manual_seed(args.seed) if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: train_sampler = RASampler( dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed ) else: train_sampler = torch.utils.data.distributed.DistributedSampler( dataset, seed=args.seed ) test_sampler = torch.utils.data.distributed.DistributedSampler( dataset_test, shuffle=False ) else: train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
[文档] def load_model(self, args, num_classes): raise NotImplementedError("Users should define this function to load model")
[文档] def get_tb_logdir_name(self, args): tb_dir = ( f"{args.model}" f"_b{args.batch_size}" f"_e{args.epochs}" f"_{args.opt}" f"_lr{args.lr}" f"_wd{args.weight_decay}" f"_ls{args.label_smoothing}" f"_ma{args.mixup_alpha}" f"_ca{args.cutmix_alpha}" f"_sbn{1 if args.sync_bn else 0}" f"_ra{args.ra_reps if args.ra_sampler else 0}" f"_re{args.random_erase}" f"_aaug{args.auto_augment}" f"_size{args.train_crop_size}_{args.val_resize_size}_{args.val_crop_size}" f"_seed{args.seed}" ) return tb_dir
[文档] def set_optimizer(self, args, parameters): opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov="nesterov" in opt_name, ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9, ) elif opt_name == "adamw": optimizer = torch.optim.AdamW( parameters, lr=args.lr, weight_decay=args.weight_decay ) else: raise RuntimeError( f"Invalid optimizer {args.opt}. " "Only SGD, RMSprop and AdamW are supported." ) return optimizer
[文档] def set_lr_scheduler(self, args, optimizer): args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "step": main_lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma ) elif args.lr_scheduler == "cosa": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - args.lr_warmup_epochs ) elif args.lr_scheduler == "exp": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=args.lr_gamma ) else: main_lr_scheduler = None if args.lr_warmup_epochs > 0: if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs, ) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs, ) else: warmup_lr_scheduler = None lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs], ) else: lr_scheduler = main_lr_scheduler return lr_scheduler
[文档] def main(self, args): set_deterministic(args.seed, args.disable_uda) if args.workers > 0 and args.prefetch_factor < 1: raise ValueError( f"--prefetch-factor must be >= 1 when --workers > 0, but got {args.prefetch_factor}." ) if args.prototype and prototype is None: raise ImportError( "The prototype module couldn't be found. Please install the latest torchvision nightly." ) if not args.prototype and args.weights: raise ValueError( "The weights parameter works only in prototype mode. Please pass the --prototype argument." ) if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) dataset, dataset_test, train_sampler, test_sampler = self.load_data(args) collate_fn = None num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: if torch.__version__ >= torch.torch_version.TorchVersion("1.10.0"): pass else: # TODO implement a CrossEntropyLoss to support for probabilities for each class. raise NotImplementedError( "CrossEntropyLoss in pytorch < 1.11.0 does not support for probabilities for each class." "Set mixup_alpha=0. to avoid such a problem or update your pytorch." ) mixup_transforms.append( transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha) ) if args.cutmix_alpha > 0.0: mixup_transforms.append( transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha) ) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=not args.disable_pinmemory, collate_fn=collate_fn, worker_init_fn=seed_worker, persistent_workers=args.persistent_workers and args.workers > 0, prefetch_factor=(args.prefetch_factor if args.workers > 0 else None), ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=not args.disable_pinmemory, worker_init_fn=seed_worker, persistent_workers=args.persistent_workers and args.workers > 0, prefetch_factor=(args.prefetch_factor if args.workers > 0 else None), ) print("Creating model") model = self.load_model(args, num_classes) model.to(device) print(model) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) if args.norm_weight_decay is None: parameters = model.parameters() else: param_groups = torchvision.ops._utils.split_normalization_params(model) wd_groups = [args.norm_weight_decay, args.weight_decay] parameters = [ {"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p ] optimizer = self.set_optimizer(args, parameters) if args.disable_amp: scaler = None else: scaler = torch.cuda.amp.GradScaler() lr_scheduler = self.set_lr_scheduler(args, optimizer) model_without_ddp = model model_ema = None if args.model_ema: # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 # # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs adjust = ( args.world_size * args.batch_size * args.model_ema_steps / args.epochs ) alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) model_ema = utils.ExponentialMovingAverage( model_without_ddp, device=device, decay=1.0 - alpha ) # 确定目录文件名 tb_dir = self.get_tb_logdir_name(args) pt_dir = os.path.join(args.output_dir, "pt", tb_dir) tb_dir = os.path.join(args.output_dir, tb_dir) if args.print_logdir: print(tb_dir) print(pt_dir) exit() if args.clean: if utils.is_main_process(): if os.path.exists(tb_dir): os.remove(tb_dir) if os.path.exists(pt_dir): os.remove(pt_dir) print(f"remove {tb_dir} and {pt_dir}.") if utils.is_main_process(): os.makedirs(tb_dir, exist_ok=args.resume is not None) os.makedirs(pt_dir, exist_ok=args.resume is not None) if args.resume is not None: if args.resume == "latest": checkpoint = torch.load( os.path.join(pt_dir, "checkpoint_latest.pth"), map_location="cpu" ) else: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if not args.test_only: optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if model_ema: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler: scaler.load_state_dict(checkpoint["scaler"]) if utils.is_main_process(): max_test_acc1 = checkpoint["max_test_acc1"] if model_ema: max_ema_test_acc1 = checkpoint["max_ema_test_acc1"] model = self.compile_model(args, model_without_ddp) eval_model = self.get_eval_model(args, model, model_without_ddp) if args.distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu] ) if utils.is_main_process(): tb_writer = SummaryWriter(tb_dir, purge_step=args.start_epoch) with open( os.path.join(tb_dir, "args.txt"), "w", encoding="utf-8" ) as args_txt: args_txt.write(str(args)) args_txt.write("\n") args_txt.write(" ".join(sys.argv)) max_test_acc1 = -1.0 if model_ema: max_ema_test_acc1 = -1.0 if args.test_only: if model_ema: self.evaluate( args, model_ema, criterion, data_loader_test, device=device, log_suffix="EMA", ) else: self.evaluate( args, eval_model, criterion, data_loader_test, device=device ) return for epoch in range(args.start_epoch, args.epochs): start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) self.before_train_one_epoch(args, model, epoch) train_loss, train_acc1, train_acc5 = self.train_one_epoch( model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler, ) if utils.is_main_process(): tb_writer.add_scalar("train_loss", train_loss, epoch) tb_writer.add_scalar("train_acc1", train_acc1, epoch) tb_writer.add_scalar("train_acc5", train_acc5, epoch) lr_scheduler.step() self.before_test_one_epoch(args, eval_model, epoch) test_loss, test_acc1, test_acc5 = self.evaluate( args, eval_model, criterion, data_loader_test, device=device ) if utils.is_main_process(): tb_writer.add_scalar("test_loss", test_loss, epoch) tb_writer.add_scalar("test_acc1", test_acc1, epoch) tb_writer.add_scalar("test_acc5", test_acc5, epoch) if model_ema: ema_test_loss, ema_test_acc1, ema_test_acc5 = self.evaluate( args, model_ema, criterion, data_loader_test, device=device, log_suffix="EMA", ) if utils.is_main_process(): tb_writer.add_scalar("ema_test_loss", ema_test_loss, epoch) tb_writer.add_scalar("ema_test_acc1", ema_test_acc1, epoch) tb_writer.add_scalar("ema_test_acc5", ema_test_acc5, epoch) if utils.is_main_process(): save_max_test_acc1 = False save_max_ema_test_acc1 = False if test_acc1 > max_test_acc1: max_test_acc1 = test_acc1 save_max_test_acc1 = True checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, "max_test_acc1": max_test_acc1, } if model_ema: if ema_test_acc1 > max_ema_test_acc1: max_ema_test_acc1 = ema_test_acc1 save_max_ema_test_acc1 = True checkpoint["model_ema"] = model_ema.state_dict() checkpoint["max_ema_test_acc1"] = max_ema_test_acc1 if scaler: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master( checkpoint, os.path.join(pt_dir, f"checkpoint_{epoch}.pth") ) utils.save_on_master( checkpoint, os.path.join(pt_dir, "checkpoint_latest.pth") ) if save_max_test_acc1: utils.save_on_master( checkpoint, os.path.join(pt_dir, "checkpoint_max_test_acc1.pth"), ) if model_ema and save_max_ema_test_acc1: utils.save_on_master( checkpoint, os.path.join(pt_dir, "checkpoint_max_ema_test_acc1.pth"), ) if utils.is_main_process() and epoch > 0: os.remove(os.path.join(pt_dir, f"checkpoint_{epoch - 1}.pth")) print( f"escape time={(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime('%Y-%m-%d %H:%M:%S')}\n" ) print(args)
[文档] def before_test_one_epoch(self, args, model, epoch): pass
[文档] def before_train_one_epoch(self, args, model, epoch): pass
[文档] def get_args_parser(self, add_help=True): parser = argparse.ArgumentParser( description="PyTorch Classification Training", add_help=add_help ) parser.add_argument( "--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path", ) parser.add_argument("--model", default="resnet18", type=str, help="model name") parser.add_argument( "--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)", ) parser.add_argument( "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size", ) parser.add_argument( "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run", ) parser.add_argument( "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)", ) parser.add_argument("--opt", default="sgd", type=str, help="optimizer") parser.add_argument( "--lr", default=0.1, type=float, help="initial learning rate" ) parser.add_argument( "--momentum", default=0.9, type=float, metavar="M", help="momentum" ) parser.add_argument( "--wd", "--weight-decay", default=0.0, type=float, metavar="W", help="weight decay (default: 0.)", dest="weight_decay", ) parser.add_argument( "--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) parser.add_argument( "--label-smoothing", default=0.1, type=float, help="label smoothing (default: 0.1)", dest="label_smoothing", ) parser.add_argument( "--mixup-alpha", default=0.2, type=float, help="mixup alpha (default: 0.2)" ) parser.add_argument( "--cutmix-alpha", default=1.0, type=float, help="cutmix alpha (default: 1.0)", ) parser.add_argument( "--lr-scheduler", default="cosa", type=str, help="the lr scheduler (default: cosa)", ) parser.add_argument( "--lr-warmup-epochs", default=5, type=int, help="the number of epochs to warmup (default: 5)", ) parser.add_argument( "--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)", ) parser.add_argument( "--lr-warmup-decay", default=0.01, type=float, help="the decay for lr" ) parser.add_argument( "--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs", ) parser.add_argument( "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma", ) parser.add_argument( "--output-dir", default="./logs", type=str, help="path to save outputs" ) parser.add_argument( "--resume", default=None, type=str, help="path of checkpoint. If set to 'latest', it will try to load the latest checkpoint", ) parser.add_argument( "--start-epoch", default=0, type=int, metavar="N", help="start epoch" ) parser.add_argument( "--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true", ) parser.add_argument( "--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true", ) parser.add_argument( "--test-only", dest="test_only", help="Only test the model", action="store_true", ) parser.add_argument( "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true", ) parser.add_argument( "--auto-augment", default="ta_wide", type=str, help="auto augment policy (default: ta_wide)", ) parser.add_argument( "--random-erase", default=0.1, type=float, help="random erasing probability (default: 0.1)", ) # Mixed precision training parameters # distributed training parameters parser.add_argument( "--world-size", default=1, type=int, help="number of distributed processes" ) parser.add_argument( "--dist-url", default="env://", type=str, help="url used to set up distributed training", ) parser.add_argument( "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters", ) parser.add_argument( "--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)", ) parser.add_argument( "--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", ) parser.add_argument( "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)", ) parser.add_argument( "--val-resize-size", default=232, type=int, help="the resize size used for validation (default: 232)", ) parser.add_argument( "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)", ) parser.add_argument( "--train-crop-size", default=176, type=int, help="the random crop size used for training (default: 176)", ) parser.add_argument( "--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)", ) parser.add_argument( "--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training", ) parser.add_argument( "--ra-reps", default=4, type=int, help="number of repetitions for Repeated Augmentation (default: 4)", ) # Prototype models only parser.add_argument( "--prototype", dest="prototype", help="Use prototype model builders instead those from main area", action="store_true", ) parser.add_argument( "--weights", default=None, type=str, help="the weights enum name to load" ) parser.add_argument("--seed", default=2020, type=int, help="the random seed") parser.add_argument( "--print-logdir", action="store_true", help="print the dirs for tensorboard logs and pt files and exit", ) parser.add_argument( "--clean", action="store_true", help="delete the dirs for tensorboard logs and pt files", ) parser.add_argument( "--disable-pinmemory", action="store_true", help="not use pin memory in dataloader, which can help reduce memory consumption", ) parser.add_argument( "--persistent-workers", action="store_true", help="keep dataloader workers alive across epochs for better throughput", ) parser.add_argument( "--prefetch-factor", default=2, type=int, help="number of batches prefetched by each dataloader worker (default: 2)", ) parser.add_argument( "--disable-amp", action="store_true", help="not use automatic mixed precision training", ) parser.add_argument( "--compile", action="store_true", help="compile the training model with torch.compile", ) parser.add_argument( "--compile-backend", default="inductor", type=str, help="backend passed to torch.compile (default: inductor)", ) parser.add_argument( "--compile-mode", default=None, type=str, help="optional mode passed to torch.compile, e.g. default or max-autotune", ) parser.add_argument( "--compile-disable-cudagraphs", action="store_true", help="disable Inductor cudagraph options for better cross-version stability", ) parser.add_argument( "--compile-eval", action="store_true", help="also compile evaluation instead of using the eager model for validation", ) parser.add_argument( "--local_rank", type=int, help="args for DDP, which should not be set by user", ) parser.add_argument( "--disable-uda", action="store_true", help="not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation", ) return parser
if __name__ == "__main__": trainer = Trainer() args = trainer.get_args_parser().parse_args() trainer.main(args)