spikingjelly.clock_driven.ann2snn.kernels package

Submodules

spikingjelly.clock_driven.ann2snn.kernels.onnx module

class spikingjelly.clock_driven.ann2snn.kernels.onnx.Mul[源代码]

基类:torch.nn.modules.module.Module

forward(input1, input2)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Add[源代码]

基类:torch.nn.modules.module.Module

forward(input1, input2)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Reshape[源代码]

基类:torch.nn.modules.module.Module

forward(input1, input2)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Concat(dim=[1])[源代码]

基类:torch.nn.modules.module.Module

forward(*args)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Shape[源代码]

基类:torch.nn.modules.module.Module

forward(input)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Gather(dim=1)[源代码]

基类:torch.nn.modules.module.Module

forward(input1, input2)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.Unsqueeze(dim=[1])[源代码]

基类:torch.nn.modules.module.Module

forward(input)[源代码]
training: bool
class spikingjelly.clock_driven.ann2snn.kernels.onnx.TopologyAnalyser[源代码]

基类:object

这个类通过onnx分析模型的拓扑结构,方便后续处理 此处还有更多更好的实现方法,欢迎开发者不断优化

This class analyzes the topological structure of the model through onnx to facilitate subsequent processing There are better implementation methods here, developers are welcome to continue to optimize

add_data_node(a)[源代码]
insert(a, b, info=None)[源代码]
findNext(id)[源代码]
findPre(id)[源代码]
find_pre_module(module_name)[源代码]
find_next_module(module_name)[源代码]
update_module_idx(onnx_graph)[源代码]
analyse(onnx_graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.pytorch2onnx_model(model: torch.nn.modules.module.Module, data, **kargs) onnx.ModelProto[源代码]
参数
  • model – 待转换的PyTorch模型

  • data – 用于转换的数据(用来确定输入维度)

  • log_dir – 输出文件夹

转换PyTorch模型到onnx模型

参数
  • model – the PyTorch model to be converted

  • data – The data used for conversion (used to determine the input dimension)

  • log_dir – output folder

Convert PyTorch model to onnx model

spikingjelly.clock_driven.ann2snn.kernels.onnx.onnx2pytorch_model(model: onnx.ModelProto, _converter) torch.nn.modules.module.Module[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.layer_reduction(model: onnx.ModelProto) onnx.ModelProto[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.rate_normalization(model: onnx.ModelProto, data: torch.Tensor, **kargs) onnx.ModelProto[源代码]
参数
  • model – ANN模型,类型为onnx.ModelProto

  • data – 用于转换的数据,类型为torch.Tensor

  • channelwise – 如果为``True``,则控制激活幅值的统计是channelwise的;否则,控制激活幅值的统计是layerwise的

  • robust – 如果为``True``,则控制激活幅值的统计是激活的99.9百分位;否则,控制激活幅值的统计是激活的最值

  • eps – epsilon;未设置值时默认1e-5

发放率归一化

参数
  • model – ANN model, the type is onnx.ModelProto

  • data – the data used for conversion, the type is torch.Tensor

  • channelwise – If True , the statistics that control the activation amplitude are channelwise; otherwise, the statistics that control the activation amplitude are layerwise

  • robust – If True, the statistic of the control activation amplitude is the 99.9th percentile of activation; otherwise, the statistic of the activation amplitude is the maximum value of activation

  • eps – epsilon; if no value is set, the default is 1e-5

normalize the firing rate

spikingjelly.clock_driven.ann2snn.kernels.onnx.save_model(model: onnx.ModelProto, f=None)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.move_constant_to_initializer(graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.print_onnx_model(graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.absorb_bn(graph, topo_analyser)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.remove_unreferenced_initializer(graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.update_topology(graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.find_node_by_output(output_name, graph)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.scale_node_weight_bias(topo_analyser, graph, node_idx, scale)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.get_onnx_output(model, numpy_tensor)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.get_intermediate_output_statistics(model, numpy_tensor, channelwise=False, debug=None)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.normalize_model(model, output_statistics, topo_analyser, robust_norm=True, channelwise=False, eps=1e-05)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.onnx.load_parameters(model: spikingjelly.clock_driven.ann2snn.kernels.onnx._pt_model, initializer)[源代码]

spikingjelly.clock_driven.ann2snn.kernels.pytorch module

spikingjelly.clock_driven.ann2snn.kernels.pytorch.layer_reduction(model: torch.nn.modules.module.Module) torch.nn.modules.module.Module[源代码]
spikingjelly.clock_driven.ann2snn.kernels.pytorch.rate_normalization(model: torch.nn.modules.module.Module, data: torch.Tensor, **kargs) torch.nn.modules.module.Module[源代码]
spikingjelly.clock_driven.ann2snn.kernels.pytorch.save_model(model: torch.nn.modules.module.Module, f)[源代码]
spikingjelly.clock_driven.ann2snn.kernels.pytorch.absorb(param_module, bn_module)[源代码]

Module contents