spikingjelly.activation_based.op_counter.spikesim.counter 源代码

from __future__ import annotations

import math
from collections import defaultdict
from dataclasses import asdict, dataclass
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..base import BaseCounter, is_binary_tensor
from .config import SpikeSimEnergyConfig

__all__ = ["SpikeSimCounter"]
aten = torch.ops.aten


def _shape_tuple(x: torch.Tensor) -> tuple[int, ...]:
    return tuple(int(v) for v in x.shape)


def _pair_tuple(x: Any) -> tuple[int, int]:
    if isinstance(x, int):
        return (int(x), int(x))
    if len(x) != 2:
        raise ValueError(f"expected a pair, but got {x}.")
    return (int(x[0]), int(x[1]))


def _exclude_python_dispatch_guard():
    return torch._C._ExcludeDispatchKeyGuard(
        torch._C.DispatchKeySet(torch._C.DispatchKey.Python)
    )


@dataclass
class _StageStats:
    dense_pe_cycle_count: int = 0
    active_patch_tile_count: int = 0
    active_row_count: int = 0
    active_row_count_by_tile: list[int] | None = None
    active_output_tile_site_count: int = 0
    dense_patch_tile_count: int = 0
    dense_row_count: int = 0
    dense_row_count_by_tile: list[int] | None = None
    dense_output_tile_site_count: int = 0

    def as_dict(self) -> dict[str, Any]:
        return asdict(self)


@dataclass
class _StageMetadata:
    scope: str
    in_channels: int
    out_channels: int
    kernel_size: tuple[int, int]
    stride: tuple[int, int]
    padding: tuple[int, int]
    dilation: tuple[int, int]
    input_shape: tuple[int, ...]
    output_shape: tuple[int, ...]
    out_channel_tiles: int
    input_tile_channels: list[int]
    total_calls: int = 0
    event_driven_calls: int = 0
    dense_fallback_calls: int = 0
    shape_mismatch_detected: bool = False

    def as_dict(self) -> dict[str, Any]:
        return asdict(self)


