from collections import defaultdict
from typing import Any, Callable
import torch
import torch.nn as nn
from .base import BaseCounter
aten = torch.ops.aten
__all__ = ["FlopCounter"]
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
def _flop_null(args, kwargs, out):
return 0
def _flop_mm(args, kwargs, out):
"""Compute FLOPs for matrix multiplication ``out = x @ y``.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output
:type out: torch.Tensor
:return: Estimated FLOPs
:rtype: int
"""
x, y = args[:2]
m, k = x.shape
kk, n = y.shape
if k != kk:
raise AssertionError(f"mm: inner dimensions mismatch [{x.shape} and {y.shape}]")
return m * n * (2 * k - 1)
def _flop_addmm(args, kwargs, out):
"""Compute FLOPs for ``out = beta * bias + alpha * (x @ y)``.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output
:type out: torch.Tensor
:return: Estimated FLOPs
:rtype: int
"""
bias, x, y = args[:3]
m, k = x.shape
kk, n = y.shape
if k != kk:
raise AssertionError(
f"addmm: inner dimensions mismatch [{x.shape} and {y.shape}]"
)
alpha = kwargs.get("alpha", 1)
beta = kwargs.get("beta", 1)
flops = m * n * (2 * k - 1) # matmul; 2k-1 flops for each output element
if alpha != 1:
flops += m * n # scale by alpha
if beta == 1:
flops += m * n # add b to the m*n matrix
elif beta != 0:
flops += bias.numel() + m * n # scale bias, and add it to the m*n matrix
return flops
def _flop_bmm(args, kwargs, out):
"""Compute FLOPs for batched matrix multiplication.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output
:type out: torch.Tensor
:return: Estimated FLOPs
:rtype: int
"""
x, y = args[:2]
b, m, k = x.shape
bb, kk, n = y.shape
if b != bb or k != kk:
raise AssertionError(
f"bmm: batch or inner dimensions mismatch [{x.shape} and {y.shape}]"
)
return b * m * n * (2 * k - 1)
def _flop_baddbmm(args, kwargs, out):
"""Compute FLOPs for batched add-batched-matmul.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output
:type out: torch.Tensor
:return: Estimated FLOPs
:rtype: int
"""
bias, x, y = args[:3]
b, m, k = x.shape
bb, kk, n = y.shape
if b != bb or k != kk:
raise AssertionError(
f"baddmm: batch or inner dimensions mismatch [{x.shape}, {y.shape}]"
)
alpha = kwargs.get("alpha", 1)
beta = kwargs.get("beta", 1)
flops = b * m * n * (2 * k - 1) # batched matmul
if alpha != 1:
flops += b * m * n # scale by alpha
if beta == 1:
flops += b * m * n # add b to the b*m*n matrix
elif beta != 0:
flops += bias.numel() + b * m * n # scale bias, then add it to the b*m*n matrix
return flops
def _flop_convolution(args, kwargs, out):
"""Compute FLOPs for convolution.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output
:type out: torch.Tensor
:return: Estimated FLOPs
:rtype: int
"""
x, w, bias = args[:3]
transposed = kwargs.get("transposed", args[6] if len(args) > 6 else False)
b = x.shape[0]
c_out, c_in, *kernel_shape = w.shape
spatial_shape = x.shape[2:] if transposed else out.shape[2:]
flops_per_position = 2 * c_in * _prod(kernel_shape)
flops = flops_per_position * _prod(spatial_shape) * c_out * b
flops -= out.numel() # for each output element, the first add can be avoided
if bias is not None:
flops += out.numel()
return flops
def _flop_convolution_backward(args, kwargs, out):
"""Compute FLOPs for convolution backward.
:param args: Positional aten arguments
:type args: tuple
:param kwargs: Keyword aten arguments
:type kwargs: dict
:param out: Aten output tuple
:type out: tuple
:return: Estimated FLOPs
:rtype: int
"""
(
grad_out,
x,
w,
bias,
_stride,
_padding,
_dilation,
transposed,
_output_padding,
_groups,
output_mask,
) = args
flops = 0
if output_mask[0]:
grad_x = out[0]
flops += _flop_convolution(
[grad_out, w, None], {"transposed": not transposed}, grad_x
)
if output_mask[1]:
grad_weight = out[1]
if transposed:
pseudo_x = grad_out
pseudo_w = x
else:
pseudo_x = x
pseudo_w = grad_out
pseudo_x = pseudo_x.transpose(0, 1)
pseudo_w = pseudo_w.transpose(0, 1)
flops += _flop_convolution(
[pseudo_x, pseudo_w, None], {"transposed": False}, grad_weight
)
if output_mask[2] and bias is not None:
B = grad_out.shape[0]
C_out = grad_out.shape[1]
spatial_shape = grad_out.shape[2:]
flops += C_out * (B * _prod(spatial_shape) - 1)
return flops
def _flop_max_pool2d_with_indices(args, kwargs, out):
kernel_size = args[1]
y = out[0]
return y.numel() * (_prod(kernel_size) - 1) # K-1 * max
def _flop_avg_pool2d(args, kwargs, out):
kernel_size = args[1]
return out.numel() * _prod(kernel_size) # K-1 * add, 1 * div
def _flop_sum(args, kwargs, out):
x = args[0]
y = out
return x.numel() - y.numel()
def _flop_mean(args, kwargs, out):
x = args[0]
return x.numel()
def _flop_add(args, kwargs, out):
alpha = kwargs.get("alpha", 1.0)
if alpha == 1.0:
return out.numel()
else:
nb = args[1].numel() if torch.is_tensor(args[1]) else 1
return nb + out.numel()
def _flop_element_wise(args, kwargs, out):
return out.numel()
def _flop_sigmoid(args, kwargs, out):
return 4 * out.numel()
def _flop_native_batch_norm(args, kwargs, out):
x, train = args[0], args[5]
n, c = x.numel(), x.shape[1]
has_affine = args[1] is not None
has_running_stats = args[3] is not None
flops = 0
if train:
flops += n # batch mean
flops += 2 * n + 2 * c # batch var ; E[x^2] - E[x]^2
flops += 2 * c # sqrt(var + eps)
flops += 2 * n # x - mean / std
if has_affine:
flops += 2 * n # * gamma, + beta
if has_running_stats:
flops += 6 * c # old + m * (new - old); mean and var
else:
flops += 2 * c # sqrt(var + eps)
flops += 2 * n # x - mean / std
if has_affine:
flops += 2 * n # * gamma, + beta
return flops
def _flop_native_batch_norm_backward(args, kwargs, out):
grad_output, gamma, train, output_mask = args[0], args[2], args[-3], args[-1]
n = grad_output.numel()
c = gamma.numel()
flops = 0
if train:
if output_mask[0]: # grad_input
flops += 2 * n # x_hat = (x - mean) * invstd
flops += n - c # term1: sum(grad_output) per channel (grad_beta)
flops += (
2 * n - c
) # term2: sum(grad_output * x_hat) per channel (grad_gamma)
flops += (
5 * n + 2 * c
) # invstd*gamma/n * (grad_output*n - term1 - term2*x_hat)
if output_mask[1] and not output_mask[0]: # grad_gamma
flops += 2 * n # x_hat
flops += 2 * n - c
if output_mask[2] and not output_mask[0]: # grad_beta
flops = flops + n - c
else:
if output_mask[0]: # grad_input
flops += 2 * n # grad_output * saved_invstd * gamma
if output_mask[1]: # grad_gamma
flops += 2 * n # x_hat = (x - mean) / std
flops += 2 * n - c
if output_mask[2]: # grad_beta
flops += n - c
return flops
[文档]
class FlopCounter(BaseCounter):
r"""
**API Language:**
:ref:`中文 <FlopCounter-cn>` | :ref:`English <FlopCounter-en>`
----
.. _FlopCounter-cn:
* **中文**
* **中文**
浮点运算次数(FLOPs)计数器。
该计数器统计前向与部分反向算子在算术层面的浮点运算数量,用于粗略估计计算开销。
具体构造参数见 :meth:`__init__ <FlopCounter.__init__-cn>`。
----
.. _FlopCounter-en:
* **English**
* **English**
Floating-point operation (FLOP) counter.
This counter tracks arithmetic FLOPs of forward operators and some backward
operators as a coarse estimate of compute cost. See
:meth:`__init__ <FlopCounter.__init__-en>` for constructor parameters.
"""
def __init__(
self,
extra_rules: dict[Any, Callable] = {},
extra_ignore_modules: list[nn.Module] = [],
):
r"""
**API Language:**
:ref:`中文 <FlopCounter.__init__-cn>` | :ref:`English <FlopCounter.__init__-en>`
----
.. _FlopCounter.__init__-cn:
* **中文**
浮点运算计数器,用于计算深度神经网络中的浮点运算次数。
**FLOP(Floating Point Operations)** 是一个衡量计算复杂度的常用指标:
- 1 次乘法 = 1 FLOP;1 次加法 = 1 FLOP;......
- 逐元素操作的FLOP也会纳入考量。
``FlopCounter`` 应与 :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` 搭配使用。
.. warning::
目前,``FlopCounter`` 支持的 aten 操作类型有限。查看源代码以获取操作列表。如需添加新操作,
可以使用 ``extra_rules`` 参数;也欢迎提交 pull request 来完善默认的 :attr:`rules` !
:param extra_rules: 额外的操作规则,格式为 ``{aten_op: func}`` ,
其中 ``func`` 是一个函数,接受 ``(args, kwargs, out)`` 并返回计数值
:type extra_rules: dict[Any, Callable]
:param extra_ignore_modules: 额外需要忽略的模块列表,这些模块中的操作不会被计数
:type extra_ignore_modules: list[torch.nn.Module]
----
.. _FlopCounter.__init__-en:
* **English**
FLOP counter for calculating the number of floating-point operations in deep networks.
FLOP (Floating Point Operations) is a common metric for measuring computational complexity:
- 1 multiplication = 1 FLOP; 1 addition = 1 FLOP; ......
- Element-wise operations are also considered.
``FlopCounter`` should be used with :class:`DispatchCounterMode <spikingjelly.activation_based.op_counter.base.DispatchCounterMode>` .
.. warning::
Currently, ``FlopCounter`` supports a limited number of aten operations.
See the source code for the operation list.
If you want to add new operations, use the ``extra_rules`` parameter.
Welcome to submit a pull request to improve the default :attr:`rules` !
:param extra_rules: additional operation rules, format as ``{aten_op: func}``,
where ``func`` is a function that takes ``(args, kwargs, out)`` and returns the count value
:type extra_rules: dict[Any, Callable]
:param extra_ignore_modules: additional list of modules to ignore.
Operations within these modules will not be counted
:type extra_ignore_modules: list[torch.nn.Module]
----
* **代码示例 | Example**
.. code-block:: python
from spikingjelly.activation_based.op_counter import (
FlopCounter,
DispatchCounterMode,
)
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(100, 50), nn.ReLU(), nn.Linear(50, 10))
x = torch.randn(32, 100)
flop_counter = FlopCounter()
with DispatchCounterMode([flop_counter]):
output = model(x)
# Get FLOP counts
total_flops = flop_counter.get_total()
print(f"Total FLOPs: {total_flops}")
:return: None
:rtype: None
"""
self.records: dict[str, dict[Any, int]] = defaultdict(lambda: defaultdict(int))
self.rules: dict[Any, Callable] = {
aten.mm.default: _flop_mm,
aten.addmm.default: _flop_addmm,
aten.bmm.default: _flop_bmm,
aten.baddbmm.default: _flop_baddbmm,
aten.convolution.default: _flop_convolution,
aten.convolution_backward.default: _flop_convolution_backward,
aten.native_batch_norm.default: _flop_native_batch_norm,
aten.native_batch_norm_backward.default: _flop_native_batch_norm_backward,
aten.max_pool2d_with_indices.default: _flop_max_pool2d_with_indices,
aten.max_pool2d_with_indices_backward.default: _flop_null,
aten.avg_pool2d.default: _flop_avg_pool2d,
aten.sum.default: _flop_sum,
aten.sum.dim_IntList: _flop_sum,
aten.mean.dim: _flop_mean,
aten.add.Tensor: _flop_add,
aten.add_.Tensor: _flop_add,
aten.add.Scalar: _flop_add,
aten.add_.Scalar: _flop_add,
aten.sub.Tensor: _flop_add,
aten.sub_.Tensor: _flop_add,
aten.sub.Scalar: _flop_add,
aten.sub_.Scalar: _flop_add,
aten.rsub.Tensor: _flop_add,
aten.rsub.Scalar: _flop_add,
aten.neg.default: _flop_element_wise,
aten.neg_.default: _flop_element_wise,
aten.mul.Tensor: _flop_element_wise,
aten.mul_.Tensor: _flop_element_wise,
aten.mul.Scalar: _flop_element_wise,
aten.mul_.Scalar: _flop_element_wise,
aten.div.Tensor: _flop_element_wise,
aten.div_.Tensor: _flop_element_wise,
aten.div.Scalar: _flop_element_wise,
aten.div_.Scalar: _flop_element_wise,
aten.eq.Tensor: _flop_element_wise,
aten.eq.Scalar: _flop_element_wise,
aten.ne.Tensor: _flop_element_wise,
aten.ne.Scalar: _flop_element_wise,
aten.lt.Tensor: _flop_element_wise,
aten.lt.Scalar: _flop_element_wise,
aten.le.Tensor: _flop_element_wise,
aten.le.Scalar: _flop_element_wise,
aten.gt.Tensor: _flop_element_wise,
aten.gt.Scalar: _flop_element_wise,
aten.ge.Tensor: _flop_element_wise,
aten.ge.Scalar: _flop_element_wise,
aten.logical_and.default: _flop_element_wise,
aten.logical_or.default: _flop_element_wise,
aten.logical_xor.default: _flop_element_wise,
aten.logical_not.default: _flop_element_wise,
aten.sigmoid_.default: _flop_sigmoid,
aten.stack.default: _flop_null,
aten.clone.default: _flop_null,
aten._to_copy.default: _flop_null,
aten.full_like.default: _flop_null,
aten.ones_like.default: _flop_null,
aten.view.default: _flop_null,
aten.empty.memory_format: _flop_null,
aten.select.int: _flop_null,
aten.select_backward.default: _flop_null,
aten.detach.default: _flop_null,
aten.t.default: _flop_null,
aten.expand.default: _flop_null,
}
self.ignore_modules = []
self.rules.update(extra_rules)
self.ignore_modules.extend(extra_ignore_modules)