spikingjelly.activation_based.op_counter.mac 源代码

from collections import defaultdict
from typing import Any, Callable

import torch
import torch.nn as nn

from .base import BaseCounter

aten = torch.ops.aten
__all__ = ["MACCounter"]


def _prod(dims):
    p = 1
    for v in dims:
        p *= v
    return p


def _is_spike(x: torch.Tensor) -> bool:
    """Return ``True`` if ``x`` is a binary spike tensor.

    :param x: Input tensor
    :type x: torch.Tensor
    :return: Whether ``x`` is a spike tensor
    :rtype: bool
    """
    if x.dtype == torch.bool:
        return True
    return bool(x.eq(0).logical_or_(x.eq(1)).all().item())


def _mac_mm(args, kwargs, out):
    x, y = args[:2]
    if _is_spike(x) or _is_spike(y):
        return 0
    m, k = x.shape
    _, n = y.shape
    return m * n * k


def _mac_addmm(args, kwargs, out):
    _, x, y = args[:3]
    if _is_spike(x) or _is_spike(y):
        return 0
    m, k = x.shape
    _, n = y.shape
    return m * n * k


def _mac_bmm(args, kwargs, out):
    x, y = args[:2]
    if _is_spike(x) or _is_spike(y):
        return 0
    b, m, k = x.shape
    _, _, n = y.shape
    return b * m * n * k


def _mac_baddbmm(args, kwargs, out):
    _, x, y = args[:3]
    if _is_spike(x) or _is_spike(y):
        return 0
    b, m, k = x.shape
    _, _, n = y.shape
    return b * m * n * k


def _mac_convolution(args, _kwargs, out):
    x, w, transposed = args[0], args[1], args[6]
    if _is_spike(x) or _is_spike(w):
        return 0
    b = x.shape[0]
    spatial_shape = x.shape[2:] if transposed else out.shape[2:]
    c_out, c_in, *kernel_shape = w.shape

    mac_per_position = c_in * _prod(kernel_shape)
    return mac_per_position * _prod(spatial_shape) * c_out * b


def _mac_native_batch_norm(args, kwargs, out):
    x, train = args[0], args[5]
    c = x.shape[1]
    has_affine = args[1] is not None
    has_running_stats = args[3] is not None
    mac = 0
    if has_affine:
        mac += x.numel()  # x_hat * gamma + beta
    if train and has_running_stats:
        mac += 2 * c  # old + m * (new - old) ; mean & var
    return mac


[文档] class MACCounter(BaseCounter): r""" **API Language:** :ref:`中文 <MACCounter-cn>` | :ref:`English <MACCounter-en>` ---- .. _MACCounter-cn: * **中文** * **中文** 硬件级乘累加(MAC)计数器。 该计数器统计网络中的 MAC 操作次数,并与 :class:`ACCounter <spikingjelly.activation_based.op_counter.ac.ACCounter>` 形成互补视角,用于近似刻画硬件上的乘累加开销。具体构造参数见 :meth:`__init__ <MACCounter.__init__-cn>`。 ---- .. _MACCounter-en: * **English** * **English** Hardware-level multiply-accumulate (MAC) counter. This counter tracks MAC operations in a network and complements :class:`ACCounter <spikingjelly.activation_based.op_counter.ac.ACCounter>` for approximate hardware-oriented compute analysis. See :meth:`__init__ <MACCounter.__init__-en>` for constructor parameters. """ def __init__( self, extra_rules: dict[Any, Callable] = {}, extra_ignore_modules: list[nn.Module] = [], ): r""" **API Language:** :ref:`中文 <MACCounter.__init__-cn>` | :ref:`English <MACCounter.__init__-en>` ---- .. _MACCounter.__init__-cn: * **中文** 硬件级乘累加(Multiply-Accumulate,MAC)操作计数器,统计网络中所有 MAC 操作次数。 MAC 乘法结果立即累加到累加器(如矩阵内积),而非写入新的内存位置。MAC与AC互斥:若一个计算步骤 计入 MAC,则不会计入 AC;反之亦然。 ``MACCounter`` 应与 :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` 搭配使用。 .. warning:: ``MACCounter`` 只能统计前向传播期间的 MAC 数量。部分专用于反向传播的算子还未覆盖。 目前,``MACCounter`` 支持的 aten 操作类型有限。查看源代码以获取操作列表。如需添加新操作, 可以使用 ``extra_rules`` 参数;也欢迎提交 pull request 来完善默认的 :attr:`rules`! .. warning:: ``MACCounter`` 会如实考虑 BN 内部的 MAC 操作。如果想在推理时忽略 BN 内部的 MAC,请将 BN 融合到线性层中;或者使用 ``extra_ignore_modules`` 参数将 BN 模块加入忽略列表。 :param extra_rules: 额外的操作规则,格式为 ``{aten_op: func}``, 其中 ``func`` 是一个函数,接受 ``(args, kwargs, out)`` 并返回 MAC 次数 :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: 额外需要忽略的模块列表,这些模块中的操作不会被计数 :type extra_ignore_modules: list[torch.nn.Module] ---- .. _MACCounter.__init__-en: * **English** Hardware-level Multiply-Accumulate (MAC) operation counter that counts all MAC operations in a network. MAC's multiply result is immediately accumulated into a running accumulator (not written to a new memory location). ``MACCounter`` is mutually exclusive with :class:`ACCounter <spikingjelly.activation_based.op_counter.ac.ACCounter>` : if a computation step is counted as MAC, it will not be counted as AC, and vice versa. .. warning:: ``MACCounter`` can only count MACs during the forward pass. Some operators dedicated to backward pass are not yet covered. Currently, ``MACCounter`` supports a limited number of aten operations. See the source code for the operation list. If you want to add new operations, use the ``extra_rules`` parameter. Welcome to submit a pull request to improve the default :attr:`rules`! :param extra_rules: additional operation rules, format as ``{aten_op: func}``, where ``func`` is a function that takes ``(args, kwargs, out)`` and returns the MAC count :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: additional list of modules to ignore. Operations within these modules will not be counted :type extra_ignore_modules: list[torch.nn.Module] ---- * **代码示例 | Example** .. code-block:: python from spikingjelly.activation_based.op_counter import ( MACCounter, ACCounter, DispatchCounterMode, ) import torch import torch.nn as nn model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10)) x = (torch.randn(32, 100) < 0.1).float() # sparse binary input mac_counter = MACCounter() with DispatchCounterMode([mac_counter]): output = model(x) print(f"Total MACs: {mac_counter.get_total()}") # only the 2nd layer counts :return: None :rtype: None """ self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.rules: dict[Any, Callable] = { aten.mm.default: _mac_mm, aten.addmm.default: _mac_addmm, aten.bmm.default: _mac_bmm, aten.baddbmm.default: _mac_baddbmm, aten.convolution.default: _mac_convolution, aten.native_batch_norm.default: _mac_native_batch_norm, # other aten ops do not involve MAC operations } self.ignore_modules = [] self.rules.update(extra_rules) self.ignore_modules.extend(extra_ignore_modules)