spikingjelly.clock_driven.ann2snn.kernels.onnx 源代码

import onnx
import onnx.helper as helper
import onnx.numpy_helper as numpy_helper
import collections
import numpy as np
import torch
import torch.nn as nn
import os
import tqdm
import onnxruntime as ort
from collections import defaultdict
import json


[文档]class Mul(nn.Module): def __init__(self): super().__init__()
[文档] def forward(self, input1, input2): return input1 * input2
[文档]class Add(nn.Module): def __init__(self): super().__init__()
[文档] def forward(self,input1,input2): return input1 + input2
[文档]class Reshape(nn.Module): def __init__(self): super().__init__()
[文档] def forward(self, input1, input2): return torch.reshape(input1,shape=list(input2))
[文档]class Concat(nn.Module): def __init__(self, dim=[1]): super().__init__() self.dim = dim if not isinstance(self.dim, list): self.dim = [self.dim] for i, d in enumerate(self.dim): if not isinstance(d, int): self.dim[i] = int(d)
[文档] def forward(self, *args): args = list(args) for i,a in enumerate(args): args[i] = a.type_as(args[0]) return torch.cat(args,dim=self.dim[0])
[文档]class Shape(nn.Module): def __init__(self): super().__init__()
[文档] def forward(self, input): return torch.IntTensor([input.size(i) for i in range(len(input.size()))])
[文档]class Gather(nn.Module): def __init__(self,dim=1): super().__init__() self.dim= int(dim)
[文档] def forward(self, input1, input2): return torch.gather(input1,dim=self.dim,index=input2.cpu())
[文档]class Unsqueeze(nn.Module): def __init__(self, dim=[1]): super().__init__() self.dim = dim if not isinstance(self.dim, list): self.dim = [self.dim] for i,d in enumerate(self.dim): if not isinstance(d,int): self.dim[i] = int(d)
[文档] def forward(self, input): x = input for i in self.dim: x = torch.unsqueeze(x,dim=i) return x
[文档]class TopologyAnalyser: def __init__(self): ''' * :ref:`API in English <TopologyAnalyser.__init__-en>` .. _TopologyAnalyser.__init__-cn: 这个类通过onnx分析模型的拓扑结构,方便后续处理 此处还有更多更好的实现方法,欢迎开发者不断优化 * :ref:`API in English <TopologyAnalyser.__init__-cn>` .. _TopologyAnalyser.__init__-en: This class analyzes the topological structure of the model through onnx to facilitate subsequent processing There are better implementation methods here, developers are welcome to continue to optimize ''' self.data_nodes = [] self.module_output = collections.OrderedDict() self.module_op = collections.OrderedDict() self.module_idx = collections.OrderedDict() self.param_idx = collections.OrderedDict() self.edge = collections.OrderedDict() self.reverse_edge = collections.OrderedDict() # 快速计算前驱结点
[文档] def add_data_node(self, a): if not a in self.data_nodes: self.data_nodes.append(a)
[文档] def insert(self, a, b, info=None): self.add_data_node(a) self.add_data_node(b) if a not in self.edge.keys(): self.edge[a] = [(b, info)] else: self.edge[a].append((b, info)) if b not in self.reverse_edge.keys(): self.reverse_edge[b] = [a] else: self.reverse_edge[b].append(a)
[文档] def findNext(self, id): if isinstance(id, str): if id in self.edge.keys(): return self.edge[id] else: return [] elif isinstance(id, list): l = [] for i in id: l += self.findNext(i) return l
[文档] def findPre(self, id): l = [] if isinstance(id, str): for pre_id in self.reverse_edge[id]: if pre_id in self.reverse_edge.keys(): for pre_pre_id in self.reverse_edge[pre_id]: if pre_pre_id in self.edge.keys(): for item in self.edge[pre_pre_id]: if item[0] == pre_id: l.append(item) elif isinstance(id, list): for i in id: l += self.findPre(i) return l
[文档] def find_pre_module(self, module_name): if module_name in self.module_output.keys(): ids = self.module_output[module_name] return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findPre(ids)]) else: return set()
[文档] def find_next_module(self, module_name): if module_name in self.module_output.keys(): ids = self.module_output[module_name] return set(['%s:%s' % (k[1]['op'], k[1]['param_module_name']) for k in self.findNext(ids)]) else: return set()
[文档] def update_module_idx(self, onnx_graph): for idx, n in enumerate(onnx_graph.node): trainable_input = n.input[1:] op = n.op_type k = set() for i in trainable_input: n = self._get_module_name_from_value_name(i) if n is not None: k.add(n) if len(k) > 1: # TODO: sanity check, raise error pass if len(k) == 1: param_module_name = list(k)[0] self.module_op[param_module_name] = op self.module_idx[param_module_name] = idx
[文档] def analyse(self, onnx_graph): # 输入的onnx graph需要保证所以常量在已经在initializer中 # 先把该往initializer放下面的参数,保证下面只有运算没有常量 for idx, constant in enumerate(onnx_graph.initializer): self.param_idx[constant.name] = idx for idx, n in enumerate(onnx_graph.node): param_module_name = None op = n.op_type inputs = n.input outputs = n.output # print(inputs,outputs) k = set() trainable_input = inputs[1:] for i in trainable_input: n = self._get_module_name_from_value_name(i) if n is not None: k.add(n) if len(k) > 1: # TODO: sanity check, raise error pass if len(k) == 1: param_module_name = list(k)[0] self.module_op[param_module_name] = op self.module_idx[param_module_name] = idx if op is not None: for o in outputs: for i in inputs: self.insert(i, o, {'op': op, 'param_module_name': param_module_name}) if param_module_name is not None: if param_module_name not in self.module_output.keys(): self.module_output[param_module_name] = [o] else: self.module_output[param_module_name].append(o) return self
@staticmethod def _get_module_name_from_value_name(value_name): module_name = None if len(value_name.split('.')) > 1: l = value_name.split('.')[:-1] l = '.'.join(l) module_name = l # [1:] module_name.replace(' ', '') return module_name
[文档]def pytorch2onnx_model(model: nn.Module, data, **kargs) -> onnx.ModelProto: ''' * :ref:`API in English <pytorch2onnx_model-en>` .. _pytorch2onnx_model-cn: :param model: 待转换的PyTorch模型 :param data: 用于转换的数据(用来确定输入维度) :param log_dir: 输出文件夹 转换PyTorch模型到onnx模型 * :ref:`API in English <pytorch2onnx_model-cn>` .. _pytorch2onnx_model-en: :param model: the PyTorch model to be converted :param data: The data used for conversion (used to determine the input dimension) :param log_dir: output folder Convert PyTorch model to onnx model ''' try: log_dir = kargs['log_dir'] except KeyError: print('pytorch2onnx_model need argument log_dir!') dump_input_size = [data.size(i) for i in range(len(data.size()))] dump_input_size[0] = 1 fname = os.path.join(log_dir,'onnxmodel') try: dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} torch.onnx.export(model, torch.ones(dump_input_size), fname, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes) except BaseException: raise NotImplementedError("Models with multiple inputs are not supported yet!") return onnx.load(fname)
[文档]def onnx2pytorch_model(model: onnx.ModelProto, _converter) -> nn.Module: model = _pt_model(model, _converter) model = model.reduce() return model
[文档]def layer_reduction(model: onnx.ModelProto) -> onnx.ModelProto: graph = model.graph topo_analyser = TopologyAnalyser() graph = move_constant_to_initializer(graph) topo_analyser.analyse(graph) absorb_bn(graph, topo_analyser) remove_unreferenced_initializer(graph) update_topology(graph) print("Finish layer reduction!") return model
[文档]def rate_normalization(model: onnx.ModelProto, data: torch.Tensor, **kargs) -> onnx.ModelProto: ''' * :ref:`API in English <rate_normalization-en>` .. _rate_normalization-cn: :param model: ANN模型,类型为onnx.ModelProto :param data: 用于转换的数据,类型为torch.Tensor :param channelwise: 如果为``True``,则控制激活幅值的统计是channelwise的;否则,控制激活幅值的统计是layerwise的 :param robust: 如果为``True``,则控制激活幅值的统计是激活的99.9百分位;否则,控制激活幅值的统计是激活的最值 :param eps: epsilon;未设置值时默认1e-5 发放率归一化 * :ref:`API in English <rate_normalization-cn>` .. _rate_normalization-en: :param model: ANN model, the type is onnx.ModelProto :param data: the data used for conversion, the type is torch.Tensor :param channelwise: If ``True`` , the statistics that control the activation amplitude are channelwise; otherwise, the statistics that control the activation amplitude are layerwise :param robust: If ``True``, the statistic of the control activation amplitude is the 99.9th percentile of activation; otherwise, the statistic of the activation amplitude is the maximum value of activation :param eps: epsilon; if no value is set, the default is 1e-5 normalize the firing rate ''' try: channelwise = kargs['channelwise'] except KeyError: channelwise = False try: robust_norm = kargs['robust'] except KeyError: robust_norm = False try: eps = kargs['eps'] except KeyError: eps = 1e-5 topo_analyser = update_topology(model.graph) output_debug = {} output_statistics = get_intermediate_output_statistics(model, data, channelwise=channelwise) # if want debug, debug=output_debug model = normalize_model(model, output_statistics, topo_analyser, robust_norm=robust_norm, channelwise=channelwise, eps=eps) return model
[文档]def save_model(model: onnx.ModelProto, f=None): fb = model.SerializeToString() if f is not None: if hasattr(f, 'write'): f.write(fb) else: with open(f, "wb") as f: f.write(fb) return fb
[文档]def move_constant_to_initializer(graph): constant_idx = [] for idx, n in enumerate(graph.node): op = n.op_type if op == 'Constant': constant_idx.append(idx) if len(constant_idx): for idx in reversed(constant_idx): n = graph.node[idx] graph.initializer.append( numpy_helper.from_array(numpy_helper.to_array(n.attribute[0].t), n.output[0])) graph.node.remove(n) return graph
[文档]def absorb_bn(graph, topo_analyser): print("\nAbsorbing BatchNorm Parameters...\n") for mn in tqdm.tqdm(reversed(topo_analyser.module_output.keys())): if topo_analyser.module_op[mn] == 'BatchNormalization': pre_m = topo_analyser.find_pre_module(mn) next_m = topo_analyser.find_next_module(mn) bn_weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[1]] bn_weight = np.array(numpy_helper.to_array(graph.initializer[bn_weight_idx])) bn_bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[2]] bn_bias = np.array(numpy_helper.to_array(graph.initializer[bn_bias_idx])) bn_mean_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[3]] bn_mean = np.array(numpy_helper.to_array(graph.initializer[bn_mean_idx])) bn_var_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[mn]].input[4]] bn_var = np.array(numpy_helper.to_array(graph.initializer[bn_var_idx])) bn_eps = graph.node[topo_analyser.module_idx[mn]].attribute[0].f bn_std = np.sqrt(bn_var + bn_eps) if len(pre_m) == 1 and list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm']: pre_mn = list(pre_m)[0].split(':')[1] weight_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[1]] weight = np.array(numpy_helper.to_array(graph.initializer[weight_idx])) if len(graph.node[topo_analyser.module_idx[pre_mn]].input) == 2: bias = None else: bias_idx = topo_analyser.param_idx[graph.node[topo_analyser.module_idx[pre_mn]].input[2]] bias = np.array(numpy_helper.to_array(graph.initializer[bias_idx])) wrsp_args = (-1, 1) if len(weight.shape) == 2 else (-1, 1, 1, 1) weight_ = weight * bn_weight.reshape(*wrsp_args) / bn_std.reshape(*wrsp_args) bias_ = ((bias if bias is not None else 0) - bn_mean.reshape(-1)) * bn_weight.reshape( -1) / bn_std.reshape(-1) \ + bn_bias.reshape(-1) assert (list(pre_m)[0].split(':')[0] in ['Conv', 'Gemm']) args = {} for attr in graph.node[topo_analyser.module_idx[pre_mn]].attribute: args[attr.name] = helper.get_attribute_value(attr) new_node = onnx.helper.make_node( list(pre_m)[0].split(':')[0], inputs=[graph.node[topo_analyser.module_idx[pre_mn]].input[0], pre_mn + ".new.weight", pre_mn + ".new.bias"], outputs=[graph.node[topo_analyser.module_idx[mn]].output[0]], **args ) graph.initializer.append(numpy_helper.from_array(weight_.astype(np.float32), pre_mn + ".new.weight")) graph.initializer.append(numpy_helper.from_array(bias_.astype(np.float32), pre_mn + ".new.bias")) graph.node.remove(graph.node[topo_analyser.module_idx[pre_mn]]) graph.node.insert(topo_analyser.module_idx[pre_mn], new_node) graph.node.remove(graph.node[topo_analyser.module_idx[mn]]) else: weight_ = bn_weight / bn_std bias_ = bn_bias - bn_weight * bn_mean / bn_std name = graph.initializer[bn_weight_idx].name graph.initializer.remove(graph.initializer[bn_weight_idx]) graph.initializer.insert(bn_weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name)) name = graph.initializer[bn_bias_idx].name graph.initializer.remove(graph.initializer[bn_bias_idx]) graph.initializer.insert(bn_bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name)) name = graph.initializer[bn_mean_idx].name graph.initializer.remove(graph.initializer[bn_mean_idx]) graph.initializer.insert(bn_mean_idx, numpy_helper.from_array(np.zeros_like(bn_mean).astype(np.float32), name)) name = graph.initializer[bn_var_idx].name graph.initializer.remove(graph.initializer[bn_var_idx]) graph.initializer.insert(bn_var_idx, numpy_helper.from_array(np.ones_like(bn_var).astype(np.float32), name))
[文档]def remove_unreferenced_initializer(graph): in_graph = set() in_initializer = set() for node in graph.node: in_graph.update(node.input) in_graph.update(node.output) for init in graph.initializer: in_initializer.add(init.name) not_in_graph = in_initializer - in_graph l = len(graph.initializer) for i in range(l - 1, -1, -1): if graph.initializer[i].name in not_in_graph: graph.initializer.remove(graph.initializer[i])
[文档]def update_topology(graph): topo_analyser = TopologyAnalyser() move_constant_to_initializer(graph) topo_analyser.analyse(graph) return topo_analyser
[文档]def find_node_by_output(output_name, graph): flag = False idx, node = None, None for idx, node in enumerate(graph.node): if output_name in node.output: flag = True break if not flag: idx, node = None, None return idx, node
[文档]def scale_node_weight_bias(topo_analyser, graph, node_idx, scale): initializer = graph.initializer node = graph.node[node_idx] if len(node.input) < 2: return weight_idx = topo_analyser.param_idx[node.input[1]] bias_idx = topo_analyser.param_idx[node.input[2]] if len(node.input) >= 3 else None weight = np.array(numpy_helper.to_array(initializer[weight_idx])) bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None w_scale = scale.reshape([*scale.shape] + [1 for _ in range(len(weight.shape) - 1)]) \ if len(scale.shape) == 1 else scale b_scale = scale weight_ = weight * w_scale name = initializer[weight_idx].name initializer.remove(initializer[weight_idx]) initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name)) if bias is not None: bias_ = bias * b_scale name = initializer[bias_idx].name initializer.remove(initializer[bias_idx]) initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name))
[文档]def get_onnx_output(model, numpy_tensor): ort_session = ort.InferenceSession(model.SerializeToString()) outputs = ort_session.run(None, {'input': numpy_tensor}) return outputs
[文档]def get_intermediate_output_statistics(model, numpy_tensor, channelwise=False, debug=None): graph = model.graph output_needed_module = {} output_needed_all_input = {} for idx, node in enumerate(graph.node): output = node.output input = node.input if 'input' in node.input: for out in output: output_needed_module[out] = set([idx]) output_needed_all_input[out] = set(input) else: s = set() s_i = set() for in_ in input: s |= (output_needed_module[in_] if in_ in output_needed_module.keys() else set()) s_i |= (output_needed_all_input[in_] if in_ in output_needed_all_input.keys() else set()) for out in output: output_needed_module[out] = s | set([idx]) output_needed_all_input[out] = s_i | set(input) output_statistics = {} if not channelwise: statistic = {'shape': numpy_tensor.shape, 'min': np.min(numpy_tensor), 'max': np.max(numpy_tensor) if np.max(numpy_tensor) > 0 else np.abs(np.min(numpy_tensor)), '99.9': np.percentile(numpy_tensor, 99.9) } else: axis_args = (0, 2, 3) if len(numpy_tensor.shape) == 4 else (0) statistic = {'shape': numpy_tensor.shape, 'min': np.min(numpy_tensor, axis=axis_args), 'max': np.max(numpy_tensor, axis=axis_args), '99.9': np.percentile(numpy_tensor, 99.9, axis=axis_args) } output_statistics['input'] = statistic print("\nGetting intermediate output statistics...\n") for out in tqdm.tqdm(output_needed_module.keys()): keep_nodes = [graph.node[i] for i in list(output_needed_module[out])] keep_initializer = [init for init in graph.initializer if init.name in list(output_needed_all_input[out])] var_out = [] value_info = onnx.ValueInfoProto() value_info.name = out var_out.append(value_info) new_graph = onnx.helper.make_graph(keep_nodes, graph.name, graph.input, var_out, keep_initializer) tmp_model = onnx.helper.make_model(new_graph) tmp_model.ir_version = model.ir_version tmp_model.producer_name = model.producer_name tmp_model.producer_version = model.producer_version tmp_model.domain = model.domain tmp_model.model_version = model.model_version tmp_model.doc_string = model.doc_string if len(tmp_model.metadata_props) > 0: values = {p.key: p.value for p in model.metadata_props} onnx.helper.set_model_props(tmp_model, values) # fix opset import for oimp in model.opset_import: op_set = tmp_model.opset_import.add() op_set.domain = oimp.domain op_set.version = oimp.version ort_session = ort.InferenceSession(tmp_model.SerializeToString()) outputs = ort_session.run(None, {'input': numpy_tensor}) if debug is not None: # print(out,outputs[0].reshape(1,-1)[0,10:20]) debug[out] = outputs[0] if not channelwise: statistic = {'shape': outputs[0].shape, 'min': np.min(outputs[0]), 'max': np.max(outputs[0]) if np.max(outputs[0]) > 0 else np.abs(np.min(outputs[0])), '99.9': np.percentile(outputs[0], 99.9) if np.percentile(outputs[0], 99.9) > 0 else np.abs(np.min(outputs[0])) } else: axis_args = (0, 2, 3) if len(outputs[0].shape) == 4 else (0) statistic = {'shape': outputs[0].shape, 'min': np.min(outputs[0], axis=axis_args), 'max': np.max(outputs[0], axis=axis_args), '99.9': np.percentile(outputs[0], 99.9, axis=axis_args) } # print(np.max(statistic['max']),np.max(outputs[0])) output_statistics[out] = statistic print("Finished getting intermediate output statistics!") if debug is not None: return output_statistics,debug else: return output_statistics
[文档]def normalize_model(model, output_statistics, topo_analyser, robust_norm=True, channelwise=False, eps=1e-5): nodes = model.graph.node graph = model.graph initializer = model.graph.initializer if robust_norm: statistic_key = '99.9' else: statistic_key = 'max' node_scaled_range = {} seperate_scale = collections.OrderedDict() print("\nNormalizing model...\n") for node_idx, node in enumerate(tqdm.tqdm(nodes)): output = node.output input = node.input op = node.op_type if input[0] == 'input': # single input model l = output_statistics[input[0]]['shape'][1] node_scaled_range[input[0]] = np.ones(l) if channelwise else 1.0 \ * output_statistics[input[0]][ statistic_key] if op in ['Conv', 'Gemm']: weight_idx = topo_analyser.param_idx[input[1]] bias_idx = topo_analyser.param_idx[input[2]] if len(input) == 3 else None weight = np.array(numpy_helper.to_array(initializer[weight_idx])) bias = np.array(numpy_helper.to_array(initializer[bias_idx])) if bias_idx is not None else None l = output_statistics[output[0]]['shape'][1] input_real_range = node_scaled_range[input[0]] input_range = output_statistics[input[0]][statistic_key] output_range = output_statistics[output[0]][statistic_key] demand = np.ones(l) if channelwise else 1.0 w_scale = demand / (output_range + eps) * ( input_range / (input_real_range + eps)) if not channelwise else \ (demand / (output_range + eps)).reshape(-1, 1).dot( (input_range / (input_real_range + eps)).reshape(1, -1)) w_scale = w_scale.reshape([*w_scale.shape, 1, 1]) if len(weight.shape) == 4 else w_scale b_scale = 1 / (output_range + eps) node_scaled_range[output[0]] = demand weight_ = weight * w_scale name = initializer[weight_idx].name initializer.remove(initializer[weight_idx]) initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name)) if bias is not None: bias_ = bias * b_scale name = initializer[bias_idx].name initializer.remove(initializer[bias_idx]) initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name)) elif op == 'BatchNormalization': # var=1 mean=0 weight_idx = topo_analyser.param_idx[input[1]] bias_idx = topo_analyser.param_idx[input[2]] weight = np.array(numpy_helper.to_array(initializer[weight_idx])) bias = np.array(numpy_helper.to_array(initializer[bias_idx])) # node_scaled_range[output[0]] = node_scaled_range[input[0]] * self.output_statistics[input[0]][statistic_key] / self.output_statistics[output[0]][statistic_key] # lamda_last = self.output_statistics[input[0]][statistic_key] # lamda = self.output_statistics[output[0]][statistic_key] # weight_ = weight * node_scaled_range[output[0]] # bias_ = bias / lamda # print(output_statistics[output[0]]) input_real_range = node_scaled_range[input[0]] input_range = output_statistics[input[0]][statistic_key] output_range = output_statistics[output[0]][statistic_key] demand = 1.0 w_scale = demand / (output_range + eps) * (input_range / (input_real_range + eps)) b_scale = 1 / (output_range + eps) node_scaled_range[output[0]] = demand weight_ = weight * w_scale bias_ = bias * b_scale # print(output[0],op,input[0], input_range, output_range, demand, input_real_range, w_scale) name = initializer[weight_idx].name initializer.remove(initializer[weight_idx]) initializer.insert(weight_idx, numpy_helper.from_array(weight_.astype(np.float32), name)) name = initializer[bias_idx].name initializer.remove(initializer[bias_idx]) initializer.insert(bias_idx, numpy_helper.from_array(bias_.astype(np.float32), name)) elif op == 'Add': l = output_statistics[output[0]]['shape'][1] demand = np.ones(l) if channelwise else 1.0 node_scaled_range[output[0]] = demand output_range = output_statistics[output[0]][statistic_key] # node_scaled_range[output[0]] = 1.0 # lamda = self.output_statistics[output[0]][statistic_key] # lamda_lasts = {} for i in input: if i in output_statistics.keys(): # lamda_lasts[i] = self.output_statistics[i][statistic_key] # scale = lamda_lasts[i] / lamda input_real_range = node_scaled_range[i] input_range = output_statistics[i][statistic_key] scale = demand / (output_range + eps) * (input_range / (input_real_range + eps)) # print(output[0], op, i, input_range, output_range, demand, input_real_range, scale) idx, _ = find_node_by_output(i, graph) if idx is not None and nodes[idx].op_type in ['Conv', 'Gemm', 'BatchNormalization']: scale_node_weight_bias(topo_analyser, graph, idx, scale) else: scale = scale.reshape( [1, *scale.shape] + [1 for _ in range(len(output_statistics[i]['shape']) - 2)]) \ if len(scale.shape) == 1 else scale initializer.append(numpy_helper.from_array(scale.astype(np.float32), "scale_" + i)) if idx not in seperate_scale.keys(): # seperate_scale[node_idx] = [(i,"scale_"+i,"scaled_"+i)] seperate_scale[node_idx] = {i: ("scale_" + i, "scaled_" + i)} else: # seperate_scale[node_idx].append((i,"scale_"+i,"scaled_"+i)) seperate_scale[node_idx][i] = ("scale_" + i, "scaled_" + i) pass elif op in ['Gather', 'Unsqueeze', 'Shape', 'Concat']: continue # elif op == "Concat": # raise NotImplementedError("Not supported %s yet!"%(op)) elif op == "Softmax": raise NotImplementedError("Not supported %s yet!" % (op)) else: # single input single output module # print(op,self.output_statistics[output[0]]['shape']) input_range = output_statistics[input[0]][statistic_key] output_range = output_statistics[output[0]][statistic_key] input_scaled_range = node_scaled_range[input[0]] output_scaled_range = input_scaled_range / (input_range + eps) * output_range node_scaled_range[output[0]] = output_scaled_range # print(output[0], op, input[0], input_range, output_range, output_scaled_range) # print(op, node_scaled_range[output[0]],'=',input_scaled_range,'/',input_range,'*',output_range) # else: # raise NotImplementedError("Not supported yet! %s"%(op)) if len(seperate_scale.keys()) != 0: print("Making new scale node...") for node_idx in reversed(seperate_scale.keys()): args = {} for attr in nodes[node_idx].attribute: args[attr.name] = helper.get_attribute_value(attr) input = [str(i) if i not in seperate_scale[node_idx].keys() else seperate_scale[node_idx][i][1] \ for i in nodes[node_idx].input] output = [str(i) for i in nodes[node_idx].output] new_node = onnx.helper.make_node( nodes[node_idx].op_type, inputs=input, outputs=output, **args ) nodes.remove(nodes[node_idx]) nodes.insert(node_idx, new_node) for i in seperate_scale[node_idx].keys(): new_node = onnx.helper.make_node( 'Mul', inputs=[seperate_scale[node_idx][i][0], i], outputs=[seperate_scale[node_idx][i][1]] ) nodes.insert(node_idx, new_node) print("Finished normalizing model!") return model
def _pre_onnx_shape_inference(model:onnx.ModelProto): ''' 为了对模型进行shape inference,需要先对onnxmodel运行此函数进行准备 To perform shape inference for model, need to run this function on onnxmodel to prepare This function has referenced code in https://github.com/onnx/onnx/issues/2660#issuecomment-605874784 ''' if model.ir_version < 4: return def add_some_graph_info(graph:onnx.GraphProto): inputs = {i.name for i in graph.input} vi_dict = {vi.name: vi for vi in graph.value_info} for init in graph.initializer: if init.name in inputs: continue vi = vi_dict.get(init.name) if vi is None: vi = graph.value_info.add() vi.name = init.name tensor_type = vi.type.tensor_type if tensor_type.elem_type == onnx.TensorProto.UNDEFINED: tensor_type.elem_type = init.data_type if not tensor_type.HasField("shape"): tensor_type.shape.dim.extend([]) for dim in init.dims: tensor_type.shape.dim.add().dim_value = dim for node in graph.node: for attr in node.attribute: if attr.ref_attr_name != "": continue if attr.type == onnx.AttributeProto.GRAPH: add_some_graph_info(attr.g) if attr.type == onnx.AttributeProto.GRAPHS: for g in attr.graphs: add_some_graph_info(g) return add_some_graph_info(model.graph) class _pt_model(nn.Module): def __init__(self, path_or_model, _converter=None): super(_pt_model, self).__init__() if path_or_model is not None: if isinstance(path_or_model, str): onnx_model = onnx.load(path_or_model) else: onnx_model = path_or_model self.onnx_model = onnx_model self.loaded_weights = load_parameters(self, onnx_model.graph.initializer) self.module_list = nn.ModuleList([]) self.op_tree = {} _pre_onnx_shape_inference(onnx_model) inferred_model = onnx.shape_inference.infer_shapes(onnx_model) self.value_info = inferred_model.graph.value_info self.dim_info = {} for idx, v in enumerate(self.value_info): self.dim_info[v.name] = len(v.type.tensor_type.shape.dim) self.graph = defaultdict(list) # self.V = set() for idx, node in enumerate(onnx_model.graph.node): op = node.op_type # if op=='MatMul': # TODO temporary # op = 'Gemm' (op_idx, inputs, outputs) = getattr(_converter, 'convert_' + op.lower())(node, self) for out_seq, output in enumerate(outputs): self.op_tree[str(output)] = (int(op_idx), [str(i) for i in inputs], out_seq) for output in outputs: for input in inputs: self.graph[input].append(output) # self.V.update(inputs) # self.V.update(outputs) # print(self.V) self.op_tree = json.dumps(self.op_tree) self.input_name = [i.name for i in onnx_model.graph.input] self.output_name = [i.name for i in onnx_model.graph.output] for out in onnx_model.graph.output: self.graph[out.name] = [] def TopologicalSort(G): in_degrees = dict((u, 0) for u in G) for u in G: for v in G[u]: in_degrees[v] += 1 Q = [u for u in G if in_degrees[u] == 0] res = [] while Q: u = Q.pop() res.append(u) for v in G[u]: in_degrees[v] -= 1 if in_degrees[v] == 0: Q.append(v) return res self.compute_seq = TopologicalSort(self.graph) # print(self.compute_seq) # self.tensors = {} for k in self.loaded_weights.keys(): if isinstance(self.loaded_weights[k], torch.FloatTensor): setattr(self, 'P' + k.replace('.', '@'), torch.nn.Parameter(self.loaded_weights[k])) else: # print(self.loaded_weights[k]) self.register_buffer('P' + k.replace('.', '@'), self.loaded_weights[k]) self.reserved_tensors_name = list(self.loaded_weights.keys()) # print('reserve', self.reserved_tensors_name) # print(self.tensors) def refresh_running_tensor(self): self.tensors = {} for k in set(self.tensors.keys()) | set(self.reserved_tensors_name): if k not in self.reserved_tensors_name: del self.tensors[k] else: self.tensors[k] = getattr(self, 'P' + k.replace('.', '@')) def forward(self, input): op_tree = json.loads(self.op_tree) tensors = {} for k in set(tensors.keys()) | set(self.reserved_tensors_name): if k not in self.reserved_tensors_name: del tensors[k] else: tensors[k] = getattr(self, 'P' + k.replace('.', '@')) # self.refresh_running_tensor() if not isinstance(input, list) or not isinstance(input, tuple): input = [input] for i, n in enumerate(self.input_name): tensors[n] = input[i] for name in self.compute_seq: if name in op_tree.keys(): op_idx, inputs, out_seq = op_tree[name] # print(name,op_idx, inputs,out_seq) args = [] for input in inputs: args.append(tensors[input]) # print(len(args)) # print(type(args[0])) result = self.module_list[op_idx](*args) if not isinstance(result, tuple): tensors[name] = result # print(' %s = self.module_list[%d] (%s)'%(name,op_idx,inputs)) else: tensors[name] = result[out_seq] # print(' %s = self.module_list[%d] (%s)[%d]'%(name,op_idx,inputs,out_seq) ) if len(self.output_name) == 1: return tensors[self.output_name[0]] else: ret = [] for output in self.output_name: ret.append(tensors[output]) return ret def reduce(self): import copy net = _pt_model(None) for k in self.reserved_tensors_name: if isinstance(self.loaded_weights[k], torch.FloatTensor): setattr(net, 'P' + k.replace('.', '@'), torch.nn.Parameter(getattr(self,'P' + k.replace('.', '@')).data.detach().clone()) ) else: net.register_buffer('P' + k.replace('.', '@'), getattr(self,'P' + k.replace('.', '@')).data.clone() ) net.compute_seq = copy.deepcopy(self.compute_seq) net.input_name = copy.deepcopy(self.input_name) net.output_name = copy.deepcopy(self.output_name) net.module_list = copy.deepcopy(self.module_list) net.op_tree = copy.deepcopy(self.op_tree) net.reserved_tensors_name = copy.deepcopy(self.reserved_tensors_name) return net
[文档]def load_parameters(model:_pt_model, initializer): param_dict = {} for init in initializer: param_dict[init.name] = torch.from_numpy(numpy_helper.to_array(init).copy()) return param_dict
class _o2p_converter: def __init__(self): ''' * :ref:`API in English <ONNX_Converter.__init__-en>` .. _ONNX_Converter.__init__-cn: 该类主要将onnx模型转换为Pytorch的ANN模型,从而转换为SpikingJelly的SNN模型 链接中 [#f1]_ 提供了一个onnx-pytorch转换的主要版本。更复杂的版本可以在这里找到。 大多数使用过的onnx运算符已在此处定义,但仍然有一些未被覆盖,或没有被完美实现 用户可以通过添加如下面例子所示的静态方法来定义您的例外情况 * :ref:`API in English <ONNX_Converter.__init__-cn>` .. _ONNX_Converter.__init__-en: This class mainly convert an onnx model to Pytorch ANN model, and thus to SpikingJelly SNN model The link [#f1]_ has provided a primary version of onnx-pytorch conversion. More complex version can be found here. Most used onnx operators has covered here, yet still there are some left, or not being defined perfectly User can define your exceptions by adding static method like below .. [#f1] https://gist.github.com/qinjian623/6aa777037534c1c1dccbb66f832e93b8 ''' pass def add_method(self, op_name, func): setattr(self, 'convert_'+op_name, staticmethod(func)) @staticmethod def convert_conv(node, model:_pt_model): attr_map = { "pads": "padding", "strides": "stride", "kernel_shape": "kernel_size", "group": "groups", "dilations": "dilation" } assert len(node.output) == 1 with_bias = False if len(node.input) == 3: with_bias = True bias = model.loaded_weights[node.input[2]] del model.loaded_weights[node.input[2]] weight = model.loaded_weights[node.input[1]] del model.loaded_weights[node.input[1]] in_channels = weight.shape[1] out_channels = weight.shape[0] kwargs = {} for att in node.attribute: kwargs[attr_map[att.name]] = list(att.ints) if att.name != 'group' else att.i if 'padding' in kwargs: assert(kwargs["padding"][0]==kwargs["padding"][2] and kwargs["padding"][1]==kwargs["padding"][3]) kwargs["padding"] = kwargs["padding"][0],kwargs["padding"][1] groups = 1 if 'groups' not in kwargs else kwargs['groups'] in_channels *= groups conv = nn.Conv2d(in_channels, out_channels, **kwargs, bias=with_bias) conv.weight.data = weight if with_bias: conv.bias.data = bias model.module_list.append(conv) return len(model.module_list)-1, node.input[:1], node.output @staticmethod def convert_relu(node, model:_pt_model): relu = nn.ReLU() model.module_list.append(relu) return len(model.module_list)-1, node.input, node.output @staticmethod def convert_prelu(node, model:_pt_model): weight = model.loaded_weights[node.input[1]] del model.loaded_weights[node.input[1]] prelu = nn.PReLU() prelu.weight.data = weight model.module_list.append(prelu) return len(model.module_list) - 1, node.input[:-1], node.output @staticmethod def convert_shape(node, model:_pt_model): shape = Shape() model.module_list.append(shape) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_gather(node, model:_pt_model): attr_map = { "axis": "dim" } kwargs = {} for att in node.attribute: if att.name in attr_map: kwargs[attr_map[att.name]] = att.f gather = Gather(**kwargs) model.module_list.append(gather) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_unsqueeze(node, model:_pt_model): attr_map = { "axes": "dim" } kwargs = {} for att in node.attribute: if att.name in attr_map: kwargs[attr_map[att.name]] = att.f unsqueeze = Unsqueeze(**kwargs) model.module_list.append(unsqueeze) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_concat(node, model:_pt_model): attr_map = { "axis": "dim" } kwargs = {} for att in node.attribute: if att.name in attr_map: kwargs[attr_map[att.name]] = att.f concat = Concat(**kwargs) model.module_list.append(concat) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_reshape(node, model:_pt_model): reshape = Reshape() model.module_list.append(reshape) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_matmul(node, model:_pt_model): class MatMul(nn.Module): def __init__(self): super().__init__() def forward(self,input1,input2): return input1 @ input2 mul = MatMul() model.module_list.append(mul) return len(model.module_list)-1, node.input, node.output @staticmethod def convert_batchnormalization(node, model:_pt_model): attr_map = { "epsilon": "eps", "momentum": "momentum" } assert len(node.input) == 5 assert len(node.output) == 1 weight = model.loaded_weights[node.input[1]] bias = model.loaded_weights[node.input[2]] running_mean = model.loaded_weights[node.input[3]] running_var = model.loaded_weights[node.input[4]] del model.loaded_weights[node.input[1]] del model.loaded_weights[node.input[2]] del model.loaded_weights[node.input[3]] del model.loaded_weights[node.input[4]] dim = weight.shape[0] kwargs = {} # _check_attr(node.attribute, rebuild_batchnormalization.bn_attr_map) for att in node.attribute: if att.name in attr_map: kwargs[attr_map[att.name]] = att.f bn = None if model.dim_info[node.output[0]] == 5: bn = nn.BatchNorm3d(num_features=dim) elif model.dim_info[node.output[0]] == 4: bn = nn.BatchNorm2d(num_features=dim) elif model.dim_info[node.output[0]] == 2 or model.dim_info[node.output[0]] == 3: bn = nn.BatchNorm1d(num_features=dim) bn.weight.data = weight bn.bias.data = bias bn.running_mean.data = running_mean bn.running_var.data = running_var model.module_list.append(bn) return len(model.module_list)-1, node.input[:1], node.output @staticmethod def convert_add(node, model:_pt_model): add = Add() model.module_list.append(add) return len(model.module_list)-1, node.input, node.output @staticmethod def convert_mul(node, model:_pt_model): mul = Mul() model.module_list.append(mul) return len(model.module_list)-1, node.input, node.output @staticmethod def convert_averagepool(node, model:_pt_model): attr_map = { "pads": "padding", "strides": "stride", "kernel_shape": "kernel_size", } kwargs = {} for att in node.attribute: kwargs[attr_map[att.name]] = list(att.ints) if 'padding' in kwargs: assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3]) kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1] ap = nn.AvgPool2d(**kwargs) model.module_list.append(ap) return len(model.module_list)-1, node.input, node.output @staticmethod def convert_globalaveragepool(node, model:_pt_model): gap = nn.AdaptiveAvgPool2d((1, 1)) model.module_list.append(gap) model.module_list.append(gap) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_maxpool(node, model:_pt_model): attr_map = { "pads": "padding", "strides": "stride", "kernel_shape": "kernel_size", } kwargs = {} for att in node.attribute: kwargs[attr_map[att.name]] = list(att.ints) if 'padding' in kwargs: assert (kwargs["padding"][0] == kwargs["padding"][2] and kwargs["padding"][1] == kwargs["padding"][3]) kwargs["padding"] = kwargs["padding"][0], kwargs["padding"][1] ap = nn.MaxPool2d(**kwargs) model.module_list.append(ap) return len(model.module_list) - 1, node.input, node.output @staticmethod def convert_flatten(node, model:_pt_model): if len(node.attribute) == 0: axis = 1 else: axis = node.attribute[0].i if axis==1: flatten = nn.Flatten() model.module_list.append(flatten) return len(model.module_list)-1, node.input, node.output else: raise NotImplementedError("Not Implemented yet!") @staticmethod def convert_gemm(node, model:_pt_model): weight = model.loaded_weights[node.input[1]] bias = model.loaded_weights[node.input[2]] del model.loaded_weights[node.input[2]] del model.loaded_weights[node.input[1]] in_features = weight.shape[1] out_features = weight.shape[0] linear = nn.Linear(in_features=in_features, out_features=out_features) linear.weight.data = weight linear.bias.data = bias model.module_list.append(linear) return len(model.module_list)-1, node.input[:1], node.output @staticmethod def convert_pad(node, model:_pt_model): mode = node.attribute[0].s pads = list(node.attribute[1].ints) value = node.attribute[2].f try: assert(mode == b'constant') assert(sum(pads[:4]) == 0) except AssertionError: print("Now only support converting to nn.ConstantPad2d") pad = nn.ConstantPad2d([*pads[2:4],*pads[3:5]],value) model.module_list.append(pad) return len(model.module_list)-1, node.input, node.output