spikingjelly.activation_based.op_counter.memory_access 源代码

from collections import defaultdict
from typing import Any, Callable

import torch
import torch.nn as nn

from .base import BaseCounter

aten = torch.ops.aten
__all__ = ["MemoryAccessCounter"]


def _bytes(x: torch.Tensor):
    return x.element_size() * x.numel() if torch.is_tensor(x) else 0


def _memory_null(args, kwargs, out):
    return 0


def _memory_mm(args, kwargs, out):
    """Estimate memory access for matrix multiplication.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output
    :type out: torch.Tensor
    :return: Estimated memory access
    :rtype: int
    """
    x, y = args[:2]
    _, k = x.shape
    kk, _ = y.shape
    if k != kk:
        raise AssertionError(f"mm: inner dimensions mismatch [{x.shape} and {y.shape}]")
    return _bytes(x) + _bytes(y) + _bytes(out)


def _memory_addmm(args, kwargs, out):
    """Estimate memory access for ``out = beta * bias + alpha * (x @ y)``.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output
    :type out: torch.Tensor
    :return: Estimated memory access
    :rtype: int
    """
    bias, x, y = args[:3]
    _, k = x.shape
    kk, _ = y.shape
    if k != kk:
        raise AssertionError(
            f"addmm: inner dimensions mismatch [{x.shape} and {y.shape}]"
        )

    m = _bytes(x) + _bytes(y) + _bytes(out)
    beta = kwargs.get("beta", 1.0)
    if beta != 0:
        m += _bytes(bias)
    return m


def _memory_bmm(args, kwargs, out):
    """Estimate memory access for batched matrix multiplication.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output
    :type out: torch.Tensor
    :return: Estimated memory access
    :rtype: int
    """
    x, y = args[:2]
    b, _, k = x.shape
    bb, kk, _ = y.shape
    if b != bb or k != kk:
        raise AssertionError(
            f"bmm: batch or inner dimensions mismatch [{x.shape} and {y.shape}]"
        )
    return _bytes(x) + _bytes(y) + _bytes(out)


def _memory_baddbmm(args, kwargs, out):
    """Estimate memory access for batched add-batched-matmul.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output
    :type out: torch.Tensor
    :return: Estimated memory access
    :rtype: int
    """
    bias, x, y = args[:3]
    b, m, k = x.shape
    bb, kk, n = y.shape
    if b != bb or k != kk:
        raise AssertionError(
            f"baddmm: batch or inner dimensions mismatch [{x.shape}, {y.shape}]"
        )

    m = _bytes(x) + _bytes(y) + _bytes(out)
    beta = kwargs.get("beta", 1.0)
    if beta != 0:
        m += _bytes(bias)
    return m


def _memory_convolution(args, kwargs, out):
    """Estimate memory access for convolution.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output
    :type out: torch.Tensor
    :return: Estimated memory access
    :rtype: int
    """
    x, w, bias = args[:3]
    m = _bytes(x) + _bytes(w) + _bytes(out)
    if bias is not None:
        m += _bytes(bias)
    return m


def _memory_convolution_backward(args, kwargs, out):
    """Estimate memory access for convolution backward.

    :param args: Positional aten arguments
    :type args: tuple
    :param kwargs: Keyword aten arguments
    :type kwargs: dict
    :param out: Aten output tuple
    :type out: tuple
    :return: Estimated memory access
    :rtype: int
    """
    (
        grad_out,
        x,
        w,
        bias,
        _stride,
        _padding,
        _dilation,
        _transposed,
        _output_padding,
        _groups,
        output_mask,
    ) = args
    m = _bytes(grad_out)

    if output_mask[0]:  # grad_x
        grad_x = out[0]
        m += _bytes(w)
        m += _bytes(grad_x)

    if output_mask[1]:  # grad_weight
        grad_weight = out[1]
        m += _bytes(x)
        m += _bytes(grad_weight)

    if output_mask[2]:  # grad_bias
        grad_bias = out[2]
        m += _bytes(grad_bias)

    return m


def _memory_max_pool2d_with_indices(args, kwargs, out):
    x = args[0]
    y, indices = out
    return _bytes(x) + _bytes(y) + _bytes(indices)


def _memory_max_pool2d_with_indices_backward(args, kwargs, out):
    grad_output, indices = args[0], args[1]
    grad_x = out
    return _bytes(grad_output) + _bytes(indices) + _bytes(grad_x)


def _memory_avg_pool2d(args, kwargs, out):
    x = args[0]
    return _bytes(x) + _bytes(out)


def _memory_mean(args, kwargs, out):
    x = args[0]
    return _bytes(x) + _bytes(out)


