SpikingFlow.softbp.optim 源代码

import torch
from torch import nn
from torch.optim.optimizer import Optimizer
import math
[文档]class AdamRewiring(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, T=1e-5, l1=1e-5): ''' .. attention:: 该算法的收敛性尚未得到任何证明,以及在基于softbp的SNN上的剪枝可靠性也未知。 :param params: (原始Adam)网络参数的迭代器,或者由字典定义的参数组 :param lr: (原始Adam)学习率 :param betas: (原始Adam)用于计算运行时梯度平均值的以及平均值平方的两个参数 :param eps: (原始Adam)除法计算时,加入到分母中的小常数,用于提高数值稳定性 :param weight_decay: (原始Adam)L2范数惩罚因子 :param amsgrad: (原始Adam)是否使用AMSGrad算法 :param T: Deep R算法中的温度参数 :param l1: Deep R算法中的L1惩罚参数 G. Bellec et al, "Deep Rewiring: Training very sparse deep networks," ICLR 2018. 该实现将论文中的基于SGD优化算法的 `Deep R`_ 算法移植到 `Adam: A Method for Stochastic Optimization`_ 优化算法上,是基于Adam算法在Pytorch中的 `官方实现`_ 修改而来。 .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _官方实现: https://github.com/pytorch/pytorch/blob/6e2bb1c05442010aff90b413e21fce99f0393727/torch/optim/adam.py .. _Deep R: https://openreview.net/pdf?id=BJ_wN01C- ''' if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, T=T, l1=l1) super(AdamRewiring, self).__init__(params, defaults) def __setstate__(self, state): super(AdamRewiring, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False)
[文档] @torch.no_grad() def step(self, closure=None): ''' :param closure: (原始Adam)传入的闭包,可用于评估模型并返回损失 执行单步参数更新 ''' loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) # 记录各参数初始符号 state['sign'] = torch.sign(p) # 记录被置零(休眠状态)的参数mask state['dormant'] = (p != 0.0).float() dormant = state['dormant'] sgn = state['sign'] exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] if group['weight_decay'] != 0: grad = grad.add(p, alpha=group['weight_decay']) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) else: denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) step_size = group['lr'] / bias_correction1 p.addcdiv_(exp_avg, denom, value=-step_size) # l1 p.add_(-group['l1'] * step_size * sgn) # 扰动项 rand_normal = torch.randn_like(p) p.add_(rand_normal * group['T'] * step_size) # 裁剪越过0的参数:保证各参数符号与sgn中对应的初始符号始终一致,否则变为0 p.mul_(dormant).mul_(sgn).clamp_(min=0.0).mul_(sgn) state['dormant'] = (p != 0.0).float() return loss