import onnx
import onnx.helper as helper
import onnx.numpy_helper as numpy_helper
import collections
import numpy as np
import torch
import torch.nn as nn
import os
import tqdm
import onnxruntime as ort
from collections import defaultdict
[文档]class Mul(nn.Module):
def __init__(self):
super().__init__()
[文档] def forward(self, input1, input2):
return input1 * input2
[文档]class Add(nn.Module):
def __init__(self):
super().__init__()
[文档] def forward(self,input1,input2):
return input1 + input2
[文档]class Reshape(nn.Module):
def __init__(self):
super().__init__()
[文档] def forward(self, input1, input2):
return torch.reshape(input1,shape=list(input2))
[文档]class Concat(nn.Module):
def __init__(self, dim=[1]):
super().__init__()
self.dim = dim
if not isinstance(self.dim, list):
self.dim = [self.dim]
for i, d in enumerate(self.dim):
if not isinstance(d, int):
self.dim[i] = int(d)
[文档] def forward(self, *args):
args = list(args)
for i,a in enumerate(args):
args[i] = a.type_as(args[0])
return torch.cat(args,dim=self.dim[0])
[文档]class Shape(nn.Module):
def __init__(self):
super().__init__()
[文档] def forward(self, input):
return torch.IntTensor([input.size(i) for i in range(len(input.size()))])
[文档]class Gather(nn.Module):
def __init__(self,dim=1):
super().__init__()
self.dim= int(dim)
[文档] def forward(self, input1, input2):
return torch.gather(input1,dim=self.dim,index=input2.cpu())
[文档]class Unsqueeze(nn.Module):
def __init__(self, dim=[1]):
super().__init__()
self.dim = dim
if not isinstance(self.dim, list):
self.dim = [self.dim]
for i,d in enumerate(self.dim):
if not isinstance(d,int):
self.dim[i] = int(d)
[文档] def forward(self, input):
x = input
for i in self.dim:
x = torch.unsqueeze(x,dim=i)
return x
[文档]class TopologyAnalyser:
def __init__(self):
'''
* :ref:`API in English <TopologyAnalyser.__init__-en>`
.. _TopologyAnalyser.__init__-cn:
这个类通过onnx分析模型的拓扑结构,方便后续处理
此处还有更多更好的实现方法,欢迎开发者不断优化
* :ref:`API in English <TopologyAnalyser.__init__-cn>`
.. _TopologyAnalyser.__init__-en:
This class analyzes the topological structure of the model through onnx to facilitate subsequent processing
There are better implementation methods here, developers are welcome to continue to optimize
'''
self.data_nodes = []
self.module_output = collections.OrderedDict()
self.module_op = collections.OrderedDict()
self.module_idx = collections.OrderedDict()
self.param_idx = collections.OrderedDict()
self.edge = collections.OrderedDict()
self.reverse_edge = collections.OrderedDict() # 快速计算前驱结点
[文档] def add_data_node(self, a):
if not a in self.data_nodes:
self.data_nodes.append(a)
[文档] def insert(self, a, b, info=None):
self.add_data_node(a)
self.add_data_node(b)
if a not in self.edge.keys():
self.edge[a] = [(b, info)]
else:
self.edge[a].append((b, info))
if b not in self.reverse_edge.keys():
self.reverse_edge[b] = [a]
else:
self.reverse_edge[b].append(a)
[文档] def findNext(self, id):
if isinstance(id, str):
if id in self.edge.keys():
return self.edge[id]
else:
return []
elif isinstance(id, list):
l = []
for i in id:
l += self.findNext(i)
return l
[文档] def findPre(self, id):
l = []
if isinstance(id, str):
for pre_id in self.reverse_edge[id]:
if pre_id in self.reverse_edge.keys():
for pre_pre_id in self.reverse_edge[pre_id]:
if pre_pre_id in self.edge.keys():
for item in self.edge[pre_pre_id]:
if item[0] == pre_id:
l.append(item)
elif isinstance(id, list):
for i in id:
l += self.findPre(i)
return l
[文档] def find_pre_module(self, module_name):
if module_name in self.module_output.keys():
ids = self.module_output[module_name]
return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findPre(ids)])
else:
return set()
[文档] def find_next_module(self, module_name):
if module_name in self.module_output.keys():
ids = self.module_output[module_name]
return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findNext(ids)])
else:
return set()
[文档] def update_module_idx(self, onnx_graph):
for idx, n in enumerate(onnx_graph.node):
trainable_input = n.input[1:]
op = n.op_type
k = set()
for i in trainable_input:
n = self._get_module_name_from_value_name(i)
if n is not None:
k.add(n)
if len(k) > 1:
# TODO: sanity check, raise error
pass
if len(k) == 1:
param_module_name = list(k)[0]
self.module_op[param_module_name] = op
self.module_idx[param_module_name] = idx
[文档] def analyse(self, onnx_graph): # 输入的onnx graph需要保证所以常量在已经在initializer中
# 先把该往initializer放下面的参数,保证下面只有运算没有常量
for idx, constant in enumerate(onnx_graph.initializer):
self.param_idx[constant.name] = idx
for idx, n in enumerate(onnx_graph.node):
param_module_name = None
op = n.op_type
inputs = n.input
outputs = n.output
# print(inputs,outputs)
k = set()
trainable_input = inputs[1:]
for i in trainable_input:
n = self._get_module_name_from_value_name(i)
if n is not None:
k.add(n)
if len(k) > 1:
# TODO: sanity check, raise error
pass
if len(k) == 1:
param_module_name = list(k)[0]
self.module_op[param_module_name] = op
self.module_idx[param_module_name] = idx
if op is not None:
for o in outputs:
for i in inputs:
self.insert(i, o, {'op': op, 'param_module_name': param_module_name})
if param_module_name is not None:
if param_module_name not in self.module_output.keys():
self.module_output[param_module_name] = [o]
else:
self.module_output[param_module_name].append(o)
return self
@staticmethod
def _get_module_name_from_value_name(value_name):
module_name = None
if len(value_name.split('.')) > 1:
l = value_name.split('.')[:-1]
l = '.'.join(l)
module_name = l # [1:]
module_name.replace(' ', '')
return module_name
[文档]def pytorch2onnx_model(model: nn.Module, data, **kargs) -> onnx.ModelProto:
'''
* :ref:`API in English <pytorch2onnx_model-en>`
.. _pytorch2onnx_model-cn:
:param model: 待转换的PyTorch模型
:param data: 用于转换的数据(用来确定输入维度)
:param log_dir: 输出文件夹
转换PyTorch模型到onnx模型
* :ref:`API in English <pytorch2onnx_model-cn>`
.. _pytorch2onnx_model-en:
:param model: the PyTorch model to be converted
:param data: The data used for conversion (used to determine the input dimension)
:param log_dir: output folder
Convert PyTorch model to onnx model
'''
try:
log_dir = kargs['log_dir']
except KeyError:
print('pytorch2onnx_model need argument log_dir!')
dump_input_size = [data.size(i) for i in range(len(data.size()))]
dump_input_size[0] = 1
fname = os.path.join(log_dir,'onnxmodel')
try:
dynamic_axes = {'input': {0: 'batch_size'},
'output': {0: 'batch_size'}}
torch.onnx.export(model, torch.ones(dump_input_size), fname,
input_names=['input'],
output_names=['output'],
dynamic_axes=dynamic_axes)
except BaseException:
raise NotImplementedError("Models with multiple inputs are not supported yet!")
return onnx.load(fname)
[文档]def onnx2pytorch_model(model: onnx.ModelProto, _converter) -> nn.Module:
model = _pt_model(model, _converter)
model = model.reduce()
return model
[文档]def layer_reduction(model: onnx.ModelProto) -> onnx.ModelProto:
graph = model.graph
topo_analyser = TopologyAnalyser()
graph = move_constant_to_initializer(graph)
topo_analyser.analyse(graph)
absorb_bn(graph, topo_analyser)
remove_unreferenced_initializer(graph)
update_topology(graph)
print("Finish layer reduction!")
return model
[文档]def rate_normalization(model: onnx.ModelProto, data: torch.Tensor, **kargs) -> onnx.ModelProto:
'''
* :ref:`API in English <rate_normalization-en>`
.. _rate_normalization-cn:
:param model: ANN模型,类型为onnx.ModelProto
:param data: 用于转换的数据,类型为torch.Tensor
:param channelwise: 如果为``True``,则控制激活幅值的统计是channelwise的;否则,控制激活幅值的统计是layerwise的
:param robust: 如果为``True``,则控制激活幅值的统计是激活的99.9百分位;否则,控制激活幅值的统计是激活的最值
:param eps: epsilon;未设置值时默认1e-5
发放率归一化
* :ref:`API in English <rate_normalization-cn>`
.. _rate_normalization-en:
:param model: ANN model, the type is onnx.ModelProto
:param data: the data used for conversion, the type is torch.Tensor
:param channelwise: If ``True`` , the statistics that control the activation amplitude are channelwise; otherwise, the statistics that control the activation amplitude are layerwise
:param robust: If ``True``, the statistic of the control activation amplitude is the 99.9th percentile of activation; otherwise, the statistic of the activation amplitude is the maximum value of activation
:param eps: epsilon; if no value is set, the default is 1e-5
normalize the firing rate
'''
try:
channelwise = kargs['channelwise']
except KeyError:
channelwise = False
try:
robust_norm = kargs['robust']
except KeyError:
robust_norm = False
try:
eps = kargs['eps']
except KeyError:
eps = 1e-5
topo_analyser = update_topology(model.graph)
output_debug = {}
output_statistics = get_intermediate_output_statistics(model, data,
channelwise=channelwise) # if want debug, debug=output_debug
model = normalize_model(model, output_statistics, topo_analyser, robust_norm=robust_norm,
channelwise=channelwise, eps=eps)
return model
[文档]def save_model(model: onnx.ModelProto, f=None):
fb = model.SerializeToString()
if f is not None:
if hasattr(f, 'write'):
f.write(fb)
else:
with open(f, "wb") as f:
f.write(fb)
return fb
[文档]def move_constant_to_initializer(graph):
constant_idx = []
for idx, n in enumerate(graph.node):
op = n.op_type
if op == 'Constant':
constant_idx.append(idx)
if len(constant_idx):
for idx in reversed(constant_idx):
n = graph.node[idx]
graph.initializer.append(
numpy_helper.from_array(numpy_helper.to_array(n.attribute[0].t), n.output[0]))
graph.node.remove(n)
return graph
[文档]def print_onnx_model(graph):
print(onnx.helper.printable_graph(graph))
[文档]def absorb_bn(graph, topo_analyser):
print("\nAbsorbing BatchNorm Parameters...\n")
for mn in tqdm.tqdm(reversed(topo_analyser.module_output.keys())):
if topo_analyser.module_op[mn] == 'BatchNormalization':
pre_m = topo_analyser.find_pre_module(mn)
next_m = topo_analyser.find_next_module(mn)
bn_weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[1]]
bn_weight = np.array(numpy_helper.to_array(graph.initializer[bn_weight_idx]))
bn_bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[2]]
bn_bias = np.array(numpy_helper.to_array(graph.initializer[bn_bias_idx]))
bn_mean_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[3]]
bn_mean = np.array(numpy_helper.to_array(graph.initializer[bn_mean_idx]))
bn_var_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[4]]
bn_var = np.array(numpy_helper.to_array(graph.initializer[bn_var_idx]))
bn_eps = graph.node[topo_analyser.module_idx[mn]].attribute[0].f
bn_std = np.sqrt(bn_var + bn_eps)
if len(pre_m) == 1 and list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm']:
pre_mn = list(pre_m)[0].split(':')[1]
weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[1]]
weight = np.array(numpy_helper.to_array(graph.initializer[weight_idx]))
if len(graph.node[topo_analyser.module_idx[pre_mn]].input) == 2:
bias = None
else:
bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[2]]
bias = np.array(numpy_helper.to_array(graph.initializer[bias_idx]))
wrsp_args = (-1, 1) if len(weight.shape) == 2 else (-1, 1, 1, 1)
weight_ = weight * bn_weight.reshape(*wrsp_args) / bn_std.reshape(*wrsp_args)
bias_ = ((bias if bias is not None else 0) - bn_mean.reshape(-1)) * bn_weight.reshape(
-1) / bn_std.reshape(-1) \
+ bn_bias.reshape(-1)
assert (list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm'])
args = {}
for attr in graph.node[topo_analyser.module_idx[pre_mn]].attribute:
args[attr.name] = helper.get_attribute_value(attr)
new_node = onnx.helper.make_node(
list(pre_m)[0].split(':')[0],
inputs=[graph.node[topo_analyser.module_idx[pre_mn]].input[0], pre_mn + ".new.weight", pre_mn + ".new.bias"],
outputs=[graph.node[topo_analyser.module_idx[mn]].output[0]],
**args
)
graph.initializer.append(numpy_helper.from_array(weight_.astype(np.float32), pre_mn + ".new.weight"))
graph.initializer.append(numpy_helper.from_array(bias_.astype(np.float32), pre_mn + ".new.bias"))
graph.node.remove(graph.node[topo_analyser.module_idx[pre_mn]])
graph.node.insert(topo_analyser.module_idx[pre_mn], new_node)
graph.node.remove(graph.node[topo_analyser.module_idx[mn]])
else:
weight_ = bn_weight / bn_std
bias_ = bn_bias - bn_weight * bn_mean / bn_std
name = graph.initializer[bn_weight_idx].name
graph.initializer.remove(graph.initializer[bn_weight_idx])
graph.initializer.insert(bn_weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
name = graph.initializer[bn_bias_idx].name
graph.initializer.remove(graph.initializer[bn_bias_idx])
graph.initializer.insert(bn_bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
name = graph.initializer[bn_mean_idx].name
graph.initializer.remove(graph.initializer[bn_mean_idx])
graph.initializer.insert(bn_mean_idx,
numpy_helper.from_array(np.zeros_like(bn_mean).astype(np.float32), name))
name = graph.initializer[bn_var_idx].name
graph.initializer.remove(graph.initializer[bn_var_idx])
graph.initializer.insert(bn_var_idx,
numpy_helper.from_array(np.ones_like(bn_var).astype(np.float32), name))
[文档]def remove_unreferenced_initializer(graph):
in_graph = set()
in_initializer = set()
for node in graph.node:
in_graph.update(node.input)
in_graph.update(node.output)
for init in graph.initializer:
in_initializer.add(init.name)
not_in_graph = in_initializer - in_graph
l = len(graph.initializer)
for i in range(l - 1, -1, -1):
if graph.initializer[i].name in not_in_graph:
graph.initializer.remove(graph.initializer[i])
[文档]def update_topology(graph):
topo_analyser = TopologyAnalyser()
move_constant_to_initializer(graph)
topo_analyser.analyse(graph)
return topo_analyser
[文档]def find_node_by_output(output_name, graph):
flag = False
idx, node = None, None
for idx, node in enumerate(graph.node):
if output_name in node.output:
flag = True
break
if not flag:
idx, node = None, None
return idx, node
[文档]def scale_node_weight_bias(topo_analyser, graph, node_idx, scale):
initializer = graph.initializer
node = graph.node[node_idx]
if len(node.input) < 2:
return
weight_idx = topo_analyser.param_idx[node.input[1]]
bias_idx = topo_analyser.param_idx[node.input[2]] if len(node.input) >= 3 else None
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None
w_scale = scale.reshape([*scale.shape] + [1 for _ in range(len(weight.shape) - 1)]) \
if len(scale.shape) == 1 else scale
b_scale = scale
weight_ = weight * w_scale
name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
if bias is not None:
bias_ = bias * b_scale
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
[文档]def get_onnx_output(model, numpy_tensor):
ort_session = ort.InferenceSession(model.SerializeToString())
outputs = ort_session.run(None, {'input': numpy_tensor})
return outputs
[文档]def normalize_model(model, output_statistics, topo_analyser, robust_norm=True, channelwise=False, eps=1e-5):
nodes = model.graph.node
graph = model.graph
initializer = model.graph.initializer
if robust_norm:
statistic_key = '99.9'
else:
statistic_key = 'max'
node_scaled_range = {}
seperate_scale = collections.OrderedDict()
print("\nNormalizing model...\n")
for node_idx, node in enumerate(tqdm.tqdm(nodes)):
output = node.output
input = node.input
op = node.op_type
if input[0] == 'input': # single input model
l = output_statistics[input[0]]['shape'][1]
node_scaled_range[input[0]] = np.ones(l) if channelwise else 1.0 \
* output_statistics[input[0]][
statistic_key]
if op in ['Conv', 'Gemm']:
weight_idx = topo_analyser.param_idx[input[1]]
bias_idx = topo_analyser.param_idx[input[2]] if len(input) == 3 else None
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None
l = output_statistics[output[0]]['shape'][1]
input_real_range = node_scaled_range[input[0]]
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
demand = np.ones(l) if channelwise else 1.0
w_scale = demand / (output_range + eps) * (
input_range / (input_real_range + eps)) if not channelwise else \
(demand / (output_range + eps)).reshape(-1, 1).dot(
(input_range / (input_real_range + eps)).reshape(1, -1))
w_scale = w_scale.reshape([*w_scale.shape, 1, 1]) if len(weight.shape) == 4 else w_scale
b_scale = 1 / (output_range + eps)
node_scaled_range[output[0]] = demand
weight_ = weight * w_scale
name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
if bias is not None:
bias_ = bias * b_scale
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
elif op == 'BatchNormalization': # var=1 mean=0
weight_idx = topo_analyser.param_idx[input[1]]
bias_idx = topo_analyser.param_idx[input[2]]
weight = np.array(numpy_helper.to_array(initializer[weight_idx]))
bias = np.array(numpy_helper.to_array(initializer[bias_idx]))
# node_scaled_range[output[0]] = node_scaled_range[input[0]] * self.output_statistics[input[0]][statistic_key] / self.output_statistics[output[0]][statistic_key]
# lamda_last = self.output_statistics[input[0]][statistic_key]
# lamda = self.output_statistics[output[0]][statistic_key]
# weight_ = weight * node_scaled_range[output[0]]
# bias_ = bias / lamda
print(output_statistics[output[0]])
input_real_range = node_scaled_range[input[0]]
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
demand = 1.0
w_scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))
b_scale = 1 / (output_range + eps)
node_scaled_range[output[0]] = demand
weight_ = weight * w_scale
bias_ = bias * b_scale
# print(output[0],op,input[0], input_range, output_range, demand, input_real_range, w_scale)
name = initializer[weight_idx].name
initializer.remove(initializer[weight_idx])
initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name))
name = initializer[bias_idx].name
initializer.remove(initializer[bias_idx])
initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
elif op == 'Add':
l = output_statistics[output[0]]['shape'][1]
demand = np.ones(l) if channelwise else 1.0
node_scaled_range[output[0]] = demand
output_range = output_statistics[output[0]][statistic_key]
# node_scaled_range[output[0]] = 1.0
# lamda = self.output_statistics[output[0]][statistic_key]
# lamda_lasts = {}
for i in input:
if i in output_statistics.keys():
# lamda_lasts[i] = self.output_statistics[i][statistic_key]
# scale = lamda_lasts[i] / lamda
input_real_range = node_scaled_range[i]
input_range = output_statistics[i][statistic_key]
scale = demand / (output_range + eps) * (input_range / (input_real_range + eps))
# print(output[0], op, i, input_range, output_range, demand, input_real_range, scale)
idx, _ = find_node_by_output(i, graph)
if idx is not None and nodes[idx].op_type in ['Conv', 'Gemm', 'BatchNormalization']:
scale_node_weight_bias(topo_analyser, graph, idx, scale)
else:
scale = scale.reshape(
[1, *scale.shape] + [1 for _ in range(len(output_statistics[i]['shape']) - 2)]) \
if len(scale.shape) == 1 else scale
initializer.append(numpy_helper.from_array(scale.astype(np.float32), "scale_" + i))
if idx not in seperate_scale.keys():
# seperate_scale[node_idx] = [(i,"scale_"+i,"scaled_"+i)]
seperate_scale[node_idx] = {i: ("scale_" + i, "scaled_" + i)}
else:
# seperate_scale[node_idx].append((i,"scale_"+i,"scaled_"+i))
seperate_scale[node_idx][i] = ("scale_" + i, "scaled_" + i)
pass
elif op in ['Gather', 'Unsqueeze', 'Shape', 'Concat']:
continue
# elif op == "Concat":
# raise NotImplementedError("Not supported %s yet!"%(op))
elif op == "Softmax":
raise NotImplementedError("Not supported %s yet!" % (op))
else: # single input single output module
# print(op,self.output_statistics[output[0]]['shape'])
input_range = output_statistics[input[0]][statistic_key]
output_range = output_statistics[output[0]][statistic_key]
input_scaled_range = node_scaled_range[input[0]]
output_scaled_range = input_scaled_range / (input_range + eps) * output_range
node_scaled_range[output[0]] = output_scaled_range
# print(output[0], op, input[0], input_range, output_range, output_scaled_range)
# print(op, node_scaled_range[output[0]],'=',input_scaled_range,'/',input_range,'*',output_range)
# else:
# raise NotImplementedError("Not supported yet! %s"%(op))
if len(seperate_scale.keys()) != 0:
print("Making new scale node...")
for node_idx in reversed(seperate_scale.keys()):
args = {}
for attr in nodes[node_idx].attribute:
args[attr.name] = helper.get_attribute_value(attr)
input = [str(i) if i not in seperate_scale[node_idx].keys() else seperate_scale[node_idx][i][1] \
for i in nodes[node_idx].input]
output = [str(i) for i in nodes[node_idx].output]
new_node = onnx.helper.make_node(
nodes[node_idx].op_type,
inputs=input,
outputs=output,
**args
)
nodes.remove(nodes[node_idx])
nodes.insert(node_idx, new_node)
for i in seperate_scale[node_idx].keys():
new_node = onnx.helper.make_node(
'Mul',
inputs=[seperate_scale[node_idx][i][0], i],
outputs=[seperate_scale[node_idx][i][1]]
)
nodes.insert(node_idx, new_node)
print("Finished normalizing model!")
return model
def _pre_onnx_shape_inference(model:onnx.ModelProto):
'''
为了对模型进行shape inference,需要先对onnxmodel运行此函数进行准备
To perform shape inference for model, need to run this function on onnxmodel to prepare
This function has referenced code in https://github.com/onnx/onnx/issues/2660#issuecomment-605874784
'''
if model.ir_version < 4:
return
def add_some_graph_info(graph:onnx.GraphProto):
inputs = {i.name for i in graph.input}
vi_dict = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
if init.name in inputs:
continue
vi = vi_dict.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
tensor_type = vi.type.tensor_type
if tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
tensor_type.elem_type = init.data_type
if not tensor_type.HasField("shape"):
tensor_type.shape.dim.extend([])
for dim in init.dims:
tensor_type.shape.dim.add().dim_value = dim
for node in graph.node:
for attr in node.attribute:
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_some_graph_info(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_some_graph_info(g)
return add_some_graph_info(model.graph)
class _pt_model(nn.Module):
def __init__(self, path_or_model, _converter=None):
super(_pt_model, self).__init__()
if path_or_model is not None:
if isinstance(path_or_model, str):
onnx_model = onnx.load(path_or_model)
else:
onnx_model = path_or_model
self.onnx_model = onnx_model
self.loaded_weights = load_parameters(self, onnx_model.graph.initializer)
self.module_list = nn.ModuleList([])
self.op_tree = {}
_pre_onnx_shape_inference(onnx_model)
inferred_model = onnx.shape_inference.infer_shapes(onnx_model)
self.value_info = inferred_model.graph.value_info
self.dim_info = {}
for idx, v in enumerate(self.value_info):
self.dim_info[v.name] = len(v.type.tensor_type.shape.dim)
self.graph = defaultdict(list)
# self.V = set()
for idx, node in enumerate(onnx_model.graph.node):
op = node.op_type
# if op=='MatMul': # TODO temporary
# op = 'Gemm'
(op_idx, inputs, outputs) = getattr(_converter, 'convert_' + op.lower())(node, self)
for out_seq, output in enumerate(outputs):
# self.op_tree[output] = (op_idx, inputs, out_seq)
self.op_tree[output] = (op_idx, inputs, out_seq)
for output in outputs:
for input in inputs:
self.graph[input].append(output)
# self.V.update(inputs)
# self.V.update(outputs)
# print(self.V)
self.input_name = [i.name for i in onnx_model.graph.input]
self.output_name = [i.name for i in onnx_model.graph.output]
for out in onnx_model.graph.output:
self.graph[out.name] = []
def TopologicalSort(G):
in_degrees = dict((u, 0) for u in G)
for u in G:
for v in G[u]:
in_degrees[v] += 1
Q = [u for u in G if in_degrees[u] == 0]
res = []
while Q:
u = Q.pop()
res.append(u)
for v in G[u]:
in_degrees[v] -= 1
if in_degrees[v] == 0:
Q.append(v)
return res
self.compute_seq = TopologicalSort(self.graph)
# print(self.compute_seq)
# self.tensors = {}
for k in self.loaded_weights.keys():
if isinstance(self.loaded_weights[k], torch.FloatTensor):
setattr(self, 'P' + k.replace('.', '@'), torch.nn.Parameter(self.loaded_weights[k]))
else:
# print(self.loaded_weights[k])
self.register_buffer('P' + k.replace('.', '@'), self.loaded_weights[k])
self.reserved_tensors_name = list(self.loaded_weights.keys())
# print(self.reserved_tensors_name)
# print(self.tensors)
def refresh_running_tensor(self):
self.tensors = {}
for k in set(self.tensors.keys()) | set(self.reserved_tensors_name):
if k not in self.reserved_tensors_name:
del self.tensors[k]
else:
self.tensors[k] = getattr(self, 'P' + k.replace('.', '@'))
def forward(self, input):
self.refresh_running_tensor()
if not isinstance(input, list) or not isinstance(input, tuple):
input = [input]
for i, n in enumerate(self.input_name):
self.tensors[n] = input[i]
for name in self.compute_seq:
if name in self.op_tree.keys():
op_idx, inputs, out_seq = self.op_tree[name]
# print(name,op_idx, inputs)
args = []
for input in inputs:
args.append(self.tensors[input])
# print(len(args))
result = self.module_list[op_idx](*args)
if not isinstance(result, tuple):
self.tensors[name] = result
else:
self.tensors[name] = result[out_seq]
if len(self.output_name) == 1:
return self.tensors[self.output_name[0]]
else:
ret = []
for output in self.output_name:
ret.append(self.tensors[output])
return ret
def reduce(self):
import copy
net = _pt_model(None)
for k in self.reserved_tensors_name:
if isinstance(self.loaded_weights[k], torch.FloatTensor):
setattr(net, 'P' + k.replace('.', '@'),
torch.nn.Parameter(getattr(self,'P' + k.replace('.', '@')).data.detach().clone()) )
else:
net.register_buffer('P' + k.replace('.', '@'),
getattr(self,'P' + k.replace('.', '@')).data.clone() )
net.compute_seq = copy.deepcopy(self.compute_seq)
net.input_name = copy.deepcopy(self.input_name)
net.output_name = copy.deepcopy(self.output_name)
net.module_list = copy.deepcopy(self.module_list)
net.op_tree = copy.deepcopy(self.op_tree)
net.reserved_tensors_name = copy.deepcopy(self.reserved_tensors_name)
return net
[文档]def load_parameters(model:_pt_model, initializer):
param_dict = {}
for init in initializer:
param_dict[init.name] = torch.from_numpy(numpy_helper.to_array(init).copy())
return param_dict
class _o2p_converter:
def __init__(self):
'''
* :ref:`API in English <ONNX_Converter.__init__-en>`
.. _ONNX_Converter.__init__-cn:
该类主要将onnx模型转换为Pytorch的ANN模型,从而转换为SpikingJelly的SNN模型
链接中 [#f1]_ 提供了一个onnx-pytorch转换的主要版本。更复杂的版本可以在这里找到。
大多数使用过的onnx运算符已在此处定义,但仍然有一些未被覆盖,或没有被完美实现
用户可以通过添加如下面例子所示的静态方法来定义您的例外情况
* :ref:`API in English <ONNX_Converter.__init__-cn>`
.. _ONNX_Converter.__init__-en:
This class mainly convert an onnx model to Pytorch ANN model, and thus to SpikingJelly SNN model
The link [#f1]_ has provided a primary version of onnx-pytorch conversion. More complex version can be found here.
Most used onnx operators has covered here, yet still there are some left, or not being defined perfectly
User can define your exceptions by adding static method like below
.. [#f1] https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8
'''
pass
def add_method(self, op_name, func):
setattr(self, 'convert_'+op_name, staticmethod(func))
@staticmethod
def convert_conv(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
"group": "groups",
"dilations": "dilation"
}
assert len(node.output) == 1
with_bias = False
if len(node.input) == 3:
with_bias = True
bias = model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[2]]
weight = model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[1]]
in_channels = weight.shape[1]
out_channels = weight.shape[0]
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i
if 'padding' in kwargs:
assert(kwargs["padding"][0]==kwargs["padding"][2] and kwargs["padding"][1]==kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0],kwargs["padding"][1]
groups = 1 if 'groups' not in kwargs else kwargs['groups']
in_channels *= groups
conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias)
conv.weight.data = weight
if with_bias:
conv.bias.data = bias
model.module_list.append(conv)
return len(model.module_list)-1, node.input[:1], node.output
@staticmethod
def convert_relu(node, model:_pt_model):
relu = nn.ReLU()
model.module_list.append(relu)
return len(model.module_list)-1, node.input, node.output
@staticmethod
def convert_prelu(node, model:_pt_model):
weight = model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[1]]
prelu = nn.PReLU()
prelu.weight.data = weight
model.module_list.append(prelu)
return len(model.module_list) - 1, node.input[:-1], node.output
@staticmethod
def convert_shape(node, model:_pt_model):
shape = Shape()
model.module_list.append(shape)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_gather(node, model:_pt_model):
attr_map = {
"axis": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
gather = Gather(**kwargs)
model.module_list.append(gather)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_unsqueeze(node, model:_pt_model):
attr_map = {
"axes": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
unsqueeze = Unsqueeze(**kwargs)
model.module_list.append(unsqueeze)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_concat(node, model:_pt_model):
attr_map = {
"axis": "dim"
}
kwargs = {}
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
concat = Concat(**kwargs)
model.module_list.append(concat)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_reshape(node, model:_pt_model):
reshape = Reshape()
model.module_list.append(reshape)
return len(model.module_list) - 1, node.input, node.output
# @staticmethod
# def convert_matmul(node, model:_pt_model):
# class MatMul(nn.Module):
# def __init__(self):
# super().__init__()
# def forward(self,input1,input2):
# return input1 @ input2
# mul = MatMul()
# model.module_list.append(mul)
# return len(model.module_list)-1, node.input, node.output
@staticmethod
def convert_batchnormalization(node, model:_pt_model):
attr_map = {
"epsilon": "eps",
"momentum": "momentum"
}
assert len(node.input) == 5
assert len(node.output) == 1
weight = model.loaded_weights[node.input[1]]
bias = model.loaded_weights[node.input[2]]
running_mean = model.loaded_weights[node.input[3]]
running_var = model.loaded_weights[node.input[4]]
del model.loaded_weights[node.input[1]]
del model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[3]]
del model.loaded_weights[node.input[4]]
dim = weight.shape[0]
kwargs = {}
# _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map)
for att in node.attribute:
if att.name in attr_map:
kwargs[attr_map[att.name]] = att.f
bn = None
if model.dim_info[node.output[0]] == 5:
bn = nn.BatchNorm3d(num_features=dim)
elif model.dim_info[node.output[0]] == 4:
bn = nn.BatchNorm2d(num_features=dim)
elif model.dim_info[node.output[0]] == 2 or model.dim_info[node.output[0]] == 3:
bn = nn.BatchNorm1d(num_features=dim)
bn.weight.data = weight
bn.bias.data = bias
bn.running_mean.data = running_mean
bn.running_var.data = running_var
model.module_list.append(bn)
return len(model.module_list)-1, node.input[:1], node.output
@staticmethod
def convert_add(node, model:_pt_model):
add = Add()
model.module_list.append(add)
return len(model.module_list)-1, node.input, node.output
@staticmethod
def convert_mul(node, model:_pt_model):
mul = Mul()
model.module_list.append(mul)
return len(model.module_list)-1, node.input, node.output
@staticmethod
def convert_averagepool(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
ap = nn.AvgPool2d(**kwargs)
model.module_list.append(ap)
return len(model.module_list)-1, node.input, node.output
@staticmethod
def convert_globalaveragepool(node, model:_pt_model):
gap = nn.AdaptiveAvgPool2d((1, 1))
model.module_list.append(gap)
model.module_list.append(gap)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_maxpool(node, model:_pt_model):
attr_map = {
"pads": "padding",
"strides": "stride",
"kernel_shape": "kernel_size",
}
kwargs = {}
for att in node.attribute:
kwargs[attr_map[att.name]] = list(att.ints)
if 'padding' in kwargs:
assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3])
kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1]
ap = nn.MaxPool2d(**kwargs)
model.module_list.append(ap)
return len(model.module_list) - 1, node.input, node.output
@staticmethod
def convert_flatten(node, model:_pt_model):
if len(node.attribute) == 0:
axis = 1
else:
axis = node.attribute[0].i
if axis==1:
flatten = nn.Flatten()
model.module_list.append(flatten)
return len(model.module_list)-1, node.input, node.output
else:
raise NotImplementedError("Not Implemented yet!")
@staticmethod
def convert_gemm(node, model:_pt_model):
weight = model.loaded_weights[node.input[1]]
bias = model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[2]]
del model.loaded_weights[node.input[1]]
in_features = weight.shape[1]
out_features = weight.shape[0]
linear = nn.Linear(in_features=in_features, out_features=out_features)
linear.weight.data = weight
linear.bias.data = bias
model.module_list.append(linear)
return len(model.module_list)-1, node.input[:1], node.output
@staticmethod
def convert_pad(node, model:_pt_model):
mode = node.attribute[0].s
pads = list(node.attribute[1].ints)
value = node.attribute[2].f
try:
assert(mode == b'constant')
assert(sum(pads[:4]) == 0)
except AssertionError:
print("Now only support converting to nn.ConstantPad2d")
pad = nn.ConstantPad2d([*pads[2:4],*pads[3:5]],value)
model.module_list.append(pad)
return len(model.module_list)-1, node.input, node.output