def _memory_element_wise_binary(args, kwargs, out):
    x, y = args[:2]
    return _bytes(x) + _bytes(y) + _bytes(out)


def _memory_element_wise_unary(args, kwargs, out):
    x = args[0]
    return _bytes(x) + _bytes(out)


def _memory_stack(args, kwargs, out):
    tensor_list = args[0]
    return sum(_bytes(x) for x in tensor_list) + _bytes(out)


def _memory_clone(args, kwargs, out):
    x = args[0]
    return _bytes(x) + _bytes(out)


def _memory_full_like(args, kwargs, out):
    return _bytes(out)


def _memory_select_backward(args, kwargs, out):
    return _bytes(args[0]) + _bytes(out)


def _memory_native_batch_norm(args, kwargs, out):
    x, mean, var, gamma, beta, train = args[:6]
    m = _bytes(x) + _bytes(mean) + _bytes(var) + _bytes(gamma) + _bytes(beta)
    if train:
        m += _bytes(out[0]) + _bytes(out[1]) + _bytes(out[2])  # write x, mean, var
    else:
        m += _bytes(out[0])  # write only x
    return m


def _memory_native_batch_norm_backward(args, kwargs, out):
    grad_output, x, gamma = args[:3]
    saved_mean, saved_invstd = args[5:7]
    train, output_mask = args[-3], args[-1]

    m = 0
    if train:
        if output_mask[0]:  # grad_input
            m += _bytes(grad_output)
            m += _bytes(x) + _bytes(saved_mean) + _bytes(saved_invstd)
            m += _bytes(gamma)
            # grad_gamma and grad_beta has been computed!
        elif output_mask[1]:  # grad_gamma
            m += _bytes(grad_output)
            m += _bytes(x) + _bytes(saved_mean) + _bytes(saved_invstd)
        elif output_mask[2]:  # grad_beta
            m += _bytes(grad_output)
    else:
        if output_mask[0]:  # grad_input
            m += _bytes(grad_output) + _bytes(saved_invstd) + _bytes(gamma)
        if output_mask[1]:  # grad_gamma
            m += _bytes(x) + _bytes(saved_mean)
            if not output_mask[0]:
                m += _bytes(saved_invstd) + _bytes(grad_output)
        if output_mask[2] and not output_mask[0] and not output_mask[1]:
            m += _bytes(grad_output)
    return m


