spikingjelly.activation_based.ann2snn.converter 源代码

import warnings
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import fx
from torch.nn.utils.fusion import fuse_conv_bn_eval
from tqdm import tqdm

from spikingjelly.activation_based.ann2snn.factories import HookFactory, NeuronFactory
from spikingjelly.activation_based.ann2snn.operators import (
    TDGELU,
    TDLayerNorm,
    TDLinear,
    TDMultiheadAttention,
    TDScaledDotProductAttention,
)
from spikingjelly.activation_based.ann2snn.rules import ActivationRule, ReLURule
from spikingjelly.activation_based.ann2snn.threshold import ThresholdOptimizer


[文档] class Converter: def __init__( self, dataloader: Iterable, device: Optional[Union[torch.device, str]] = None, mode: Union[str, float] = "Max", momentum: float = 0.1, fuse_flag: bool = True, rules: Optional[List[ActivationRule]] = None, neuron_factory: Optional[NeuronFactory] = None, threshold_optimizer: Optional[ThresholdOptimizer] = None, ) -> None: r""" **API Language** - :ref:`中文 <Converter.__init__-cn>` | :ref:`English <Converter.__init__-en>` ---- .. _Converter.__init__-cn: * **中文** ``Converter`` 是 ANN2SNN 转换器对象,而不是用于推理的 :class:`torch.nn.Module`。它提供显式转换方法: :meth:`convert_to_spiking_neurons` 用于传统 ReLU→IFNode 校准转换, :meth:`replace_by_td_operators` 用于 Transformer TD core operator 替换。 ANN2SNN教程见此处 `ANN转换SNN <https://spikingjelly.readthedocs.io/zh_CN/latest/tutorials/cn/ann2snn.html>`_ 。 目前支持三种转换模式,由参数mode进行设置。 ReLU→IFNode 转换后 ReLU 模块被删除,SNN 需要的新模块(包括 VoltageScaler、IFNode 等)被创建并存放在 snn tailor 父模块中。 TD operator 替换不使用校准数据,只将支持的 ANN 模块替换为 TD 等价模块。 由于返回值的类型为 ``fx.GraphModule``,建议使用 ``print(fx.GraphModule.graph)`` 查看计算图及前向传播关系。更多 API 参见 `GraphModule <https://pytorch.org/docs/stable/fx.html?highlight=graphmodule#torch.fx.GraphModule>`_ 。 .. warning:: 必须确保ANN中的 ``ReLU`` 为module而非function。 您最好在ANN模型中使用平均池化而不是最大池化。否则,可能会损害转换后的SNN模型的性能。 :param dataloader: 数据加载器。迭代返回的每个 batch 必须支持 ``data[0]`` 取出输入张量,例如 ``(input, label)`` 或 ``(input,)``。 :type dataloader: Iterable :param device: Device :type device: torch.device or str or None :param mode: 转换模式。目前支持三种模式:最大电流转换模式 ``mode="max"``, 99.9% 电流转换模式 ``mode="99.9%"``,以及缩放转换模式 ``mode=x`` (``0 < x <= 1``)。 :type mode: str, float :param momentum: 动量值,用于modules.VoltageHook :type momentum: float :param fuse_flag: 标志位,设置为True,则进行conv与bn的融合,反之不进行。 :type fuse_flag: bool :param rules: 激活函数转换规则列表。每个规则必须实现 ``match``、``insert_hooks``、``find_replacements`` 和 ``replace_with_neurons`` 。默认使用 ``[ReLURule()]`` 。 :type rules: Optional[List[ActivationRule]] :param neuron_factory: 脉冲神经元工厂。默认使用 ``NeuronFactory()`` (IFNode, threshold=1.0)。 :type neuron_factory: Optional[NeuronFactory] :param threshold_optimizer: 阈值优化器。默认使用 ``ThresholdOptimizer(strategy="fixed")``。 :type threshold_optimizer: Optional[ThresholdOptimizer] ---- .. _Converter.__init__-en: * **English** ``Converter`` is an ANN2SNN conversion driver, not a :class:`torch.nn.Module` for inference. It provides explicit conversion methods: :meth:`convert_to_spiking_neurons` for the traditional ReLU-to-IFNode calibrated conversion, and :meth:`replace_by_td_operators` for Transformer TD core operator replacement. ANN2SNN tutorial is here `ANN2SNN <https://spikingjelly.readthedocs.io/zh_CN/latest/tutorials/en/ann2snn.html>`_ . Three common methods are implemented here, which can be selected by the value of parameter mode. In the ReLU-to-IFNode path, ReLU modules will be removed, and new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module ``snn tailor``. The TD operator replacement path does not use calibration data; it only replaces supported ANN modules with TD-equivalent modules. Since the converted model is an ``fx.GraphModule``, use ``print(fx.GraphModule.graph)`` to inspect the generated computation graph. More APIs are here `GraphModule <https://pytorch.org/docs/stable/fx.html?highlight=graphmodule#torch.fx.GraphModule>`_ . .. warning:: Make sure that ``ReLU`` is module rather than function. You'd better use ``avgpool`` rather than ``maxpool`` in your ann model. If not, the performance of the converted snn model may be ruined. :param dataloader: Dataloader for converting. Each yielded batch must support ``data[0]`` as the input tensor, for example ``(input, label)`` or ``(input,)``. :type dataloader: Iterable :param device: Device :type device: torch.device or str or None :param mode: Conversion mode. Now support three mode, MaxNorm (``mode="max"``), RobustNorm (``mode="99.9%"``), and scaling mode (``mode=x``, where ``0 < x <= 1``). :type mode: str, float :param momentum: Momentum value used by modules.VoltageHook :type momentum: float :param fuse_flag: Bool specifying if fusion of the conv and the bn happens, by default it happens. :type fuse_flag: bool :param rules: List of activation conversion rules. Each rule must implement ``match``, ``insert_hooks``, ``find_replacements`` and ``replace_with_neurons``. Defaults to ``[ReLURule()]``. :type rules: Optional[List[ActivationRule]] :param neuron_factory: Neuron factory. Defaults to ``NeuronFactory()`` (IFNode, threshold=1.0). :type neuron_factory: Optional[NeuronFactory] :param threshold_optimizer: Threshold optimizer. Defaults to ``ThresholdOptimizer(strategy="fixed")``. :type threshold_optimizer: Optional[ThresholdOptimizer] """ self.mode = mode self.fuse_flag = fuse_flag self.dataloader = dataloader self._check_mode() self.device = device self.momentum = momentum self.rules = rules if rules is not None else [ReLURule()] self.neuron_factory = ( neuron_factory if neuron_factory is not None else NeuronFactory() ) self.threshold_optimizer = ( threshold_optimizer if threshold_optimizer is not None else ThresholdOptimizer() )
[文档] def convert_to_spiking_neurons(self, ann: nn.Module) -> torch.fx.GraphModule: r""" **API Language** - :ref:`中文 <Converter.convert_to_spiking_neurons-cn>` | :ref:`English <Converter.convert_to_spiking_neurons-en>` ---- .. _Converter.convert_to_spiking_neurons-cn: * **中文** 将带有 ReLU module 的 ANN 转换为 SNN ``GraphModule``。该方法会执行 FX tracing、可选 Conv-BN 融合、VoltageHook 校准和神经元替换。 :param ann: 待转换的 ANN。 :type ann: torch.nn.Module :return: 转换得到的 SNN。 :rtype: torch.fx.GraphModule ---- .. _Converter.convert_to_spiking_neurons-en: * **English** Convert an ANN with ReLU modules to an SNN ``GraphModule``. This method performs FX tracing, optional Conv-BN fusion, VoltageHook calibration, and neuron replacement. :param ann: ANN to be converted. :type ann: torch.nn.Module :return: Converted SNN. :rtype: torch.fx.GraphModule """ if self.device is None: self.device = next(ann.parameters()).device ann = fx.symbolic_trace(ann).to(self.device) ann.eval() ann_fused = self.fuse(ann, fuse_flag=self.fuse_flag).to(self.device) ann_with_hook = self.set_voltagehook( ann_fused, momentum=self.momentum, mode=self.mode, rules=self.rules ).to(self.device) with torch.no_grad(): for _, data in enumerate(tqdm(self.dataloader)): imgs = self._extract_batch_input(data) ann_with_hook( torch.as_tensor(imgs).to(device=self.device, dtype=torch.float) ) snn = self.replace_by_neurons(ann_with_hook).to(self.device) return snn
[文档] def replace_by_td_operators(self, ann: nn.Module) -> torch.fx.GraphModule: r""" **API Language:** :ref:`中文 <Converter.replace_by_td_operators-cn>` | :ref:`English <Converter.replace_by_td_operators-en>` ---- .. _Converter.replace_by_td_operators-cn: * **中文** 将 ANN 中支持的 core modules 和窄 attention 子集替换为 temporal-difference (TD) 等价算子,并返回 ``GraphModule``。当前自动 替换 :class:`torch.nn.Linear`、:class:`torch.nn.LayerNorm`、 :class:`torch.nn.GELU`、literal ``dropout_p=0.0`` 的 :func:`torch.nn.functional.scaled_dot_product_attention` 调用,以及 ``dropout=0.0``、``batch_first=True``、``need_weights=False`` 的 :class:`torch.nn.MultiheadAttention` 调用。该方法不插入 ``VoltageHook``,不运行 dataloader 校准。返回模型会保留输入模型及 已替换模块的 training/eval 状态。 该转换路径面向完整时间序列输入,约定转换后模型的输入张量使用第 0 维作为时间维,形状通常为 ``[T, ...]`` 且 ``T > 0``。TD 算子输出 浮点差分值,不是二值脉冲,也不表示 fully spike-driven 在线执行。 dtype、device 与后端行为跟随被替换算子的 PyTorch 实现;当前没有 CuPy / Triton 专用路径。该方法不改变输入模型本身,而是返回 tracing 后的 ``GraphModule``。 :param ann: 待转换的 ANN。 :type ann: torch.nn.Module :return: 已替换 core TD operators 的 ``GraphModule``。 :rtype: torch.fx.GraphModule :raises ValueError: 若 FX 图中包含当前不支持的 TD attention 配置,例如 非零 SDPA dropout、动态 SDPA 配置、``enable_gqa=True``、 ``nn.MultiheadAttention`` 的 ``dropout != 0``、 ``batch_first=False``、``need_weights=True``、 ``key_padding_mask`` 或非 packed q/k/v 参数。 ---- .. _Converter.replace_by_td_operators-en: * **English** Replace supported core modules and a narrow attention subset in an ANN with temporal-difference (TD) equivalent operators and return a ``GraphModule``. Currently, :class:`torch.nn.Linear`, :class:`torch.nn.LayerNorm`, :class:`torch.nn.GELU`, literal ``dropout_p=0.0`` :func:`torch.nn.functional.scaled_dot_product_attention` calls, and :class:`torch.nn.MultiheadAttention` calls with ``dropout=0.0``, ``batch_first=True`` and ``need_weights=False`` are replaced automatically. This method does not insert ``VoltageHook`` and does not run dataloader calibration. The returned model preserves the training/eval state of the input model and replaced modules. This conversion path targets complete time-sequence inputs. Converted models conventionally use dimension 0 as the time dimension, with shape ``[T, ...]`` and ``T > 0``. TD operators output floating-point differential values; they are not binary spikes and do not represent fully spike-driven online execution. Dtype, device, and backend behavior follow the PyTorch implementation of each replaced operator; there is no CuPy / Triton specific path currently. This method does not mutate the input model itself; it returns a traced ``GraphModule``. :param ann: ANN to be converted. :type ann: torch.nn.Module :return: ``GraphModule`` with core TD operators replaced. :rtype: torch.fx.GraphModule :raises ValueError: If the FX graph contains unsupported TD attention configurations, such as nonzero SDPA dropout, dynamic SDPA configuration, ``enable_gqa=True``, ``nn.MultiheadAttention`` with ``dropout != 0``, ``batch_first=False``, ``need_weights=True``, ``key_padding_mask``, or non-packed q/k/v parameters. """ device = self.device if device is None: try: device = next(ann.parameters()).device except StopIteration: device = torch.device("cpu") fx_model = fx.symbolic_trace(ann).to(device) modules = dict(fx_model.named_modules()) for node in list(fx_model.graph.nodes): if node.op != "call_module": continue if not isinstance(node.target, str) or node.target not in modules: continue module = modules[node.target] replacement = self._make_td_operator(module, node) if replacement is None: continue self._replace_submodule(fx_model, node.target, replacement) modules[node.target] = replacement sdpa_index = 0 for node in list(fx_model.graph.nodes): if ( node.op != "call_function" or node.target is not F.scaled_dot_product_attention ): continue sdpa_kwargs = self._parse_sdpa_node(node) target = f"td_scaled_dot_product_attention_{sdpa_index}" sdpa_index += 1 fx_model.add_submodule( target, TDScaledDotProductAttention( is_causal=sdpa_kwargs["is_causal"], scale=sdpa_kwargs["scale"], ), ) with fx_model.graph.inserting_after(node): new_node = fx_model.graph.call_module( target, args=( sdpa_kwargs["query"], sdpa_kwargs["key"], sdpa_kwargs["value"], sdpa_kwargs["attn_mask"], ), ) node.replace_all_uses_with(new_node) fx_model.graph.erase_node(node) fx_model.graph.lint() fx_model.delete_all_unused_submodules() fx_model.recompile() return fx_model
@staticmethod def _extract_batch_input(data): if isinstance(data, torch.Tensor): return data if isinstance(data, (list, tuple)): if not data: raise ValueError("Batch data is an empty list or tuple.") return data[0] if isinstance(data, dict): if not data: raise ValueError("Batch data is an empty dictionary.") for key in ("input", "image", "img", "x", "data", "pixel_values"): if key in data: return data[key] return next(iter(data.values())) return data def _check_mode(self): err_msg = "You have used a non-defined VoltageScale Method." if isinstance(self.mode, str): if self.mode[-1] == "%": try: float(self.mode[:-1]) except ValueError: raise NotImplementedError(err_msg) elif self.mode.lower() in ["max"]: pass else: raise NotImplementedError(err_msg) elif isinstance(self.mode, float): try: assert self.mode <= 1 and self.mode > 0 except AssertionError: raise NotImplementedError(err_msg) else: raise NotImplementedError(err_msg) @staticmethod def _replace_submodule( fx_model: torch.fx.GraphModule, target: str, module: nn.Module ) -> None: parent_name, _, child_name = target.rpartition(".") parent = fx_model.get_submodule(parent_name) if parent_name else fx_model setattr(parent, child_name, module) @staticmethod def _get_literal_argument( node: fx.Node, name: str, position: int, default: Any, ) -> Any: if name in node.kwargs: return node.kwargs[name] if len(node.args) > position: return node.args[position] return default @staticmethod def _parse_sdpa_node(node: fx.Node) -> Dict[str, Any]: if len(node.args) < 3: raise ValueError("SDPA node must have query, key, and value arguments.") dropout_p = Converter._get_literal_argument(node, "dropout_p", 4, 0.0) if not isinstance(dropout_p, (int, float)) or float(dropout_p) != 0.0: raise ValueError( "TD SDPA conversion only supports literal dropout_p=0.0, " f"but got {dropout_p!r}." ) enable_gqa = Converter._get_literal_argument(node, "enable_gqa", 7, False) if enable_gqa is not False: raise ValueError("TD SDPA conversion does not support enable_gqa=True.") is_causal = Converter._get_literal_argument(node, "is_causal", 5, False) if not isinstance(is_causal, bool): raise ValueError( "TD SDPA conversion only supports literal bool is_causal, " f"but got {is_causal!r}." ) scale = Converter._get_literal_argument(node, "scale", 6, None) if scale is not None and not isinstance(scale, (int, float)): raise ValueError( "TD SDPA conversion only supports literal numeric scale or None, " f"but got {scale!r}." ) return { "query": node.args[0], "key": node.args[1], "value": node.args[2], "attn_mask": Converter._get_literal_argument(node, "attn_mask", 3, None), "is_causal": is_causal, "scale": None if scale is None else float(scale), } @staticmethod def _check_mha_node(module: nn.MultiheadAttention, node: fx.Node) -> None: if module.dropout != 0.0: raise ValueError("TD MHA conversion only supports dropout=0.0.") if not module.batch_first: raise ValueError("TD MHA conversion only supports batch_first=True.") if module.kdim != module.embed_dim or module.vdim != module.embed_dim: raise ValueError( "TD MHA conversion only supports kdim == vdim == embed_dim." ) if module.bias_k is not None or module.bias_v is not None: raise ValueError("TD MHA conversion does not support add_bias_kv.") if module.add_zero_attn: raise ValueError("TD MHA conversion does not support add_zero_attn.") need_weights = Converter._get_literal_argument(node, "need_weights", 4, True) if need_weights is not False: raise ValueError("TD MHA conversion requires need_weights=False.") key_padding_mask = Converter._get_literal_argument( node, "key_padding_mask", 3, None ) if key_padding_mask is not None: raise ValueError("TD MHA conversion does not support key_padding_mask.") average_attn_weights = Converter._get_literal_argument( node, "average_attn_weights", 6, True ) if average_attn_weights is not True: raise ValueError( "TD MHA conversion does not support average_attn_weights=False." ) @staticmethod def _copy_mha_parameters( source: nn.MultiheadAttention, target: TDMultiheadAttention, ) -> None: if source.in_proj_weight is None: raise ValueError("TD MHA conversion requires packed in_proj_weight.") with torch.no_grad(): q_weight, k_weight, v_weight = source.in_proj_weight.chunk(3, dim=0) target.q_proj.weight.copy_(q_weight) target.k_proj.weight.copy_(k_weight) target.v_proj.weight.copy_(v_weight) if source.in_proj_bias is not None: q_bias, k_bias, v_bias = source.in_proj_bias.chunk(3, dim=0) target.q_proj.bias.copy_(q_bias) target.k_proj.bias.copy_(k_bias) target.v_proj.bias.copy_(v_bias) target.out_proj.weight.copy_(source.out_proj.weight) if source.out_proj.bias is not None: target.out_proj.bias.copy_(source.out_proj.bias) @staticmethod def _make_td_operator( module: nn.Module, node: Optional[fx.Node] = None, ) -> Optional[nn.Module]: if type(module) is nn.Linear: td_module = TDLinear( module.in_features, module.out_features, bias=module.bias is not None, device=module.weight.device, dtype=module.weight.dtype, ) with torch.no_grad(): td_module.weight.copy_(module.weight) if module.bias is not None: td_module.bias.copy_(module.bias) td_module.train(module.training) return td_module if type(module) is nn.LayerNorm: td_module = TDLayerNorm( module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, device=( module.weight.device if module.weight is not None else None ), dtype=( module.weight.dtype if module.weight is not None else None ), ) with torch.no_grad(): if module.weight is not None: td_module.weight.copy_(module.weight) if module.bias is not None: td_module.bias.copy_(module.bias) td_module.train(module.training) return td_module if type(module) is nn.GELU: td_module = TDGELU(approximate=module.approximate) td_module.train(module.training) return td_module if type(module) is nn.MultiheadAttention: if node is None: raise ValueError("TD MHA conversion requires an FX node.") Converter._check_mha_node(module, node) td_module = TDMultiheadAttention( module.embed_dim, module.num_heads, dropout=module.dropout, bias=module.in_proj_bias is not None, batch_first=module.batch_first, device=module.in_proj_weight.device, dtype=module.in_proj_weight.dtype, ) Converter._copy_mha_parameters(module, td_module) td_module.train(module.training) return td_module return None
[文档] @staticmethod def fuse( fx_model: torch.fx.GraphModule, fuse_flag: bool = True ) -> torch.fx.GraphModule: r""" **API Language** - :ref:`中文 <Converter.fuse-cn>` | :ref:`English <Converter.fuse-en>` ---- .. _Converter.fuse-cn: * **中文** ``fuse`` 用于conv与bn的融合。 :param fx_model: 原模型 :type fx_model: torch.fx.GraphModule :param fuse_flag: 标志位,设置为True,则进行conv与bn的融合,反之不进行。 :type fuse_flag: bool :return: conv层和bn层融合后的模型. :rtype: torch.fx.GraphModule ---- .. _Converter.fuse-en: * **English** ``fuse`` is used to fuse conv layer and bn layer. :param fx_model: Original fx_model :type fx_model: torch.fx.GraphModule :param fuse_flag: Bool specifying if fusion of the conv and the bn happens, by default it happens. :type fuse_flag: bool :return: fx_model whose conv layer and bn layer have been fused. :rtype: torch.fx.GraphModule """ if not fuse_flag: return fx_model def matches_module_pattern( pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any] ) -> bool: if len(node.args) == 0: return False nodes: Tuple[Any, fx.Node] = (node.args[0], node) for expected_type, current_node in zip(pattern, nodes): if not isinstance(current_node, fx.Node): return False if current_node.op != "call_module": return False if not isinstance(current_node.target, str): return False if current_node.target not in modules: return False if type(modules[current_node.target]) is not expected_type: return False return True def replace_node_module( node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module ): def parent_name(target: str) -> Tuple[str, str]: *parent, name = target.rsplit(".", 1) return parent[0] if parent else "", name assert isinstance(node.target, str) parent_name, name = parent_name(node.target) modules[node.target] = new_module setattr(modules[parent_name], name, new_module) patterns = [ (nn.Conv1d, nn.BatchNorm1d), (nn.Conv2d, nn.BatchNorm2d), (nn.Conv3d, nn.BatchNorm3d), ] modules = dict(fx_model.named_modules()) for pattern in patterns: for node in fx_model.graph.nodes: if matches_module_pattern(pattern, node, modules): if len(node.args[0].users) > 1: continue conv = modules[node.args[0].target] bn = modules[node.target] fused_conv = fuse_conv_bn_eval(conv, bn) replace_node_module(node.args[0], modules, fused_conv) node.replace_all_uses_with(node.args[0]) fx_model.graph.erase_node(node) fx_model.graph.lint() fx_model.delete_all_unused_submodules() fx_model.recompile() return fx_model
[文档] @staticmethod def set_voltagehook( fx_model: torch.fx.GraphModule, mode="Max", momentum=0.1, rules: Optional[List[ActivationRule]] = None, ) -> torch.fx.GraphModule: r""" **API Language** - :ref:`中文 <Converter.set_voltagehook-cn>` | :ref:`English <Converter.set_voltagehook-en>` ---- .. _Converter.set_voltagehook-cn: * **中文** ``set_voltagehook`` 用于给模型添加VoltageHook模块。这里实现了常见的三种模式,同上。 :param fx_model: 原模型 :type fx_model: torch.fx.GraphModule :param mode: 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式 :type mode: str, float :param momentum: 动量值,用于VoltageHook :type momentum: float :param rules: 自定义的激活匹配规则列表。默认值为 ``None``,此时使用 ``[ReLURule()]``,即匹配 ``ReLU`` 并为其前后插入 ``VoltageHook``。传入自定义规则可扩展匹配的激活类型或调整 hook 插入位置。 :type rules: Optional[List[ActivationRule]] :return: 带有VoltageHook的模型. :rtype: torch.fx.GraphModule ---- .. _Converter.set_voltagehook-en: * **English** ``set_voltagehook`` is used to add VoltageHook to fx_model. Three common methods are implemented here, the same as Converter.mode. :param fx_model: Original fx_model :type fx_model: torch.fx.GraphModule :param mode: Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode :type mode: str, float :param momentum: momentum value used by VoltageHook :type momentum: float :param rules: Optional list of activation matching rules. When ``None`` (default) ``[ReLURule()]`` is used, which matches ``ReLU`` and wraps it with ``VoltageHook``. Pass custom rules to match additional activation types or to change where hooks are inserted. :type rules: Optional[List[ActivationRule]] :return: fx_model with VoltageHook. :rtype: torch.fx.GraphModule """ hook_factory = HookFactory(mode=mode, momentum=momentum) hook_counts_per_prefix: Dict[str, int] = {} modules = dict(fx_model.named_modules()) active_rules = rules if rules is not None else [ReLURule()] for node in list(fx_model.graph.nodes): if node.op != "call_module": continue if node.target not in modules: continue for rule in active_rules: if rule.match(node, modules): rule.insert_hooks( fx_model, node, hook_factory, hook_counts_per_prefix ) modules = dict(fx_model.named_modules()) break fx_model.graph.lint() fx_model.recompile() return fx_model
[文档] def replace_by_neurons( self, fx_model: torch.fx.GraphModule ) -> torch.fx.GraphModule: r""" **API Language** - :ref:`中文 <Converter.replace_by_neurons-cn>` | :ref:`English <Converter.replace_by_neurons-en>` ---- .. _Converter.replace_by_neurons-cn: * **中文** 将 ``self.rules`` 匹配到的激活节点替换为脉冲神经元,并按 ``self.threshold_optimizer`` 计算出的阈值 完成 ``VoltageScaler(1/v_threshold) -> Neuron -> VoltageScaler(v_threshold)`` 的等价变换。 默认规则、神经元工厂与阈值优化器可复现原 ``replace_by_ifnode`` 的 ``ReLU -> IFNode`` 替换语义。 :param fx_model: 已插入校准 hook 的 ``GraphModule`` :type fx_model: torch.fx.GraphModule :return: 激活节点被替换为脉冲神经元后的模型 :rtype: torch.fx.GraphModule ---- .. _Converter.replace_by_neurons-en: * **English** Replace activations matched by ``self.rules`` with spiking neurons, applying the ``VoltageScaler(1/v_threshold) -> Neuron -> VoltageScaler(v_threshold)`` transformation using the threshold computed by ``self.threshold_optimizer``. With the default rule, neuron factory and threshold optimizer this reproduces the original ``replace_by_ifnode`` ``ReLU -> IFNode`` replacement semantics. :param fx_model: ``GraphModule`` with calibration hooks already inserted. :type fx_model: torch.fx.GraphModule :return: Model with activations replaced by spiking neurons. :rtype: torch.fx.GraphModule """ return Converter._replace_by_neurons_impl( fx_model, self.rules, self.neuron_factory, self.threshold_optimizer, )
@staticmethod def _replace_by_neurons_impl( fx_model: torch.fx.GraphModule, rules: List[ActivationRule], neuron_factory: NeuronFactory, threshold_optimizer: ThresholdOptimizer, ) -> torch.fx.GraphModule: replaced_hooks = set() replaced_activations = set() for rule in rules: modules = dict(fx_model.named_modules()) replacements = list(rule.find_replacements(fx_model, modules)) for activation_node, hook_node in replacements: if ( hook_node in replaced_hooks or activation_node in replaced_activations ): continue replaced_hooks.add(hook_node) replaced_activations.add(activation_node) rule.replace_with_neurons( fx_model, activation_node, hook_node, neuron_factory, threshold_optimizer, ) fx_model.graph.lint() fx_model.delete_all_unused_submodules() fx_model.recompile() return fx_model
[文档] @staticmethod def replace_by_ifnode(fx_model: torch.fx.GraphModule) -> torch.fx.GraphModule: r""" Replace ReLU with IF neurons (legacy API, use :meth:`replace_by_neurons` instead). :deprecated: Use :meth:`replace_by_neurons` instead. :param fx_model: Model with calibration hooks inserted. :type fx_model: torch.fx.GraphModule :return: Model with ReLU replaced by IF neurons. :rtype: torch.fx.GraphModule """ warnings.warn( "replace_by_ifnode is deprecated, use replace_by_neurons instead.", DeprecationWarning, stacklevel=2, ) return Converter._replace_by_neurons_impl( fx_model, [ReLURule()], NeuronFactory(), ThresholdOptimizer(), )