spikingjelly.activation_based.ann2snn package#

Converter#

class spikingjelly.activation_based.ann2snn.converter.Converter(dataloader, device=None, mode='Max', momentum=0.1, fuse_flag=True)[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

Converter 用于将带有ReLU的ANN转换为SNN。

ANN2SNN教程见此处 ANN转换SNN

目前支持三种转换模式,由参数mode进行设置。

转换后ReLU模块被删除,SNN需要的新模块(包括VoltageScaler、IFNode等)被创建并存放在snn tailor父模块中。

由于返回值的类型为fx.GraphModule,建议使用print(fx.GraphModule.graph)查看计算图及前向传播关系。更多API参见 GraphModule

警告

必须确保ANN中的 ReLU 为module而非function。

您最好在ANN模型中使用平均池化而不是最大池化。否则,可能会损害转换后的SNN模型的性能。

参数:
  • dataloader (Dataloader) -- 数据加载器

  • device (str) -- Device

  • mode (str, float) -- 转换模式。目前支持三种模式: 最大电流转换模式mode='max',99.9%电流转换模式mode='99.9%',以及缩放转换模式mode=x(0<x<=1)

  • momentum (float) -- 动量值,用于modules.VoltageHook

  • fuse_flag (bool) -- 标志位,设置为True,则进行conv与bn的融合,反之不进行。


  • English

Converter is used to convert ANN with to SNN.

ANN2SNN tutorial is here ANN2SNN .

Three common methods are implemented here, which can be selected by the value of parameter mode.

After converting, ReLU modules will be removed. And new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module 'snn tailor'.

Due to the type of the return model is fx.GraphModule, you can use 'print(fx.GraphModule.graph)' to view how modules links and the how the forward method works. More APIs are here GraphModule .

警告

Make sure that ReLU is module rather than function.

You'd better use avgpool rather than maxpool in your ann model. If not, the performance of the converted snn model may be ruined.

参数:
  • dataloader (Dataloader) -- Dataloader for converting

  • device (str) -- Device

  • mode (str, float) -- Conversion mode. Now support three mode, MaxNorm(mode='max'), RobustNorm(mode='99.9%'), and scaling mode(mode=x, where 0<x<=1)

  • momentum (float) -- Momentum value used by modules.VoltageHook

  • fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.

forward(ann: Module)[源代码]#

API Language: 中文 | English


  • 中文

  • 中文

参数:

ann (torch.nn.Module) -- 待转换的ann

返回:

转换得到的snn

返回类型:

torch.fx.GraphModule


  • English

  • English

参数:

ann (torch.nn.Module) -- ann to be converted

返回:

snn

返回类型:

torch.fx.GraphModule

static fuse(fx_model: GraphModule, fuse_flag: bool = True) GraphModule[源代码]#

API Language: 中文 | English


  • 中文

fuse 用于conv与bn的融合。

参数:
  • fx_model (torch.fx.GraphModule) -- 原模型

  • fuse_flag (bool) -- 标志位,设置为True,则进行conv与bn的融合,反之不进行。

返回:

conv层和bn层融合后的模型.

返回类型:

torch.fx.GraphModule


  • English

fuse is used to fuse conv layer and bn layer.

参数:
  • fx_model (torch.fx.GraphModule) -- Original fx_model

  • fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.

返回:

fx_model whose conv layer and bn layer have been fused.

返回类型:

torch.fx.GraphModule

static set_voltagehook(fx_model: GraphModule, mode='Max', momentum=0.1) GraphModule[源代码]#

API Language: 中文 | English


  • 中文

set_voltagehook 用于给模型添加VoltageHook模块。这里实现了常见的三种模式,同上。

参数:
  • fx_model (torch.fx.GraphModule) -- 原模型

  • mode (str, float) -- 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式

  • momentum (float) -- 动量值,用于VoltageHook

返回:

带有VoltageHook的模型.

返回类型:

torch.fx.GraphModule


  • English

set_voltagehook is used to add VoltageHook to fx_model. Three common methods are implemented here, the same as Converter.mode.

参数:
  • fx_model (torch.fx.GraphModule) -- Original fx_model

  • mode (str, float) -- Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode

  • momentum (float) -- momentum value used by VoltageHook

返回:

fx_model with VoltageHook.

返回类型:

torch.fx.GraphModule

static replace_by_ifnode(fx_model: GraphModule) GraphModule[源代码]#

API Language: 中文 | English


  • 中文

  • 中文

replace_by_ifnode 用于将模型的ReLU替换为IF脉冲神经元。

参数:

fx_model (torch.fx.GraphModule) -- 原模型

返回:

将ReLU替换为IF脉冲神经元后的模型.

返回类型:

torch.fx.GraphModule


  • English

  • English

replace_by_ifnode is used to replace ReLU with IF neuron.

参数:

fx_model (torch.fx.GraphModule) -- Original fx_model

返回:

fx_model whose ReLU has been replaced by IF neuron.

返回类型:

torch.fx.GraphModule

Helper Modules and Functions#

class spikingjelly.activation_based.ann2snn.modules.VoltageHook(scale=1.0, momentum=0.1, mode='Max')[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

VoltageHook 的构造函数。

参数:
  • scale (float) -- 缩放初始值

  • momentum (float) -- 动量值

  • mode (str, float) -- 模式。"Max" 表示记录ANN激活最大值;"99.9%" 表示记录99.9%分位点; 0-1 的 float 表示记录激活最大值的对应倍数

返回:

None

返回类型:

None


  • English

Constructor of VoltageHook.

参数:
  • scale (float) -- initial scaling value

  • momentum (float) -- momentum value

  • mode (str, float) -- Mode. "Max" means recording the maximum value of ANN activation; "99.9%" means recording the 99.9% percentile; a float of 0-1 means recording the corresponding multiple of the maximum value

返回:

None

返回类型:

None

forward(x)[源代码]#

API Language: 中文 | English


  • 中文

前向传播函数。不对输入张量做任何处理,只是抓取ReLU的激活值用于确定ANN激活范围。

参数:

x (torch.Tensor) -- 输入张量

返回:

原输入张量

返回类型:

torch.Tensor


  • English

Forward function. It doesn't process input tensors, but hooks the activation values of ReLU to determine ANN activation ranges.

参数:

x (torch.Tensor) -- input tensor

返回:

original input tensor

返回类型:

torch.Tensor

class spikingjelly.activation_based.ann2snn.modules.VoltageScaler(scale=1.0)[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

VoltageScaler 的构造函数。用于SNN推理中缩放电流。

参数:

scale (float) -- 缩放值

返回:

None

返回类型:

None


  • English

Constructor of VoltageScaler. Used for scaling current in SNN inference.

参数:

scale (float) -- scaling value

返回:

None

返回类型:

None

forward(x)[源代码]#

API Language: 中文 | English


  • 中文

前向传播函数。对输入电流进行缩放。

参数:

x (torch.Tensor) -- 输入张量,亦即输入电流

返回:

缩放后的电流

返回类型:

torch.Tensor


  • English

Forward function. Scales the input current.

参数:

x (torch.Tensor) -- input tensor, or input current

返回:

current after scaling

返回类型:

torch.Tensor

spikingjelly.activation_based.ann2snn.utils.download_url(url, dst)[源代码]#

API Language: 中文 | English


  • 中文

从指定 URL 下载文件并保存到目标路径。支持断点续传。

参数:
  • url (str) -- 文件的下载链接

  • dst (str) -- 保存文件的目标路径

返回:

文件的总大小(以字节为单位)

返回类型:

int


  • English

Download a file from a given URL and save it to a destination path. Supports resuming interrupted downloads.

参数:
  • url (str) -- the download URL of the file

  • dst (str) -- the destination path to save the file

返回:

the total file size in bytes

返回类型:

int

Examples#