spikingjelly.activation_based.ann2snn package#
Converter#
- class spikingjelly.activation_based.ann2snn.converter.Converter(dataloader, device=None, mode='Max', momentum=0.1, fuse_flag=True)[源代码]#
基类:
Module
中文
Converter用于将带有ReLU的ANN转换为SNN。ANN2SNN教程见此处 ANN转换SNN 。
目前支持三种转换模式,由参数mode进行设置。
转换后ReLU模块被删除,SNN需要的新模块(包括VoltageScaler、IFNode等)被创建并存放在snn tailor父模块中。
由于返回值的类型为fx.GraphModule,建议使用print(fx.GraphModule.graph)查看计算图及前向传播关系。更多API参见 GraphModule 。
警告
必须确保ANN中的
ReLU为module而非function。您最好在ANN模型中使用平均池化而不是最大池化。否则,可能会损害转换后的SNN模型的性能。
- 参数:
English
Converteris used to convert ANN with to SNN.ANN2SNN tutorial is here ANN2SNN .
Three common methods are implemented here, which can be selected by the value of parameter mode.
After converting, ReLU modules will be removed. And new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module 'snn tailor'.
Due to the type of the return model is fx.GraphModule, you can use 'print(fx.GraphModule.graph)' to view how modules links and the how the forward method works. More APIs are here GraphModule .
警告
Make sure that
ReLUis module rather than function.You'd better use
avgpoolrather thanmaxpoolin your ann model. If not, the performance of the converted snn model may be ruined.- 参数:
dataloader (Dataloader) -- Dataloader for converting
device (str) -- Device
mode (str, float) -- Conversion mode. Now support three mode, MaxNorm(mode='max'), RobustNorm(mode='99.9%'), and scaling mode(mode=x, where 0<x<=1)
momentum (float) -- Momentum value used by modules.VoltageHook
fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.
- forward(ann: Module)[源代码]#
-
中文
中文
- 参数:
ann (torch.nn.Module) -- 待转换的ann
- 返回:
转换得到的snn
- 返回类型:
English
English
- 参数:
ann (torch.nn.Module) -- ann to be converted
- 返回:
snn
- 返回类型:
- static fuse(fx_model: GraphModule, fuse_flag: bool = True) GraphModule[源代码]#
-
中文
fuse用于conv与bn的融合。- 参数:
fx_model (torch.fx.GraphModule) -- 原模型
fuse_flag (bool) -- 标志位,设置为True,则进行conv与bn的融合,反之不进行。
- 返回:
conv层和bn层融合后的模型.
- 返回类型:
English
fuseis used to fuse conv layer and bn layer.- 参数:
fx_model (torch.fx.GraphModule) -- Original fx_model
fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.
- 返回:
fx_model whose conv layer and bn layer have been fused.
- 返回类型:
- static set_voltagehook(fx_model: GraphModule, mode='Max', momentum=0.1) GraphModule[源代码]#
-
中文
set_voltagehook用于给模型添加VoltageHook模块。这里实现了常见的三种模式,同上。- 参数:
fx_model (torch.fx.GraphModule) -- 原模型
mode (str, float) -- 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式
momentum (float) -- 动量值,用于VoltageHook
- 返回:
带有VoltageHook的模型.
- 返回类型:
English
set_voltagehookis used to add VoltageHook to fx_model. Three common methods are implemented here, the same as Converter.mode.- 参数:
fx_model (torch.fx.GraphModule) -- Original fx_model
mode (str, float) -- Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode
momentum (float) -- momentum value used by VoltageHook
- 返回:
fx_model with VoltageHook.
- 返回类型:
- static replace_by_ifnode(fx_model: GraphModule) GraphModule[源代码]#
-
中文
中文
replace_by_ifnode用于将模型的ReLU替换为IF脉冲神经元。- 参数:
fx_model (torch.fx.GraphModule) -- 原模型
- 返回:
将ReLU替换为IF脉冲神经元后的模型.
- 返回类型:
English
English
replace_by_ifnodeis used to replace ReLU with IF neuron.- 参数:
fx_model (torch.fx.GraphModule) -- Original fx_model
- 返回:
fx_model whose ReLU has been replaced by IF neuron.
- 返回类型:
Helper Modules and Functions#
- class spikingjelly.activation_based.ann2snn.modules.VoltageHook(scale=1.0, momentum=0.1, mode='Max')[源代码]#
基类:
Module
中文
VoltageHook的构造函数。- 参数:
- 返回:
None
- 返回类型:
None
English
Constructor of
VoltageHook.- 参数:
- 返回:
None
- 返回类型:
None
- forward(x)[源代码]#
-
中文
前向传播函数。不对输入张量做任何处理,只是抓取ReLU的激活值用于确定ANN激活范围。
- 参数:
x (torch.Tensor) -- 输入张量
- 返回:
原输入张量
- 返回类型:
English
Forward function. It doesn't process input tensors, but hooks the activation values of ReLU to determine ANN activation ranges.
- 参数:
x (torch.Tensor) -- input tensor
- 返回:
original input tensor
- 返回类型:
- class spikingjelly.activation_based.ann2snn.modules.VoltageScaler(scale=1.0)[源代码]#
基类:
Module
中文
VoltageScaler的构造函数。用于SNN推理中缩放电流。- 参数:
scale (float) -- 缩放值
- 返回:
None
- 返回类型:
None
English
Constructor of
VoltageScaler. Used for scaling current in SNN inference.- 参数:
scale (float) -- scaling value
- 返回:
None
- 返回类型:
None
- forward(x)[源代码]#
-
中文
前向传播函数。对输入电流进行缩放。
- 参数:
x (torch.Tensor) -- 输入张量,亦即输入电流
- 返回:
缩放后的电流
- 返回类型:
English
Forward function. Scales the input current.
- 参数:
x (torch.Tensor) -- input tensor, or input current
- 返回:
current after scaling
- 返回类型: