import numpy as np
import torch
import torch.nn as nn
import copy
from collections import defaultdict
[文档]def layer_reduction(model: nn.Module) -> nn.Module:
relu_linker = {} # 字典类型,用于通过relu层在network中的序号确定relu前参数化模块的序号
param_module_relu_linker = {} # 字典类型,用于通过relu前在network中的参数化模块的序号确定relu层序号
activation_range = defaultdict(float) # 字典类型,保存在network中的序号对应层的激活最大值(或某分位点值)
module_len = 0
module_list = nn.ModuleList([])
last_parammodule_idx = 0
for n, m in model.named_modules():
Name = m.__class__.__name__
# 加载激活层
if isinstance(m,nn.Softmax):
Name = 'ReLU'
print(UserWarning("Replacing Softmax by ReLU."))
if isinstance(m,nn.ReLU) or Name == "ReLU":
module_list.append(m)
relu_linker[module_len] = last_parammodule_idx
param_module_relu_linker[last_parammodule_idx] = module_len
module_len += 1
activation_range[module_len] = -1e5
# 加载BatchNorm层
if isinstance(m,(nn.BatchNorm2d,nn.BatchNorm1d)):
if isinstance(module_list[last_parammodule_idx], (nn.Conv2d,nn.Linear)):
absorb(module_list[last_parammodule_idx], m)
else:
module_list.append(copy.deepcopy(m))
# 加载有参数的层
if isinstance(m,(nn.Conv2d,nn.Linear)):
module_list.append(m)
last_parammodule_idx = module_len
module_len += 1
# 加载无参数层
if isinstance(m,nn.MaxPool2d):
module_list.append(m)
module_len += 1
if isinstance(m,nn.AvgPool2d):
module_list.append(nn.AvgPool2d(kernel_size=m.kernel_size, stride=m.stride, padding=m.padding))
module_len += 1
# if isinstance(m,nn.Flatten):
if m.__class__.__name__ == "Flatten":
module_list.append(m)
module_len += 1
network = torch.nn.Sequential(*module_list)
setattr(network,'param_module_relu_linker',param_module_relu_linker)
setattr(network, 'activation_range', activation_range)
return network
[文档]def rate_normalization(model: nn.Module, data: torch.Tensor, **kargs) -> nn.Module:
if not hasattr(model,"activation_range") or not hasattr(model,"param_module_relu_linker"):
raise(AttributeError("run layer_reduction first!"))
try:
robust_norm = kargs['robust']
except KeyError:
robust_norm = False
x = data
i = 0
with torch.no_grad():
for n, m in model.named_modules():
Name = m.__class__.__name__
if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
x = m.forward(x)
a = x.cpu().numpy().reshape(-1)
if robust_norm:
model.activation_range[i] = np.percentile(a[np.nonzero(a)], 99.9)
else:
model.activation_range[i] = np.max(a)
i += 1
i = 0
last_lambda = 1.0
for n, m in model.named_modules():
Name = m.__class__.__name__
if Name in ['Conv2d', 'ReLU', 'MaxPool2d', 'AvgPool2d', 'Flatten', 'Linear']:
if Name in ['Conv2d', 'Linear']:
relu_output_layer = model.param_module_relu_linker[i]
if hasattr(m, 'weight') and m.weight is not None:
m.weight.data = m.weight.data * last_lambda / model.activation_range[relu_output_layer]
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data = m.bias.data / model.activation_range[relu_output_layer]
last_lambda = model.activation_range[relu_output_layer]
i += 1
return model
[文档]def save_model(model: nn.Module, f):
if isinstance(f,str):
torch.save(model,f)
return
[文档]def absorb(param_module, bn_module):
if_2d = len(param_module.weight.size()) == 4 # 判断是否为BatchNorm2d
bn_std = torch.sqrt(bn_module.running_var.data + bn_module.eps)
if not if_2d:
if param_module.bias is not None:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
-1,
1)
param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(
-1) + bn_module.bias.data.view(-1)
else:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1) / bn_std.view(
-1,
1)
param_module.bias.data = (torch.zeros_like(
bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
else:
if param_module.bias is not None:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
1) / bn_std.view(-1, 1,
1, 1)
param_module.bias.data = (param_module.bias.data - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(
-1) + bn_module.bias.data.view(-1)
else:
param_module.weight.data = param_module.weight.data * bn_module.weight.data.view(-1, 1, 1,
1) / bn_std.view(-1, 1,
1, 1)
param_module.bias.data = (torch.zeros_like(
bn_module.running_mean.data.view(-1)) - bn_module.running_mean.data.view(
-1)) * bn_module.weight.data.view(-1) / bn_std.view(-1) + bn_module.bias.data.view(-1)
return param_module