spikingjelly.activation_based.op_counter.ac 源代码

from collections import defaultdict
from typing import Any, Callable

import torch
import torch.nn as nn

from .base import BaseCounter

aten = torch.ops.aten
__all__ = ["ACCounter"]


def _prod(dims):
    p = 1
    for v in dims:
        p *= v
    return p


def _spike_nnz(x: torch.Tensor) -> int | None:
    """Return the number of non-zero elements if *x* is a binary spike tensor, else None."""
    if x.dtype == torch.bool:
        return int(x.count_nonzero().item())
    is_binary = bool(x.eq(0).logical_or_(x.eq(1)).all().item())
    if not is_binary:
        return None
    return int(x.count_nonzero().item())


def _ac_element_wise(args, kwargs, out):
    x = args[0]
    y = args[1] if (len(args) > 1 and torch.is_tensor(args[1])) else None
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y) if y is not None else None

    if (nnz_x is not None) and (nnz_y is not None):
        return nnz_x + nnz_y  # two binary tensors
    else:
        return out.numel()


def _ac_mm(args, kwargs, out):
    x, y = args[:2]
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        return int(out.sum().item())
    elif nnz_x is not None:
        return nnz_x * y.shape[1]
    elif nnz_y is not None:
        return nnz_y * x.shape[0]
    else:
        return 0


def _ac_addmm(args, kwargs, out):
    _, x, y = args[:3]
    alpha = kwargs.get("alpha", 1)
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.mm.default(x.double(), y.double())
        return int(result.sum().item())
    elif alpha != 1:
        return 0
    elif nnz_x is not None:
        return nnz_x * y.shape[1]
    elif nnz_y is not None:
        return nnz_y * x.shape[0]
    else:
        return 0


def _ac_bmm(args, kwargs, out):
    x, y = args[:2]
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        return int(out.sum().item())
    elif nnz_x is not None:
        return nnz_x * y.shape[2]
    elif nnz_y is not None:
        return nnz_y * x.shape[1]
    else:
        return 0


def _ac_baddbmm(args, kwargs, out):
    _, x, y = args[:3]
    alpha = kwargs.get("alpha", 1)
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.bmm.default(x.double(), y.double())
        return int(result.sum().item())
    elif alpha != 1:
        return 0
    elif nnz_x is not None:
        return nnz_x * y.shape[2]
    elif nnz_y is not None:
        return nnz_y * x.shape[1]
    else:
        return 0


def _ac_convolution(args, _kwargs, out):
    x, w, _, stride, padding, dilation, transposed, output_padding, groups = args[:9]
    nnz_x = _spike_nnz(x)
    nnz_w = _spike_nnz(w)
    if nnz_x is not None and nnz_w is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.convolution.default(
                    x.double(),
                    w.double(),
                    None,
                    stride,
                    padding,
                    dilation,
                    transposed,
                    output_padding,
                    groups,
                )
        return int(result.sum().item())
    elif nnz_x is not None:
        w_ones = torch.ones(w.shape, dtype=torch.float64, device=x.device)
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.convolution.default(
                    x.double(),
                    w_ones,
                    None,
                    stride,
                    padding,
                    dilation,
                    transposed,
                    output_padding,
                    groups,
                )
        return int(result.sum().item())
    elif nnz_w is not None:
        ref = x if transposed else out
        return nnz_w * ref.shape[0] * _prod(ref.shape[2:])
    else:
        return 0


def _ac_avg_pool2d(args, kwargs, out):
    kernel_size = args[1]
    return out.numel() * (_prod(kernel_size) - 1)


def _ac_sum(args, kwargs, out):
    x = args[0]
    return x.numel() - out.numel()


def _ac_mean(args, kwargs, out):
    x = args[0]
    return x.numel() - out.numel()


def _ac_sigmoid(args, kwargs, out):
    return out.numel()


def _ac_native_batch_norm(args, kwargs, out):
    x, train = args[0], args[5]
    n, c = x.numel(), x.shape[1]
    has_running_stats = args[3] is not None
    if train:
        ac = n - c  # mean: reduction sum per channel
        ac += n  # var: E[x^2] - E[x]^2
        ac += c  # var + eps
        ac += n  # x - mean
        if has_running_stats:
            ac += 2 * c  # old + m * (new - old) ; mean and var
    else:
        ac = c + n  # var + eps, x - mean
    return ac