[文档] class MemoryAccessCounter(BaseCounter): r""" **API Language:** :ref:`中文 <MemoryAccessCounter-cn>` | :ref:`English <MemoryAccessCounter-en>` ---- .. _MemoryAccessCounter-cn: * **中文** * **中文** 内存访问量估计计数器。 该计数器以输入/输出张量的字节数为基础,估计算子的内存访问下界, 适合用于粗略分析不同网络结构的访存压力。具体构造参数见 :meth:`__init__ <MemoryAccessCounter.__init__-cn>`。 ---- .. _MemoryAccessCounter-en: * **English** * **English** Memory-access estimation counter. The counter estimates a lower bound of operator memory access from the byte size of input and output tensors, which is useful for coarse-grained memory traffic analysis across network structures. See :meth:`__init__ <MemoryAccessCounter.__init__-en>` for constructor parameters. """ def __init__( self, extra_rules: dict[Any, Callable] = {}, extra_ignore_modules: list[nn.Module] = [], ): r""" **API Language:** :ref:`中文 <MemoryAccessCounter.__init__-cn>` | :ref:`English <MemoryAccessCounter.__init__-en>` ---- .. _MemoryAccessCounter.__init__-cn: * **中文** 内存访问计数器,用于粗略估计深度神经网络的内存访问量。 该计数器统计操作所需的输入、输出张量的 **字节** 数,作为对内存访问量的 **下界估计** 。真实的内存访问量由算子的load store模式决定,取决于具体实现,在此不做考虑。 ``MemoryAccessCounter`` 应与 :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` 搭配使用。 .. warning:: 目前,``MemoryAccessCounter`` 支持的 aten 操作类型有限。查看源代码以获取操作列表。如需添加新操作, 可以使用 ``extra_rules`` 参数;也欢迎提交 pull request 来完善默认的 :attr:`rules` ! :param extra_rules: 额外的操作规则,格式为 ``{aten_op, func}``, 其中 ``func`` 是一个函数,接受 ``(args, kwargs, out)`` 并返回字节数 :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: 额外需要忽略的模块列表,这些模块中的操作不会被计数 :type extra_ignore_modules: list[nn.Module] ---- .. _MemoryAccessCounter.__init__-en: * **English** Memory access counter for estimating memory access in deep networks. This counter tracks the **byte** count of input and output tensors for operations as a **lower bound estimate** of memory access. Actual amount of memory access depends on the load store patterns of specific implementations, so it is not considered here. ``MemoryAccessCounter`` should be used with :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>`. .. warning:: Currently, ``MemoryAccessCounter`` supports a limited number of aten operations. See the source code for the list of operations. If you want to add a new operation, you can use the ``extra_rules`` parameter. Welcome to submit a pull request to improve the default :attr:`rules`! :param extra_rules: additional operation rules, format as ``{aten_op: func}``, where ``func`` is a function that takes ``(args, kwargs, out)`` and returns byte count :type extra_rules: dict[Any, Callable] :param extra_ignore_modules: additional list of modules to ignore. Operations within these modules will not be counted :type extra_ignore_modules: list[nn.Module] ---- * **代码示例 | Example** .. code-block:: python from spikingjelly.activation_based.op_counter import ( MemoryAccessCounter, DispatchCounterMode, ) import torch import torch.nn as nn model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10)) x = torch.randn(32, 100) memory_counter = MemoryAccessCounter() with DispatchCounterMode([memory_counter]): output = model(x) total_bytes = memory_counter.get_total() print(f"Total memory access: {total_bytes / 1024:.2f} KB") :return: None :rtype: None """ self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int)) self.rules: dict[Any, Callable] = { aten.mm.default: _memory_mm, aten.addmm.default: _memory_addmm, aten.bmm.default: _memory_bmm, aten.baddbmm.default: _memory_baddbmm, aten.convolution.default: _memory_convolution, aten.convolution_backward.default: _memory_convolution_backward, aten.native_batch_norm.default: _memory_native_batch_norm, aten.native_batch_norm_backward.default: _memory_native_batch_norm_backward, aten.max_pool2d_with_indices.default: _memory_max_pool2d_with_indices, aten.max_pool2d_with_indices_backward.default: _memory_max_pool2d_with_indices_backward, aten.avg_pool2d.default: _memory_avg_pool2d, aten.sum.default: _memory_mean, aten.sum.dim_IntList: _memory_mean, aten.mean.dim: _memory_mean, aten.add.Tensor: _memory_element_wise_binary, aten.add_.Tensor: _memory_element_wise_binary, aten.add.Scalar: _memory_element_wise_binary, aten.add_.Scalar: _memory_element_wise_binary, aten.sub.Tensor: _memory_element_wise_binary, aten.sub_.Tensor: _memory_element_wise_binary, aten.sub.Scalar: _memory_element_wise_binary, aten.sub_.Scalar: _memory_element_wise_binary, aten.rsub.Tensor: _memory_element_wise_binary, aten.rsub.Scalar: _memory_element_wise_binary, aten.neg.default: _memory_element_wise_unary, aten.neg_.default: _memory_element_wise_unary, aten.mul.Tensor: _memory_element_wise_binary, aten.mul_.Tensor: _memory_element_wise_binary, aten.mul.Scalar: _memory_element_wise_binary, aten.mul_.Scalar: _memory_element_wise_binary, aten.div.Tensor: _memory_element_wise_binary, aten.div_.Tensor: _memory_element_wise_binary, aten.div.Scalar: _memory_element_wise_binary, aten.div_.Scalar: _memory_element_wise_binary, aten.eq.Tensor: _memory_element_wise_binary, aten.eq.Scalar: _memory_element_wise_binary, aten.ne.Tensor: _memory_element_wise_binary, aten.ne.Scalar: _memory_element_wise_binary, aten.lt.Tensor: _memory_element_wise_binary, aten.lt.Scalar: _memory_element_wise_binary, aten.le.Tensor: _memory_element_wise_binary, aten.le.Scalar: _memory_element_wise_binary, aten.gt.Tensor: _memory_element_wise_binary, aten.gt.Scalar: _memory_element_wise_binary, aten.ge.Tensor: _memory_element_wise_binary, aten.ge.Scalar: _memory_element_wise_binary, aten.logical_and.default: _memory_element_wise_binary, aten.logical_or.default: _memory_element_wise_binary, aten.logical_xor.default: _memory_element_wise_binary, aten.logical_not.default: _memory_element_wise_binary, aten.sigmoid_.default: _memory_element_wise_unary, aten.stack.default: _memory_stack, aten.clone.default: _memory_clone, aten._to_copy.default: _memory_clone, aten.full_like.default: _memory_full_like, aten.ones_like.default: _memory_full_like, aten.view.default: _memory_null, aten.empty.memory_format: _memory_null, aten.select.int: _memory_null, # return a view aten.select_backward.default: _memory_select_backward, # involve load store aten.detach.default: _memory_null, aten.t.default: _memory_null, aten.expand.default: _memory_null, } self.ignore_modules = [] self.rules.update(extra_rules) self.ignore_modules.extend(extra_ignore_modules)