spikingjelly.activation_based.op_counter.synop 源代码

from collections import defaultdict
from typing import Any, Callable

import torch
import torch.nn as nn

from .base import BaseCounter

aten = torch.ops.aten
__all__ = ["SynOpCounter"]


def _prod(dims):
    p = 1
    for v in dims:
        p *= v
    return p


def _spike_nnz(x: torch.Tensor) -> int | None:
    """Return the number of non-zero elements if *x* is a binary spike tensor, else None."""
    if x.dtype == torch.bool:
        return int(x.count_nonzero().item())
    is_binary = bool(x.eq(0).logical_or_(x.eq(1)).all().item())
    if not is_binary:
        return None
    return int(x.count_nonzero().item())


def _synop_mm(args, kwargs, out):
    x, y = args[:2]
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        return int(out.sum().item())
    elif nnz_x is not None:
        return nnz_x * y.shape[1]
    elif nnz_y is not None:
        return nnz_y * x.shape[0]
    else:
        return 0


def _synop_addmm(args, kwargs, out):
    _, x, y = args[:3]
    alpha = kwargs.get("alpha", 1)
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.mm.default(x.double(), y.double())
        return int(result.sum().item())
    elif alpha != 1:
        return 0
    elif nnz_x is not None:
        return nnz_x * y.shape[1]
    elif nnz_y is not None:
        return nnz_y * x.shape[0]
    else:
        return 0


def _synop_bmm(args, kwargs, out):
    x, y = args[:2]
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        return int(out.sum().item())
    elif nnz_x is not None:
        return nnz_x * y.shape[2]
    elif nnz_y is not None:
        return nnz_y * x.shape[1]
    else:
        return 0


def _synop_baddbmm(args, kwargs, out):
    _, x, y = args[:3]
    alpha = kwargs.get("alpha", 1)
    nnz_x = _spike_nnz(x)
    nnz_y = _spike_nnz(y)
    if nnz_x is not None and nnz_y is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.bmm.default(x.double(), y.double())
        return int(result.sum().item())
    elif alpha != 1:
        return 0
    elif nnz_x is not None:
        return nnz_x * y.shape[2]
    elif nnz_y is not None:
        return nnz_y * x.shape[1]
    else:
        return 0


def _synop_convolution(args, _kwargs, out):
    x, w, _, stride, padding, dilation, transposed, output_padding, groups = args[:9]
    nnz_x = _spike_nnz(x)
    nnz_w = _spike_nnz(w)
    if nnz_x is not None and nnz_w is not None:
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.convolution.default(
                    x.double(),
                    w.double(),
                    None,
                    stride,
                    padding,
                    dilation,
                    transposed,
                    output_padding,
                    groups,
                )
        return int(result.sum().item())
    elif nnz_x is not None:
        w_ones = torch.ones(w.shape, dtype=torch.float64, device=x.device)
        with torch.no_grad():
            with torch._C._ExcludeDispatchKeyGuard(
                torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
            ):
                result = torch.ops.aten.convolution.default(
                    x.double(),
                    w_ones,
                    None,
                    stride,
                    padding,
                    dilation,
                    transposed,
                    output_padding,
                    groups,
                )
        return int(result.sum().item())
    elif nnz_w is not None:
        ref = x if transposed else out
        return nnz_w * ref.shape[0] * _prod(ref.shape[2:])
    else:
        return 0


[文档] class SynOpCounter(BaseCounter): def __init__( self, extra_rules: dict[Any, Callable] = {}, extra_ignore_modules: list[nn.Module] = [], ): r""" **API Language:** :ref:`中文 <SynOpCounter.__init__-cn>` | :ref:`English <SynOpCounter.__init__-en>` ---- .. _SynOpCounter.__init__-cn: * **中文** 突触操作(Synaptic Operations,SynOps)计数器,用于统计 SNN 中由 spike 驱动的突触权重累加次数。 与 :class:`ACCounter <spikingjelly.activation_based.op_counter.ac.ACCounter>` 的区别: ``ACCounter`` 除了权重层线性操作外,还会统计 BN、add/sub 等算子内部的加法; ``SynOpCounter`` 只关注脉冲驱动的矩阵乘法和卷积,范围更窄但语义更直接。 例如,SEW ResNet 中残差连接处的加法操作将不会被计入 SynOps。 ``SynOpCounter`` 应与 :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` 搭配使用。 .. warning:: ``SynOpCounter`` 只能统计前向传播期间的突触操作数量。部分专用于反向传播的算子还未覆盖。 目前,``SynOpCounter`` 支持的 aten 操作类型有限(mm、addmm、bmm、baddbmm、convolution)。 如需添加新操作,可以使用 ``extra_rules`` 参数;也欢迎提交 pull request 来完善默认的 :attr:`rules`! :param extra_rules: 额外的操作规则,格式为 ``{aten_op: func}``, 其中 ``func`` 是一个函数,接受 ``(args, kwargs, out)`` 并返回 SynOps 次数 :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: 额外需要忽略的模块列表,这些模块中的操作不会被计数 :type extra_ignore_modules: list[torch.nn.Module] ---- .. _SynOpCounter.__init__-en: * **English** Synaptic Operations (SynOps) counter that tracks spike-driven weight accumulations in SNNs. Compared with :class:`ACCounter <spikingjelly.activation_based.op_counter.ac.ACCounter>`: ``ACCounter`` also covers BN internals, add/sub, ... ``SynOpCounter`` is narrower: only spike-driven matmul and conv are considered. This makes it more directly aligned with the intuitive concept of "synaptic operations" in neuromorphic computing. ``SynOpCounter`` should be used with :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>`. .. warning:: ``SynOpCounter`` can only count SynOps during the forward pass. Some operators dedicated to backward pass are not yet covered. Currently, ``SynOpCounter`` supports mm, addmm, bmm, baddbmm, and convolution. If you want to add new operations, use the ``extra_rules`` parameter. Welcome to submit a pull request to improve the default :attr:`rules`! :param extra_rules: additional operation rules, format as ``{aten_op: func}``, where ``func`` is a function that takes ``(args, kwargs, out)`` and returns the SynOps count :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: additional list of modules to ignore. Operations within these modules will not be counted :type extra_ignore_modules: list[torch.nn.Module] ---- * **代码示例 | Example** .. code-block:: python from spikingjelly.activation_based.op_counter import ( SynOpCounter, DispatchCounterMode, ) import torch import torch.nn as nn model = nn.Linear(10, 5, bias=False) spike = (torch.rand(4, 10) < 0.2).float() counter = SynOpCounter() with DispatchCounterMode([counter]): model(spike) print(f"SynOp count: {counter.get_total()}") :return: None :rtype: None """ self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.rules: dict[Any, Callable] = { aten.mm.default: _synop_mm, aten.addmm.default: _synop_addmm, aten.bmm.default: _synop_bmm, aten.baddbmm.default: _synop_baddbmm, aten.convolution.default: _synop_convolution, # other aten ops do not involve SynOp operations } self.ignore_modules = [] self.rules.update(extra_rules) self.ignore_modules.extend(extra_ignore_modules)