import torch
from torch import nn
from torch.optim.optimizer import Optimizer
import math
[文档]class AdamRewiring(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, T=1e-5, l1=1e-5):
'''
.. attention::
该算法的收敛性尚未得到任何证明,以及在基于softbp的SNN上的剪枝可靠性也未知。
:param params: (原始Adam)网络参数的迭代器,或者由字典定义的参数组
:param lr: (原始Adam)学习率
:param betas: (原始Adam)用于计算运行时梯度平均值的以及平均值平方的两个参数
:param eps: (原始Adam)除法计算时,加入到分母中的小常数,用于提高数值稳定性
:param weight_decay: (原始Adam)L2范数惩罚因子
:param amsgrad: (原始Adam)是否使用AMSGrad算法
:param T: Deep R算法中的温度参数
:param l1: Deep R算法中的L1惩罚参数
G. Bellec et al, "Deep Rewiring: Training very sparse deep networks," ICLR 2018.
该实现将论文中的基于SGD优化算法的 `Deep R`_ 算法移植到 `Adam: A Method for Stochastic Optimization`_ 优化算法上,是基于Adam算法在Pytorch中的 `官方实现`_ 修改而来。
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _官方实现:
https://github.com/pytorch/pytorch/blob/6e2bb1c05442010aff90b413e21fce99f0393727/torch/optim/adam.py
.. _Deep R:
https://openreview.net/pdf?id=BJ_wN01C-
'''
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, T=T, l1=l1)
super(AdamRewiring, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamRewiring, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
[文档] @torch.no_grad()
def step(self, closure=None):
'''
:param closure: (原始Adam)传入的闭包,可用于评估模型并返回损失
执行单步参数更新
'''
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# 记录各参数初始符号
state['sign'] = torch.sign(p)
# 记录被置零(休眠状态)的参数mask
state['dormant'] = (p != 0.0).float()
dormant = state['dormant']
sgn = state['sign']
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
step_size = group['lr'] / bias_correction1
p.addcdiv_(exp_avg, denom, value=-step_size)
# l1
p.add_(-group['l1'] * step_size * sgn)
# 扰动项
rand_normal = torch.randn_like(p)
p.add_(rand_normal * group['T'] * step_size)
# 裁剪越过0的参数:保证各参数符号与sgn中对应的初始符号始终一致,否则变为0
p.mul_(dormant).mul_(sgn).clamp_(min=0.0).mul_(sgn)
state['dormant'] = (p != 0.0).float()
return loss