spikingjelly.activation_based.model.train_classify 源代码

import datetime
import os
import time
import warnings
from .tv_ref_classify import presets, transforms, utils
import torch
import torch.utils.data
import torchvision
from .tv_ref_classify.sampler import RASampler
from torch import nn
from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
import random
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import sys
import argparse
from .. import functional


try:
    from torchvision import prototype
except ImportError:
    prototype = None

[文档]def set_deterministic(_seed_: int = 2020, disable_uda=False): random.seed(_seed_) np.random.seed(_seed_) torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA) torch.cuda.manual_seed_all(_seed_) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if disable_uda: pass else: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB). torch.use_deterministic_algorithms(True)
[文档]def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)
[文档]class Trainer:
[文档] def cal_acc1_acc5(self, output, target): # define how to calculate acc1 and acc5 acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) return acc1, acc5
[文档] def preprocess_train_sample(self, args, x: torch.Tensor): # define how to process train sample before send it to model return x
[文档] def preprocess_test_sample(self, args, x: torch.Tensor): # define how to process test sample before send it to model return x
[文档] def process_model_output(self, args, y: torch.Tensor): # define how to process y = model(x) return y
[文档] def train_one_epoch(self, model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) header = f"Epoch: [{epoch}]" for i, (image, target) in enumerate(metric_logger.log_every(data_loader, -1, header)): start_time = time.time() image, target = image.to(device), target.to(device) with torch.cuda.amp.autocast(enabled=scaler is not None): image = self.preprocess_train_sample(args, image) output = self.process_model_output(args, model(image)) loss = criterion(output, target) optimizer.zero_grad() if scaler is not None: scaler.scale(loss).backward() if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) scaler.step(optimizer) scaler.update() else: loss.backward() if args.clip_grad_norm is not None: nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm) optimizer.step() functional.reset_net(model) if model_ema and i % args.model_ema_steps == 0: model_ema.update_parameters(model) if epoch < args.lr_warmup_epochs: # Reset ema buffer to keep copying weights during warmup period model_ema.n_averaged.fill_(0) acc1, acc5 = self.cal_acc1_acc5(output, target) batch_size = target.shape[0] metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) # gather the stats from all processes metric_logger.synchronize_between_processes() train_loss, train_acc1, train_acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg print(f'Train: train_acc1={train_acc1:.3f}, train_acc5={train_acc5:.3f}, train_loss={train_loss:.6f}, samples/s={metric_logger.meters["img/s"]}') return train_loss, train_acc1, train_acc5
[文档] def evaluate(self, args, model, criterion, data_loader, device, log_suffix=""): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = f"Test: {log_suffix}" num_processed_samples = 0 start_time = time.time() with torch.inference_mode(): for image, target in metric_logger.log_every(data_loader, -1, header): image = image.to(device, non_blocking=True) target = target.to(device, non_blocking=True) image = self.preprocess_test_sample(args, image) output = self.process_model_output(args, model(image)) loss = criterion(output, target) acc1, acc5 = self.cal_acc1_acc5(output, target) # FIXME need to take into account that the datasets # could have been padded in distributed setup batch_size = target.shape[0] metric_logger.update(loss=loss.item()) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) num_processed_samples += batch_size functional.reset_net(model) # gather the stats from all processes num_processed_samples = utils.reduce_across_processes(num_processed_samples) if ( hasattr(data_loader.dataset, "__len__") and len(data_loader.dataset) != num_processed_samples and torch.distributed.get_rank() == 0 ): # See FIXME above warnings.warn( f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} " "samples were used for the validation, which might bias the results. " "Try adjusting the batch size and / or the world size. " "Setting the world size to 1 is always a safe bet." ) metric_logger.synchronize_between_processes() test_loss, test_acc1, test_acc5 = metric_logger.loss.global_avg, metric_logger.acc1.global_avg, metric_logger.acc5.global_avg print(f'Test: test_acc1={test_acc1:.3f}, test_acc5={test_acc5:.3f}, test_loss={test_loss:.6f}, samples/s={num_processed_samples / (time.time() - start_time):.3f}') return test_loss, test_acc1, test_acc5
def _get_cache_path(self, filepath): import hashlib h = hashlib.sha1(filepath.encode()).hexdigest() cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt") cache_path = os.path.expanduser(cache_path) return cache_path
[文档] def load_data(self, args): return self.load_ImageNet(args)
[文档] def load_CIFAR10(self, args): # Data loading code print("Loading data") val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.CIFAR10( root=args.data_path, train=True, transform=presets.ClassificationPresetTrain( crop_size=train_crop_size, mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010), interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), ) print("Took", time.time() - st) print("Loading validation data") dataset_test = torchvision.datasets.CIFAR10( root=args.data_path, train=False, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]), ) print("Creating data loaders") loader_g = torch.Generator() loader_g.manual_seed(args.seed) if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=args.seed) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
[文档] def load_ImageNet(self, args): # Data loading code traindir = os.path.join(args.data_path, "train") valdir = os.path.join(args.data_path, "val") print("Loading data") val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size interpolation = InterpolationMode(args.interpolation) print("Loading training data") st = time.time() cache_path = self._get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, presets.ClassificationPresetTrain( crop_size=train_crop_size, interpolation=interpolation, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob, ), ) if args.cache_dataset: print(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) print("Took", time.time() - st) print("Loading validation data") cache_path = self._get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: if not args.prototype: preprocessing = presets.ClassificationPresetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: if args.weights: weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: preprocessing = prototype.transforms.ImageNetEval( crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) dataset_test = torchvision.datasets.ImageFolder( valdir, preprocessing, ) if args.cache_dataset: print(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") loader_g = torch.Generator() loader_g.manual_seed(args.seed) if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps, seed=args.seed) else: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset, seed=args.seed) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset, generator=loader_g) test_sampler = torch.utils.data.SequentialSampler(dataset_test) return dataset, dataset_test, train_sampler, test_sampler
[文档] def load_model(self, args, num_classes): raise NotImplementedError("Users should define this function to load model")
[文档] def get_tb_logdir_name(self, args): tb_dir = f'{args.model}' \ f'_b{args.batch_size}' \ f'_e{args.epochs}' \ f'_{args.opt}' \ f'_lr{args.lr}' \ f'_wd{args.weight_decay}' \ f'_ls{args.label_smoothing}' \ f'_ma{args.mixup_alpha}' \ f'_ca{args.cutmix_alpha}' \ f'_sbn{1 if args.sync_bn else 0}' \ f'_ra{args.ra_reps if args.ra_sampler else 0}' \ f'_re{args.random_erase}' \ f'_aaug{args.auto_augment}' \ f'_size{args.train_crop_size}_{args.val_resize_size}_{args.val_crop_size}' \ f'_seed{args.seed}' return tb_dir
[文档] def set_optimizer(self, args, parameters): opt_name = args.opt.lower() if opt_name.startswith("sgd"): optimizer = torch.optim.SGD( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov="nesterov" in opt_name, ) elif opt_name == "rmsprop": optimizer = torch.optim.RMSprop( parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9 ) elif opt_name == "adamw": optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay) else: optimizer = None return optimizer
[文档] def set_lr_scheduler(self, args, optimizer): args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "step": main_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) elif args.lr_scheduler == "cosa": main_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs - args.lr_warmup_epochs ) elif args.lr_scheduler == "exp": main_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.lr_gamma) else: main_lr_scheduler = None if args.lr_warmup_epochs > 0: if args.lr_warmup_method == "linear": warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs ) elif args.lr_warmup_method == "constant": warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=args.lr_warmup_decay, total_iters=args.lr_warmup_epochs ) else: warmup_lr_scheduler = None lr_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[args.lr_warmup_epochs] ) else: lr_scheduler = main_lr_scheduler return lr_scheduler
[文档] def main(self, args): set_deterministic(args.seed, args.disable_uda) if args.prototype and prototype is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") if not args.prototype and args.weights: raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") if args.output_dir: utils.mkdir(args.output_dir) utils.init_distributed_mode(args) print(args) device = torch.device(args.device) dataset, dataset_test, train_sampler, test_sampler = self.load_data(args) collate_fn = None num_classes = len(dataset.classes) mixup_transforms = [] if args.mixup_alpha > 0.0: if torch.__version__ >= torch.torch_version.TorchVersion('1.10.0'): pass else: # TODO implement a CrossEntropyLoss to support for probabilities for each class. raise NotImplementedError("CrossEntropyLoss in pytorch < 1.11.0 does not support for probabilities for each class." "Set mixup_alpha=0. to avoid such a problem or update your pytorch.") mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) if args.cutmix_alpha > 0.0: mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=not args.disable_pinmemory, collate_fn=collate_fn, worker_init_fn=seed_worker ) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=not args.disable_pinmemory, worker_init_fn=seed_worker ) print("Creating model") model = self.load_model(args, num_classes) model.to(device) print(model) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) if args.norm_weight_decay is None: parameters = model.parameters() else: param_groups = torchvision.ops._utils.split_normalization_params(model) wd_groups = [args.norm_weight_decay, args.weight_decay] parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p] optimizer = self.set_optimizer(args, parameters) if args.disable_amp: scaler = None else: scaler = torch.cuda.amp.GradScaler() lr_scheduler = self.set_lr_scheduler(args, optimizer) model_without_ddp = model if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module model_ema = None if args.model_ema: # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at: # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 # # total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps) # We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus: # adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha) # 确定目录文件名 tb_dir = self.get_tb_logdir_name(args) pt_dir = os.path.join(args.output_dir, 'pt', tb_dir) tb_dir = os.path.join(args.output_dir, tb_dir) if args.print_logdir: print(tb_dir) print(pt_dir) exit() if args.clean: if utils.is_main_process(): if os.path.exists(tb_dir): os.remove(tb_dir) if os.path.exists(pt_dir): os.remove(pt_dir) print(f'remove {tb_dir} and {pt_dir}.') if utils.is_main_process(): os.makedirs(tb_dir, exist_ok=args.resume is not None) os.makedirs(pt_dir, exist_ok=args.resume is not None) if args.resume is not None: if args.resume == 'latest': checkpoint = torch.load(os.path.join(pt_dir, 'checkpoint_latest.pth'), map_location="cpu") else: checkpoint = torch.load(args.resume, map_location="cpu") model_without_ddp.load_state_dict(checkpoint["model"]) if not args.test_only: optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 if model_ema: model_ema.load_state_dict(checkpoint["model_ema"]) if scaler: scaler.load_state_dict(checkpoint["scaler"]) if utils.is_main_process(): max_test_acc1 = checkpoint['max_test_acc1'] if model_ema: max_ema_test_acc1 = checkpoint['max_ema_test_acc1'] if utils.is_main_process(): tb_writer = SummaryWriter(tb_dir, purge_step=args.start_epoch) with open(os.path.join(tb_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: args_txt.write(str(args)) args_txt.write('\n') args_txt.write(' '.join(sys.argv)) max_test_acc1 = -1. if model_ema: max_ema_test_acc1 = -1. if args.test_only: if model_ema: self.evaluate(args, model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") else: self.evaluate(args, model, criterion, data_loader_test, device=device) return for epoch in range(args.start_epoch, args.epochs): start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) self.before_train_one_epoch(args, model, epoch) train_loss, train_acc1, train_acc5 = self.train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler) if utils.is_main_process(): tb_writer.add_scalar('train_loss', train_loss, epoch) tb_writer.add_scalar('train_acc1', train_acc1, epoch) tb_writer.add_scalar('train_acc5', train_acc5, epoch) lr_scheduler.step() self.before_test_one_epoch(args, model, epoch) test_loss, test_acc1, test_acc5 = self.evaluate(args, model, criterion, data_loader_test, device=device) if utils.is_main_process(): tb_writer.add_scalar('test_loss', test_loss, epoch) tb_writer.add_scalar('test_acc1', test_acc1, epoch) tb_writer.add_scalar('test_acc5', test_acc5, epoch) if model_ema: ema_test_loss, ema_test_acc1, ema_test_acc5 = self.evaluate(args, model_ema, criterion, data_loader_test, device=device, log_suffix="EMA") if utils.is_main_process(): tb_writer.add_scalar('ema_test_loss', ema_test_loss, epoch) tb_writer.add_scalar('ema_test_acc1', ema_test_acc1, epoch) tb_writer.add_scalar('ema_test_acc5', ema_test_acc5, epoch) if utils.is_main_process(): save_max_test_acc1 = False save_max_ema_test_acc1 = False if test_acc1 > max_test_acc1: max_test_acc1 = test_acc1 save_max_test_acc1 = True checkpoint = { "model": model_without_ddp.state_dict(), "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict(), "epoch": epoch, "args": args, "max_test_acc1": max_test_acc1, } if model_ema: if ema_test_acc1 > max_ema_test_acc1: max_ema_test_acc1 = ema_test_acc1 save_max_ema_test_acc1 = True checkpoint["model_ema"] = model_ema.state_dict() checkpoint["max_ema_test_acc1"] = max_ema_test_acc1 if scaler: checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(pt_dir, f"checkpoint_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(pt_dir, "checkpoint_latest.pth")) if save_max_test_acc1: utils.save_on_master(checkpoint, os.path.join(pt_dir, f"checkpoint_max_test_acc1.pth")) if model_ema and save_max_ema_test_acc1: utils.save_on_master(checkpoint, os.path.join(pt_dir, f"checkpoint_max_ema_test_acc1.pth")) if utils.is_main_process() and epoch > 0: os.remove(os.path.join(pt_dir, f"checkpoint_{epoch - 1}.pth")) print(f'escape time={(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n') print(args)
[文档] def before_test_one_epoch(self, args, model, epoch): pass
[文档] def before_train_one_epoch(self, args, model, epoch): pass
[文档] def get_args_parser(self, add_help=True): parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) parser.add_argument("--data-path", default="/datasets01/imagenet_full_size/061417/", type=str, help="dataset path") parser.add_argument("--model", default="resnet18", type=str, help="model name") parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( "-b", "--batch-size", default=32, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" ) parser.add_argument("--epochs", default=90, type=int, metavar="N", help="number of total epochs to run") parser.add_argument( "-j", "--workers", default=16, type=int, metavar="N", help="number of data loading workers (default: 16)" ) parser.add_argument("--opt", default="sgd", type=str, help="optimizer") parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate") parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") parser.add_argument( "--wd", "--weight-decay", default=0., type=float, metavar="W", help="weight decay (default: 0.)", dest="weight_decay", ) parser.add_argument( "--norm-weight-decay", default=None, type=float, help="weight decay for Normalization layers (default: None, same value as --wd)", ) parser.add_argument( "--label-smoothing", default=0.1, type=float, help="label smoothing (default: 0.1)", dest="label_smoothing" ) parser.add_argument("--mixup-alpha", default=0.2, type=float, help="mixup alpha (default: 0.2)") parser.add_argument("--cutmix-alpha", default=1.0, type=float, help="cutmix alpha (default: 1.0)") parser.add_argument("--lr-scheduler", default="cosa", type=str, help="the lr scheduler (default: cosa)") parser.add_argument("--lr-warmup-epochs", default=5, type=int, help="the number of epochs to warmup (default: 5)") parser.add_argument( "--lr-warmup-method", default="linear", type=str, help="the warmup method (default: linear)" ) parser.add_argument("--lr-warmup-decay", default=0.01, type=float, help="the decay for lr") parser.add_argument("--lr-step-size", default=30, type=int, help="decrease lr every step-size epochs") parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma") parser.add_argument("--output-dir", default="./logs", type=str, help="path to save outputs") parser.add_argument("--resume", default=None, type=str, help="path of checkpoint. If set to 'latest', it will try to load the latest checkpoint") parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") parser.add_argument( "--cache-dataset", dest="cache_dataset", help="Cache the datasets for quicker initialization. It also serializes the transforms", action="store_true", ) parser.add_argument( "--sync-bn", dest="sync_bn", help="Use sync batch norm", action="store_true", ) parser.add_argument( "--test-only", dest="test_only", help="Only test the model", action="store_true", ) parser.add_argument( "--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", action="store_true", ) parser.add_argument("--auto-augment", default='ta_wide', type=str, help="auto augment policy (default: ta_wide)") parser.add_argument("--random-erase", default=0.1, type=float, help="random erasing probability (default: 0.1)") # Mixed precision training parameters # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") parser.add_argument( "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" ) parser.add_argument( "--model-ema-steps", type=int, default=32, help="the number of iterations that controls how often to update the EMA model (default: 32)", ) parser.add_argument( "--model-ema-decay", type=float, default=0.99998, help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)", ) parser.add_argument( "--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)" ) parser.add_argument( "--val-resize-size", default=232, type=int, help="the resize size used for validation (default: 232)" ) parser.add_argument( "--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)" ) parser.add_argument( "--train-crop-size", default=176, type=int, help="the random crop size used for training (default: 176)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") parser.add_argument("--ra-sampler", action="store_true", help="whether to use Repeated Augmentation in training") parser.add_argument( "--ra-reps", default=4, type=int, help="number of repetitions for Repeated Augmentation (default: 4)" ) # Prototype models only parser.add_argument( "--prototype", dest="prototype", help="Use prototype model builders instead those from main area", action="store_true", ) parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--seed", default=2020, type=int, help="the random seed") parser.add_argument("--print-logdir", action="store_true", help="print the dirs for tensorboard logs and pt files and exit") parser.add_argument("--clean", action="store_true", help="delete the dirs for tensorboard logs and pt files") parser.add_argument("--disable-pinmemory", action="store_true", help="not use pin memory in dataloader, which can help reduce memory consumption") parser.add_argument("--disable-amp", action="store_true", help="not use automatic mixed precision training") parser.add_argument("--local_rank", type=int, help="args for DDP, which should not be set by user") parser.add_argument("--disable-uda", action="store_true", help="not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation") return parser
if __name__ == "__main__": trainer = Trainer() args = trainer.get_args_parser().parse_args() trainer.main(args)