[文档] class SpikeSimCounter(BaseCounter): r""" **API Language:** :ref:`中文 <SpikeSimCounter-cn>` | :ref:`English <SpikeSimCounter-en>` ---- .. _SpikeSimCounter-cn: * **中文** SpikeSim 计数器,用于在 spike 驱动的模拟中统计计算成本。 :param config: SpikeSim 能量配置 :type config: SpikeSimEnergyConfig :param strict: 严格模式开关 :type strict: bool :param verbose: 详细输出开关 :type verbose: bool :return: None :rtype: None ---- .. _SpikeSimCounter-en: * **English** SpikeSim counter for profiling computation costs in spike-driven simulations. :param config: SpikeSim energy configuration :type config: SpikeSimEnergyConfig :param strict: Whether to use strict mode :type strict: bool :param verbose: Whether to produce verbose output :type verbose: bool :return: None :rtype: None """ def __init__( self, *, config: SpikeSimEnergyConfig, strict: bool, verbose: bool, ): super().__init__() self.config = config self.strict = strict self.verbose = verbose self.stage_stats: dict[str, _StageStats] = defaultdict(_StageStats) self.stage_metadata: dict[str, _StageMetadata] = {} self.warnings: list[str] = [] self._warning_keys: set[str] = set() self._ones_kernel_cache: dict[ tuple[int, int, int, torch.device], torch.Tensor ] = {} self.rules = { aten.convolution.default: self._count_convolution, }
[文档] def count( self, func, args: tuple, kwargs: dict, out, active_modules=None, parent_names=None, ) -> int: r""" **API Language:** :ref:`中文 <SpikeSimCounter.count-cn>` | :ref:`English <SpikeSimCounter.count-en>` ---- .. _SpikeSimCounter.count-cn: * **中文** 统计单次前向传播的计算成本。 :param func: 待统计的算子 :type func: Callable :param args: 位置参数 :type args: tuple :param kwargs: 关键字参数 :type kwargs: dict :param out: 算子输出 :param active_modules: 活跃模块列表 :param parent_names: 父节点名称列表 :return: 计算成本 :rtype: int :raises NotImplementedError: 若未注册的算子遇到则抛出 ---- .. _SpikeSimCounter.count-en: * **English** Count the computation cost of a single forward propagation. :param func: The operator to count :type func: Callable :param args: Positional arguments :type args: tuple :param kwargs: Keyword arguments :type kwargs: dict :param out: Output of the operator :param active_modules: List of active modules :param parent_names: List of parent node names :return: computation cost :rtype: int :raises NotImplementedError: Raised when encountering an unregistered operator """ return int( self.rules[func]( args, kwargs, out, active_modules=active_modules, parent_names=parent_names, ) )
[文档] def has_rule(self, func) -> bool: return func in self.rules
def _warn_or_raise(self, key: str, message: str) -> None: if key in self._warning_keys: return self._warning_keys.add(key) if self.strict: raise NotImplementedError(message) self.warnings.append(message) def _leaf_scope(self, parent_names: set[str] | None) -> str: names = [name for name in (parent_names or set()) if name != "Global"] if not names: return "Global" return max(names, key=lambda name: (name.count("."), len(name))) def _input_tile_channels(self, in_channels: int) -> list[int]: return [ min(self.config.xbar_size, in_channels - start) for start in range(0, in_channels, self.config.xbar_size) ] def _dense_event_counts( self, *, w: torch.Tensor, out: torch.Tensor, out_channel_tiles: int, ) -> tuple[int, int, list[int], int]: num_sites = int(out.shape[0] * out.shape[2] * out.shape[3]) input_tile_channels = self._input_tile_channels(int(w.shape[1])) dense_a = num_sites * len(input_tile_channels) dense_row_count_by_tile = [ num_sites * tile_channels * w.shape[2] * w.shape[3] for tile_channels in input_tile_channels ] dense_r = sum(dense_row_count_by_tile) dense_z = out_channel_tiles * num_sites return dense_a, dense_r, dense_row_count_by_tile, dense_z def _dense_pe_cycles( self, *, w: torch.Tensor, out: torch.Tensor, ) -> int: p_i = math.ceil(int(w.shape[1]) / self.config.xbar_size) q_i = math.ceil(int(w.shape[0]) / self.config.xbar_size) num_sites = int(out.shape[0] * out.shape[2] * out.shape[3]) return int(p_i * q_i * num_sites) def _spike_event_counts( self, *, x: torch.Tensor, w: torch.Tensor, stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], out: torch.Tensor, out_channel_tiles: int, ) -> tuple[int, int, list[int], int]: xbar_size = self.config.xbar_size c_in = x.shape[1] num_tiles = math.ceil(c_in / xbar_size) k_h, k_w = w.shape[2], w.shape[3] padded_channels = num_tiles * xbar_size if padded_channels == c_in: x_padded = x else: x_padded = F.pad(x, (0, 0, 0, 0, 0, padded_channels - c_in)) tile_sums = ( x_padded.reshape(x.shape[0], num_tiles, xbar_size, x.shape[2], x.shape[3]) .sum(dim=2) .to(dtype=torch.float32) ) cache_key = (num_tiles, k_h, k_w, tile_sums.device) if cache_key not in self._ones_kernel_cache: self._ones_kernel_cache[cache_key] = tile_sums.new_ones( (num_tiles, 1, k_h, k_w) ) ones_kernel = self._ones_kernel_cache[cache_key] with torch.no_grad(): with _exclude_python_dispatch_guard(): occupancy = F.conv2d( tile_sums, ones_kernel, None, stride, padding, dilation, num_tiles, ) active_patch = occupancy.gt(0) active_patch_tile_count = int(active_patch.sum().item()) active_row_count_by_tile = [ int(v) for v in occupancy.sum(dim=(0, 2, 3), dtype=torch.float64).tolist() ] active_row_count = int(sum(active_row_count_by_tile)) active_site_mask = active_patch.any(dim=1) active_output_tile_site_count = out_channel_tiles * int( active_site_mask.sum().item() ) return ( active_patch_tile_count, active_row_count, active_row_count_by_tile, active_output_tile_site_count, ) def _update_stage( self, *, scope: str, x: torch.Tensor, w: torch.Tensor, out: torch.Tensor, stride: tuple[int, int], padding: tuple[int, int], dilation: tuple[int, int], spike_like_input: bool, ) -> int: out_channel_tiles = math.ceil(w.shape[0] / self.config.xbar_size) metadata = self.stage_metadata.get(scope) if metadata is None: metadata = _StageMetadata( scope=scope, in_channels=int(w.shape[1]), out_channels=int(w.shape[0]), kernel_size=(int(w.shape[2]), int(w.shape[3])), stride=stride, padding=padding, dilation=dilation, input_shape=_shape_tuple(x), output_shape=_shape_tuple(out), out_channel_tiles=out_channel_tiles, input_tile_channels=self._input_tile_channels(int(w.shape[1])), ) self.stage_metadata[scope] = metadata else: same_shape = ( metadata.input_shape == _shape_tuple(x) and metadata.output_shape == _shape_tuple(out) and metadata.kernel_size == (int(w.shape[2]), int(w.shape[3])) and metadata.stride == stride and metadata.padding == padding and metadata.dilation == dilation ) if not same_shape: metadata.shape_mismatch_detected = True self._warn_or_raise( f"shape-mismatch:{scope}", f"SpikeSim stage '{scope}' was invoked with inconsistent shapes; " "runtime energy is accumulated across calls anyway.", ) metadata.total_calls += 1 dense_a, dense_r, dense_r_by_tile, dense_z = self._dense_event_counts( w=w, out=out, out_channel_tiles=out_channel_tiles, ) if self.config.activity_mode == "event" and spike_like_input: metadata.event_driven_calls += 1 active_a, active_r, active_r_by_tile, active_z = self._spike_event_counts( x=x, w=w, stride=stride, padding=padding, dilation=dilation, out=out, out_channel_tiles=out_channel_tiles, ) else: metadata.dense_fallback_calls += 1 active_a, active_r, active_r_by_tile, active_z = ( dense_a, dense_r, dense_r_by_tile, dense_z, ) dense_pe_cycles = self._dense_pe_cycles(w=w, out=out) stats = self.stage_stats[scope] stats.dense_pe_cycle_count += dense_pe_cycles stats.active_patch_tile_count += active_a stats.active_row_count += active_r if stats.active_row_count_by_tile is None: stats.active_row_count_by_tile = [0] * len(active_r_by_tile) elif len(stats.active_row_count_by_tile) != len(active_r_by_tile): stats.active_row_count_by_tile = None for i, value in enumerate(active_r_by_tile): if stats.active_row_count_by_tile is None: break stats.active_row_count_by_tile[i] += value stats.active_output_tile_site_count += active_z stats.dense_patch_tile_count += dense_a stats.dense_row_count += dense_r if stats.dense_row_count_by_tile is None: stats.dense_row_count_by_tile = [0] * len(dense_r_by_tile) elif len(stats.dense_row_count_by_tile) != len(dense_r_by_tile): stats.dense_row_count_by_tile = None for i, value in enumerate(dense_r_by_tile): if stats.dense_row_count_by_tile is None: break stats.dense_row_count_by_tile[i] += value stats.dense_output_tile_site_count += dense_z return dense_pe_cycles def _handle_convolution( self, scope: str, args: tuple[Any, ...], out, active_modules: set[nn.Module] | None, ) -> int: x, w = args[0], args[1] stride, padding, dilation = args[3], args[4], args[5] transposed = bool(args[6]) groups = int(args[8]) if transposed: self._warn_or_raise( f"transposed-conv:{scope}", "SpikeSim energy only covers original SpikeSim Conv2d inference " f"stages; transposed convolutions are outside scope: {scope}.", ) return 0 if not self._is_forward_inference_conv(active_modules): self._warn_or_raise( f"outside-scope:{scope}", "SpikeSim energy only covers Conv2d forward inference stages from " f"nn.Conv2d modules: {scope}.", ) return 0 if groups != 1: self._warn_or_raise( f"grouped-conv:{scope}", "SpikeSim event energy does not support grouped/depthwise " f"convolutions: {scope}.", ) return 0 if x.dim() != 4 or w.dim() != 4 or out.dim() != 4: self._warn_or_raise( f"conv-rank:{scope}", "SpikeSim event energy only supports Conv2d-like calls, but got " f"x.shape={tuple(x.shape)}, w.shape={tuple(w.shape)}, " f"out.shape={tuple(out.shape)}.", ) return 0 stride = _pair_tuple(stride) padding = _pair_tuple(padding) dilation = _pair_tuple(dilation) spike_like_input = is_binary_tensor(x) dense_pe_cycles = self._update_stage( scope=scope, x=x, w=w, out=out, stride=stride, padding=padding, dilation=dilation, spike_like_input=spike_like_input, ) if self.verbose: mode = ( "event" if self.config.activity_mode == "event" and spike_like_input else "dense" ) print( f"SpikeSimCounter: {scope} - aten.convolution.default " f"[{mode}] x={tuple(x.shape)} w={tuple(w.shape)} out={tuple(out.shape)}" ) return dense_pe_cycles def _count_convolution( self, args: tuple[Any, ...], kwargs: dict[str, Any], out, *, active_modules=None, parent_names=None, ) -> int: return self._handle_convolution( self._leaf_scope(parent_names), args, out, active_modules ) def _is_forward_inference_conv(self, active_modules: set[nn.Module] | None) -> bool: modules = active_modules or set() conv2d_modules = [module for module in modules if isinstance(module, nn.Conv2d)] if len(conv2d_modules) != 1: return False conv = conv2d_modules[0] return (not conv.training) and (not torch.is_grad_enabled())
[文档] def get_stage_stats(self) -> dict[str, dict[str, Any]]: return {stage: stats.as_dict() for stage, stats in self.stage_stats.items()}
[文档] def get_stage_metadata(self) -> dict[str, dict[str, Any]]: return { stage: metadata.as_dict() for stage, metadata in self.stage_metadata.items() }