[文档] class ACCounter(BaseCounter): def __init__( self, extra_rules: dict[Any, Callable] = {}, extra_ignore_modules: list[nn.Module] = [], ): r""" **API Language:** :ref:`中文 <ACCounter.__init__-cn>` | :ref:`English <ACCounter.__init__-en>` ---- .. _ACCounter.__init__-cn: * **中文** 硬件级累加(Accumulate,AC)操作计数器,从硬件视角统计网络中的纯加法次数。 与 :class:`SynOpCounter <spikingjelly.activation_based.op_counter.synop.SynOpCounter>` 的区别: ``SynOpCounter`` 只关注脉冲驱动的矩阵乘法和卷积; ``ACCounter`` 还会统计 BN、add/sub 等算子内部的加法,范围更广但语义更宽泛。 例如,SEW ResNet 中残差连接处的加法操作将被计入 AC。 ``ACCounter`` 应与 :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` 搭配使用。 .. warning:: ``ACCounter`` 只能统计前向传播期间的 AC 数量。部分专用于反向传播的算子还未覆盖。 目前,``ACCounter`` 支持的 aten 操作类型有限。查看源代码以获取操作列表。如需添加新操作, 可以使用 ``extra_rules`` 参数;也欢迎提交 pull request 来完善默认的 :attr:`rules`! .. warning:: ``ACCounter`` 会如实考虑 BN 内部的 AC 操作。如果想在推理时忽略 BN 内部的 AC,请将 BN 融合到线性层中;或者使用 ``extra_ignore_modules`` 参数将 BN 模块加入忽略列表。 :param extra_rules: 额外的操作规则,格式为 ``{aten_op: func}``, 其中 ``func`` 是一个函数,接受 ``(args, kwargs, out)`` 并返回 AC 次数 :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: 额外需要忽略的模块列表 :type extra_ignore_modules: list[torch.nn.Module] ---- .. _ACCounter.__init__-en: * **English** Hardware-level Accumulate (AC) operation counter that counts pure additions in a network from a hardware perspective. Compared with :class:`SynOpCounter <spikingjelly.activation_based.op_counter.synop.SynOpCounter>`: ``SynOpCounter`` only covers spike-driven matmul and conv; ``ACCounter`` also covers BN, add/sub, etc., thus is broader but more semantically general. ``ACCounter`` should be used with :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>`. .. warning:: ``ACCounter`` can only count ACs during the forward pass. Some operators dedicated to backward pass are not yet covered. Currently, ``ACCounter`` supports a limited number of aten operations. See the source code for the operation list. If you want to add new operations, use the ``extra_rules`` parameter. Welcome to submit a pull request to improve the default :attr:`rules`! .. warning:: ``ACCounter`` counts AC operations inside BN. To ignore AC inside BN during inference, please fuse BN into linear/conv layers; or use the ``extra_ignore_modules`` parameter to add BN modules to the ignore list. :param extra_rules: additional operation rules, format as ``{aten_op: func}``, where ``func`` is a function that takes ``(args, kwargs, out)`` and returns the AC count :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: additional list of modules to ignore :type extra_ignore_modules: list[torch.nn.Module] ---- * **代码示例 | Example** .. code-block:: python from spikingjelly.activation_based.op_counter import ( ACCounter, DispatchCounterMode, ) import torch import torch.nn as nn model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10)) spike = (torch.rand(32, 100) < 0.1).float() # sparse binary input ac_counter = ACCounter() with DispatchCounterMode([ac_counter]): model(spike) print(f"Total ACs: {ac_counter.get_total()}") # only the 1st layer counts :return: None :rtype: None """ self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.rules: dict[Any, Callable] = { aten.mm.default: _ac_mm, aten.addmm.default: _ac_addmm, aten.bmm.default: _ac_bmm, aten.baddbmm.default: _ac_baddbmm, aten.convolution.default: _ac_convolution, aten.native_batch_norm.default: _ac_native_batch_norm, aten.avg_pool2d.default: _ac_avg_pool2d, aten.sum.default: _ac_sum, aten.sum.dim_IntList: _ac_sum, aten.mean.dim: _ac_mean, aten.add.Tensor: _ac_element_wise, aten.add_.Tensor: _ac_element_wise, aten.add.Scalar: _ac_element_wise, aten.add_.Scalar: _ac_element_wise, aten.sub.Tensor: _ac_element_wise, aten.sub_.Tensor: _ac_element_wise, aten.sub.Scalar: _ac_element_wise, aten.sub_.Scalar: _ac_element_wise, aten.rsub.Tensor: _ac_element_wise, aten.rsub.Scalar: _ac_element_wise, aten.sigmoid_.default: _ac_sigmoid, # other aten ops do not involve AC operations } self.ignore_modules = [] self.rules.update(extra_rules) self.ignore_modules.extend(extra_ignore_modules)