spikingjelly.activation_based.ann2snn.modules 源代码

import torch.nn as nn
import torch
import numpy as np


__all__ = ["VoltageHook", "VoltageScaler"]


[文档] class VoltageHook(nn.Module): def __init__(self, scale=1.0, momentum=0.1, mode="Max"): r""" **API Language:** :ref:`中文 <VoltageHook.__init__-cn>` | :ref:`English <VoltageHook.__init__-en>` ---- .. _VoltageHook.__init__-cn: * **中文** :class:`VoltageHook` 的构造函数。 :param scale: 缩放初始值 :type scale: float :param momentum: 动量值 :type momentum: float :param mode: 模式。``"Max"`` 表示记录ANN激活最大值;``"99.9%"`` 表示记录99.9%分位点; 0-1 的 float 表示记录激活最大值的对应倍数 :type mode: str, float :return: None :rtype: None ---- .. _VoltageHook.__init__-en: * **English** Constructor of :class:`VoltageHook`. :param scale: initial scaling value :type scale: float :param momentum: momentum value :type momentum: float :param mode: Mode. ``"Max"`` means recording the maximum value of ANN activation; ``"99.9%"`` means recording the 99.9% percentile; a float of 0-1 means recording the corresponding multiple of the maximum value :type mode: str, float :return: None :rtype: None """ super().__init__() self.register_buffer("scale", torch.tensor(scale)) self.mode = mode self.num_batches_tracked = 0 self.momentum = momentum
[文档] def forward(self, x): r""" **API Language:** :ref:`中文 <VoltageHook.forward-cn>` | :ref:`English <VoltageHook.forward-en>` ---- .. _VoltageHook.forward-cn: * **中文** 前向传播函数。不对输入张量做任何处理,只是抓取ReLU的激活值用于确定ANN激活范围。 :param x: 输入张量 :type x: torch.Tensor :return: 原输入张量 :rtype: torch.Tensor ---- .. _VoltageHook.forward-en: * **English** Forward function. It doesn't process input tensors, but hooks the activation values of ReLU to determine ANN activation ranges. :param x: input tensor :type x: torch.Tensor :return: original input tensor :rtype: torch.Tensor """ err_msg = "You have used a non-defined VoltageScale Method." if isinstance(self.mode, str): if self.mode[-1] == "%": try: s_t = torch.tensor( np.percentile(x.detach().cpu(), float(self.mode[:-1])) ) except ValueError: raise NotImplementedError(err_msg) elif self.mode.lower() in ["max"]: s_t = x.max().detach() else: raise NotImplementedError(err_msg) elif isinstance(self.mode, float) and self.mode <= 1 and self.mode > 0: s_t = x.max().detach() * self.mode else: raise NotImplementedError(err_msg) if self.num_batches_tracked == 0: self.scale = s_t else: self.scale = (1 - self.momentum) * self.scale + self.momentum * s_t self.num_batches_tracked += x.shape[0] return x
[文档] class VoltageScaler(nn.Module): def __init__(self, scale=1.0): r""" **API Language:** :ref:`中文 <VoltageScaler.__init__-cn>` | :ref:`English <VoltageScaler.__init__-en>` ---- .. _VoltageScaler.__init__-cn: * **中文** :class:`VoltageScaler` 的构造函数。用于SNN推理中缩放电流。 :param scale: 缩放值 :type scale: float :return: None :rtype: None ---- .. _VoltageScaler.__init__-en: * **English** Constructor of :class:`VoltageScaler`. Used for scaling current in SNN inference. :param scale: scaling value :type scale: float :return: None :rtype: None """ super().__init__() self.register_buffer("scale", torch.tensor(scale))
[文档] def forward(self, x): r""" **API Language:** :ref:`中文 <VoltageScaler.forward-cn>` | :ref:`English <VoltageScaler.forward-en>` ---- .. _VoltageScaler.forward-cn: * **中文** 前向传播函数。对输入电流进行缩放。 :param x: 输入张量,亦即输入电流 :type x: torch.Tensor :return: 缩放后的电流 :rtype: torch.Tensor ---- .. _VoltageScaler.forward-en: * **English** Forward function. Scales the input current. :param x: input tensor, or input current :type x: torch.Tensor :return: current after scaling :rtype: torch.Tensor """ return x * self.scale
def extra_repr(self): return "%f" % self.scale.item()