spikingjelly.activation_based.op_counter.base 源代码

import logging
from collections import defaultdict
from typing import Any, Callable, Optional

import torch
import torch.nn as nn
from torch.autograd.graph import register_multi_grad_hook
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten
from torch.utils.module_tracker import ModuleTracker

logger = logging.getLogger(__name__)
_arrow = chr(0x2937)


__all__ = [
    "ActiveModuleTracker",
    "BaseCounter",
    "is_binary_tensor",
    "DispatchCounterMode",
    "FunctionCounterMode",
]


[文档] def is_binary_tensor(x: torch.Tensor) -> bool: r""" **API Language:** :ref:`中文 <is_binary_tensor-cn>` | :ref:`English <is_binary_tensor-en>` ---- .. _is_binary_tensor-cn: * **中文** 判断输入张量 ``x`` 是否为二元张量(即所有元素都在 {0, 1} 中或 ``dtype`` 为 ``bool``)。 :param x: 输入张量 :type x: torch.Tensor :return: 如果 ``x`` 是二元张量或 bool 张量则返回 ``True`` :rtype: bool ---- .. _is_binary_tensor-en: * **English** Check if the input tensor ``x`` is a binary tensor (all elements are in {0, 1} or its ``dtype`` is ``bool``). :param x: input tensor :type x: torch.Tensor :return: ``True`` if ``x`` is a binary tensor or a bool tensor :rtype: bool """ if x.dtype == torch.bool: return True value = bool((x.eq(0) | x.eq(1)).all().item()) return value
[文档] class ActiveModuleTracker(ModuleTracker): def __init__(self): r""" **API Language:** :ref:`中文 <ActiveModuleTracker.__init__-cn>` | :ref:`English <ActiveModuleTracker.__init__-en>` ---- .. _ActiveModuleTracker.__init__-cn: * **中文** * **中文** 模块追踪器,用于在 PyTorch 的前向和反向传播过程中追踪模块的调用层次结构。 它通过在模块的前向和反向钩子上进行回调来记录当前活跃的模块 :attr:`active_modules` 。 :attr:`active_modules` 和 :attr:`parents` 的区别在于:前者是 ``nn.Module`` 的集合, 后者是 ``str`` (模块名)的集合。 ---- .. _ActiveModuleTracker.__init__-en: * **English** * **English** Module tracker that tracks the module call hierarchy during PyTorch forward and backward passes. It records the currently executing module instances to :attr:`active_modules` through callbacks on module forward and backward hooks. Attributes :attr:`active_modules` and :attr:`parents` are different: the former is a set of ``nn.Module`` instances, while the latter is a set of ``str`` (module names). :return: None :rtype: None """ super().__init__() self.active_modules: set[nn.Module] = set() # align with self.parents: set[str] def _get_append_fn(self, mod, name, is_bw): def fn(*args) -> None: if is_bw: self._maybe_set_engine_callback() if name in self.parents: logger.info( "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s", name, "backward" if is_bw else "forward", ) self.parents.add(name) self.active_modules.add(mod) return fn def _get_pop_fn(self, mod, name, is_bw): def fn(*args) -> None: if name in self.parents: self.parents.remove(name) else: logger.info( "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", name, "backward" if is_bw else "forward", ) if not self.active_modules: raise RuntimeError("active_modules stack underflow") if mod in self.active_modules: self.active_modules.remove(mod) else: logger.info( "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s", name, "backward" if is_bw else "forward", ) return fn def _fw_pre_hook(self, mod, input) -> None: name = self._get_mod_name(mod) self._get_append_fn(mod, name, False)() args, _ = tree_flatten(input) tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] if tensors: self._hooks.append( register_multi_grad_hook(tensors, self._get_pop_fn(mod, name, True)) ) def _fw_post_hook(self, mod, input, output) -> None: name = self._get_mod_name(mod) self._get_pop_fn(mod, name, False)() args, _ = tree_flatten(output) tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] if tensors: self._hooks.append( register_multi_grad_hook(tensors, self._get_append_fn(mod, name, True)) )
[文档] class BaseCounter: def __init__(self): r""" **API Language:** :ref:`中文 <BaseCounter.__init__-cn>` | :ref:`English <BaseCounter.__init__-en>` ---- .. _BaseCounter.__init__-cn: * **中文** * **中文** 操作计数器的基类。所有具体的计数器实现都继承自此类。 该基类提供了计数器的核心属性: - :attr:`records`: 存储计数记录,结构为 ``dict[scope][operation] = count`` - :attr:`rules`: 定义如何计算各个操作的计数的函数 - :attr:`ignore_modules`: 需要忽略的模块列表,这些模块中的操作不会被计数 子类需要实现具体的规则 ``rules`` 来定义如何计算特定操作的计数。 ---- .. _BaseCounter.__init__-en: * **English** * **English** Base class for operation counters. All concrete counter implementations inherit from this class. This base class provides core attributes for counters: - :attr:`records`: stores count records, structured as ``dict[scope][operation] = count`` - :attr:`rules`: functions that define how to calculate counts for each operation - :attr:`ignore_modules`: list of modules to ignore. Operations within these modules will not be counted Subclasses need to implement specific rule functions in :attr:`rules` to define how to calculate counts for particular operations. :return: None :rtype: None """ self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.rules: dict[Any, Callable] = {} self.ignore_modules: list[nn.Module] = []
[文档] def has_rule(self, func) -> bool: r""" **API Language:** :ref:`中文 <BaseCounter.has_rule-cn>` | :ref:`English <BaseCounter.has_rule-en>` ---- .. _BaseCounter.has_rule-cn: * **中文** :param func: 待判断的函数。其类型应与 :attr:`rules` 的键类型一致 :type func: Any :return: ``func`` 是否有对应的计数规则 :rtype: bool ---- .. _BaseCounter.has_rule-en: * **English** :param func: the function or operation to be checked. Its type should be the same as the keys in :attr:`rules` :type func: Any :return: whether ``func`` has a corresponding counting rule :rtype: bool """ return func in self.rules
[文档] def count( self, func, args: tuple, kwargs: dict, out, active_modules: Optional[set[nn.Module]] = None, parent_names: Optional[set[str]] = None, ) -> int: r""" **API Language:** :ref:`中文 <BaseCounter.count-cn>` | :ref:`English <BaseCounter.count-en>` ---- .. _BaseCounter.count-cn: * **中文** 根据 :attr:`rules` ,计算一次函数或操作调用所产生的计数值。 :param func: 待计算的函数或操作。其类型应与 :attr:`rules` 的键类型一致 :type func: Any :param args: `func` 的位置参数 :type args: tuple :param kwargs: `func` 的关键字参数 :type kwargs: dict :param out: `func` 输出 :type out: Any :param active_modules: 当前处于活跃状态的模块集合。大多数计数器可忽略该参数, 但需要结合模块上下文做语义统计的计数器可以使用它 :type active_modules: Optional[set[nn.Module]] :param parent_names: 当前活跃模块名称集合。大多数计数器可忽略该参数 :type parent_names: Optional[set[str]] :return: 计算得到的计数值 :rtype: int ---- .. _BaseCounter.count-en: * **English** Calculate the count for a function or operation call according to :attr:`rules`. :param func: the function or operation to be calculated. Its type should be the same as the keys in :attr:`rules` :type func: Any :param args: positional arguments of `func` :type args: tuple :param kwargs: keyword arguments of `func` :type kwargs: dict :param out: output of `func` :type out: Any :param active_modules: currently active module instances. Most counters can ignore it, while context-aware counters may use it for semantic counting :type active_modules: Optional[set[nn.Module]] :param parent_names: names of the currently active parent modules. Most counters can ignore it :type parent_names: Optional[set[str]] :return: the calculated count :rtype: int """ return int(self.rules[func](args, kwargs, out))
[文档] def record(self, scope, func, value): r""" **API Language:** :ref:`中文 <BaseCounter.record-cn>` | :ref:`English <BaseCounter.record-en>` ---- .. _BaseCounter.record-cn: * **中文** 向 :attr:`records` 中添加记录。 :param scope: 模块作用域字符串,如 ``"SimpleNet.lif1"`` :type scope: str :param func: 待记录的函数或操作。其类型应与 :attr:`rules` 的键类型一致 :type func: Any :param value: 计数值 :type value: int ---- .. _BaseCounter.record-en: * **English** Record the calculated count to :attr:`records`. :param scope: the module scope, e.g., ``"SimpleNet.lif1"`` :type scope: str :param func: the function or operation to be recorded. Its type should be the same as the keys in :attr:`rules` :type func: Any :param value: the calculated count :type value: int """ self.records[scope][func] += value
[文档] def get_counts(self) -> dict[str, dict[Any, int]]: r""" **API Language:** :ref:`中文 <BaseCounter.get_counts-cn>` | :ref:`English <BaseCounter.get_counts-en>` ---- .. _BaseCounter.get_counts-cn: * **中文** :return: 所有计数记录 :attr:`records` :rtype: dict[str, dict[Any, int]] ---- .. _BaseCounter.get_counts-en: * **English** :return: all count records in :attr:`records` :rtype: dict[str, dict[Any, int]] """ return {k: dict(v) for k, v in self.records.items()}
[文档] def get_total(self) -> int: r""" **API Language:** :ref:`中文 <BaseCounter.get_total-cn>` | :ref:`English <BaseCounter.get_total-en>` ---- .. _BaseCounter.get_total-cn: * **中文** :return: 顶层作用域 ``"Global"`` 下所有计数的总和。 :rtype: int ---- .. _BaseCounter.get_total-en: * **English** :return: the total count of all records in the ``"Global"`` scope. :rtype: int """ return sum(self.records["Global"].values())
[文档] def reset(self): r""" **API Language:** :ref:`中文 <BaseCounter.reset-cn>` | :ref:`English <BaseCounter.reset-en>` ---- .. _BaseCounter.reset-cn: * **中文** 重置计数器,清空所有已记录的计数。 此方法会将 :attr:`records` 重新初始化为空的嵌套字典,移除之前累积的全部计数结果。 适用于开始新的计数会话之前显式清零计数器状态。 :return: ``None`` :rtype: None ---- .. _BaseCounter.reset-en: * **English** Reset the counter and clear all recorded counts. This method reinitializes :attr:`records` to an empty nested dictionary, removing all previously accumulated count results. Call it before starting a new counting session when a counter instance is reused. :return: ``None`` :rtype: None """ self.records = defaultdict(lambda: defaultdict(int))
[文档] class DispatchCounterMode(TorchDispatchMode): def __init__( self, counters: list[BaseCounter], strict: bool = False, verbose: bool = False ): r""" **API Language:** :ref:`中文 <DispatchCounterMode.__init__-cn>` | :ref:`English <DispatchCounterMode.__init__-en>` ---- .. _DispatchCounterMode.__init__-cn: * **中文** 基于 PyTorch 的 Dispatch 机制的 **上下文管理器** ,用于计算aten操作对应计数。 该类通过重写 ``__torch_dispatch__`` 方法来捕捉所有 PyTorch aten 操作的调用,并使用注册的计数器 来统计这些操作的某些计数。 **机制:** 1. 通过 :class:`ActiveModuleTracker` 追踪当前执行所在的模块层级 2. 对于每个被拦截的操作,检查是否有对应的计数规则 3. 如果存在规则且不在被忽略的模块中,则调用规则函数计算计数值 4. 将计数值记录到每一个父模块作用域中。 :param counters: 计数器列表 :type counters: list[BaseCounter] :param strict: 如果为 ``True`` ,当遇到未定义规则的操作时会报错;否则,未定义的操作将被跳过。 默认为 ``False`` :type strict: bool :param verbose: 如果为 ``True`` ,会在控制台打印每个被计数的操作及其计数值 :type verbose: bool :return: 上下文管理器对象 :rtype: DispatchCounterMode ---- .. _DispatchCounterMode.__init__-en: * **English** **Context manager** based on PyTorch's Dispatch mechanism for counting aten operations. It intercepts all PyTorch aten operations through overriding `__torch_dispatch__` and uses registered counters to track these operations. **Working Mechanism:** 1. Tracks the current module hierarchy using :class:`ActiveModuleTracker` 2. For each intercepted operation, checks if there's a corresponding counting rule 3. If a rule exists and the operation is not in an ignored module, calls the rule function to calculate the count 4. Records the count to the parent module scope :param counters: list of counters :type counters: list[BaseCounter] :param strict: if ``True``, raises ``NotImplementedError`` when encountering operations without defined rules; if ``False``, skip the operations without defined rules. Default to ``False``. :type strict: bool :param verbose: if ``True``, prints each counted operation and its count to the console :type verbose: bool :return: Context manager object :rtype: DispatchCounterMode ---- * **代码示例 | Example** .. code-block:: python from spikingjelly.activation_based.op_counter import ( FlopCounter, DispatchCounterMode, ) import torch import torch.nn as nn class SimpleNet(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(100, 50) def forward(self, x): return self.linear(x) model = SimpleNet() x = torch.randn(32, 100) # Initialize counter flop_counter = FlopCounter() with DispatchCounterMode([flop_counter], verbose=True): output = model(x) # Get and print results print("FLOP counts:", flop_counter.get_total()) """ super().__init__() self.counters = counters self.strict = strict self.verbose = verbose self.module_tracker = ActiveModuleTracker() def __enter__(self): self.module_tracker.__enter__() return super().__enter__() def __exit__(self, *args): ret = super().__exit__(*args) self.module_tracker.__exit__(*args) return ret def _should_skip(self, counter, func) -> bool: active_modules = self.module_tracker.active_modules for am in active_modules: if isinstance(am, tuple(counter.ignore_modules)): # inside a ignored module if self.verbose: print( f"{_arrow} ignored by {counter.__class__.__name__} as it is " f"inside {am.__class__.__name__}" ) return True parent_names = self.module_tracker.parents if not counter.has_rule(func): # stats rule not defined if self.strict: raise NotImplementedError( f"DispatchCounterMode: {parent_names} - {resolve_name(func)}" f" not defined by {counter.__class__.__name__}. " f"To disable this error, " f"set strict=False when initializing {counter.__class__.__name__}." ) if self.verbose: print(f"{_arrow} not defined by {counter.__class__.__name__}") return True return False def __torch_dispatch__(self, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs out = func(*args, **kwargs) parent_names = self.module_tracker.parents active_modules = set(self.module_tracker.active_modules) parent_names_snapshot = set(parent_names) if self.verbose: print(f"DispatchCounterMode: {parent_names} - {resolve_name(func)}") for counter in self.counters: if self._should_skip(counter, func): continue value = counter.count( func, args, kwargs, out, active_modules=active_modules, parent_names=parent_names_snapshot, ) if self.verbose: print(f"{_arrow} + {value} [{counter.__class__.__name__}]") for parent in parent_names_snapshot: counter.record(parent, func, value) # add the count to every ancestor if hasattr(counter, "finalize_record"): counter.finalize_record() return out
[文档] class FunctionCounterMode(TorchFunctionMode): def __init__( self, counters: list[BaseCounter], strict: bool = False, verbose: bool = False ): r""" **API Language:** :ref:`中文 <FunctionCounterMode.__init__-cn>` | :ref:`English <FunctionCounterMode.__init__-en>` ---- .. _FunctionCounterMode.__init__-cn: * **中文** 基于 PyTorch Function 机制的 **上下文管理器** ,用于计算函数的计数。 该类通过重写 ``__torch_function__`` 方法来拦截所有 PyTorch 函数调用,并使用注册的计数器来统计这些 操作的某些计数。 工作原理与 :class:`DispatchCounterMode` 类似。 :param counters: 计数器列表 :type counters: list[BaseCounter] :param strict: 如果为 ``True``,当遇到未定义规则的操作时会报错;否则,未定义的操作将被跳过。 默认为 ``False`` :type strict: bool :param verbose: 如果为 ``True``,会在控制台打印每个被计数的操作及其计数值 :type verbose: bool :return: 上下文管理器对象 :rtype: FunctionCounterMode ---- .. _FunctionCounterMode.__init__-en: * **English** **Context manager** based on PyTorch's Function mechanism for counting operations. It intercepts all PyTorch function calls through overriding ``__torch_function__`` and uses registered counters to track these operations. It has a similar working mechanism to :class:`DispatchCounterMode` . :param counters: list of counters :type counters: list[BaseCounter] :param strict: if ``True``, raises ``NotImplementedError`` when encountering operations without defined rules; if ``False``, skips operations without defined rules. Default to ``False`` :type strict: bool :param verbose: if ``True``, prints each counted operation and its count to the console :type verbose: bool :return: Context manager object :rtype: FunctionCounterMode """ super().__init__() self.counters = counters self.strict = strict self.verbose = verbose self.module_tracker = ActiveModuleTracker() def __enter__(self): self.module_tracker.__enter__() return super().__enter__() def __exit__(self, *args): ret = super().__exit__(*args) self.module_tracker.__exit__(*args) return ret def _should_skip(self, counter, func) -> bool: parent_names = self.module_tracker.parents if not counter.has_rule(func): # stats rule not defined if self.strict: raise NotImplementedError( f"FunctionCounterMode: {parent_names} - {resolve_name(func)} " f"not defined by {counter.__class__.__name__}. " f"To disable this error, " f"set strict=False when initializing {counter.__class__.__name__}." ) if self.verbose: print(f"{_arrow} not defined by {counter.__class__.__name__}") return True active_modules = self.module_tracker.active_modules for am in active_modules: if isinstance(am, tuple(counter.ignore_modules)): # inside a ignored module if self.verbose: print( f"{_arrow} ignored by {counter.__class__.__name__} as it is " f"inside {am.__class__.__name__}" ) return True if self.verbose: print(f"{_arrow} counted by {counter.__class__.__name__}") return False def __torch_function__(self, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs out = func(*args, **kwargs) parent_names = self.module_tracker.parents active_modules = set(self.module_tracker.active_modules) parent_names_snapshot = set(parent_names) if self.verbose: print(f"FunctionCounterMode: {parent_names} - {resolve_name(func)}") for counter in self.counters: if self._should_skip(counter, func): continue value = counter.count( func, args, kwargs, out, active_modules=active_modules, parent_names=parent_names_snapshot, ) if self.verbose: print(f"{_arrow} + {value}") for parent in parent_names_snapshot: counter.record(parent, func, value) # add the count to every ancestor if hasattr(counter, "finalize_record"): counter.finalize_record() return out