spikingjelly.activation_based.op_counter.spikesim.core 源代码

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn

from ...neuron.base_node import BaseNode, SimpleBaseNode
from ...neuron.integrate_and_fire import IFNode, SimpleIFNode
from ...neuron.lif import LIFNode, SimpleLIFNode
from ..base import DispatchCounterMode
from .config import SpikeSimEnergyConfig
from .counter import SpikeSimCounter
from .formulas import (
    compute_spikesim_dense_energy_breakdown,
    compute_spikesim_event_energy_breakdown,
)

__all__ = [
    "SpikeSimEnergyConfig",
    "SpikeSimCounter",
    "SpikeSimEnergyReport",
    "SpikeSimEnergyProfiler",
    "SpikeSimEventEnergyReport",
    "SpikeSimEventEnergyProfiler",
    "estimate_spikesim_event_energy",
]


def _call_model(model: nn.Module, inputs):
    if isinstance(inputs, (tuple, list)):
        return model(*inputs)
    return model(inputs)


_SUPPORTED_SPIKESIM_NEURONS = (
    IFNode,
    LIFNode,
    SimpleIFNode,
    SimpleLIFNode,
)


def _unsupported_neuron_modules(model: nn.Module) -> list[tuple[str, str]]:
    unsupported: list[tuple[str, str]] = []
    for name, module in model.named_modules():
        if not isinstance(module, (BaseNode, SimpleBaseNode)):
            continue
        if isinstance(module, _SUPPORTED_SPIKESIM_NEURONS):
            continue
        unsupported.append(
            (name or module.__class__.__name__, module.__class__.__name__)
        )
    return unsupported


[文档] @dataclass class SpikeSimEnergyReport: r""" **API Language:** :ref:`中文 <SpikeSimEnergyReport-cn>` | :ref:`English <SpikeSimEnergyReport-en>` ---- .. _SpikeSimEnergyReport-cn: * **中文** SpikeSim 运行时能耗估计器的报告, 支持 ``dense`` 和 ``event`` 两种 activity mode。 字段包括总能耗、stage 分解、统计量、stage 元数据和 warning。 ``event_stats_by_stage`` 在两种 mode 下都会填充。 ---- .. _SpikeSimEnergyReport-en: * **English** Report for the SpikeSim runtime energy estimator, supporting both ``dense`` and ``event`` activity modes. Fields include total energy, stage-wise energy breakdown, event stats, stage metadata, and warnings. ``event_stats_by_stage`` is populated regardless of the selected activity mode. """ energy_total_pj: float energy_by_stage: dict[str, float] energy_by_component: dict[str, Any] event_stats_by_stage: dict[str, dict[str, Any]] stage_metadata: dict[str, dict[str, Any]] warnings: list[str] breakdown_pj: dict[str, float] counts: dict[str, int]
[文档] class SpikeSimEnergyProfiler: r""" **API Language:** :ref:`中文 <SpikeSimEnergyProfiler-cn>` | :ref:`English <SpikeSimEnergyProfiler-en>` ---- .. _SpikeSimEnergyProfiler-cn: * **中文** Runtime SpikeSim-aligned energy profiler. 使用方式: - 以 context manager 方式包住一次真实前向传播 - 结束后调用 ``get_report()`` 获取能耗报告 ---- .. _SpikeSimEnergyProfiler-en: * **English** Runtime SpikeSim-aligned energy profiler. Usage: - wrap one real forward pass in the profiler context - call ``get_report()`` afterwards to build the energy report """ def __init__( self, *, config: SpikeSimEnergyConfig | None = None, strict: bool = False, verbose: bool = False, ): r""" **API Language:** :ref:`中文 <SpikeSimEnergyProfiler.__init__-cn>` | :ref:`English <SpikeSimEnergyProfiler.__init__-en>` ---- .. _SpikeSimEnergyProfiler.__init__-cn: * **中文** :param config: SpikeSim 能耗配置;默认使用 ``SpikeSimEnergyConfig()`` :param strict: 是否在 unsupported 情况下直接抛异常 :param verbose: 是否打印逐 stage 的运行时统计信息 ---- .. _SpikeSimEnergyProfiler.__init__-en: * **English** :param config: SpikeSim energy config; defaults to ``SpikeSimEnergyConfig()`` :param strict: whether to raise immediately on unsupported behaviors :param verbose: whether to print per-stage runtime statistics :return: None :rtype: None """ self.config = (config or SpikeSimEnergyConfig()).copy() self.config.validate() self.strict = strict self.verbose = verbose self._counter = SpikeSimCounter( config=self.config, strict=self.strict, verbose=self.verbose, ) self._dispatch_mode = DispatchCounterMode( [self._counter], strict=False, verbose=self.verbose, ) self._active = False self._warnings: list[str] = [] def __enter__(self): self._dispatch_mode.__enter__() self._active = True return self def __exit__(self, exc_type, exc, tb): self._active = False return self._dispatch_mode.__exit__(exc_type, exc, tb)
[文档] def add_warnings(self, messages: list[str]) -> None: self._warnings.extend(messages)
[文档] def get_report(self) -> SpikeSimEnergyReport: r""" **API Language:** :ref:`中文 <SpikeSimEnergyProfiler.get_report-cn>` | :ref:`English <SpikeSimEnergyProfiler.get_report-en>` ---- .. _SpikeSimEnergyProfiler.get_report-cn: * **中文** 生成并返回完整的 SpikeSim runtime 能耗报告。 ---- .. _SpikeSimEnergyProfiler.get_report-en: * **English** Build and return the full runtime SpikeSim energy report. """ event_stats_by_stage = self._counter.get_stage_stats() stage_metadata = self._counter.get_stage_metadata() energy_by_stage: dict[str, float] = {} component_by_stage: dict[str, dict[str, float]] = {} component_totals: dict[str, float] = defaultdict(float) for stage, stats in event_stats_by_stage.items(): metadata = stage_metadata[stage] if self.config.activity_mode == "event": breakdown = compute_spikesim_event_energy_breakdown( stats=stats, metadata=metadata, config=self.config ) else: breakdown = compute_spikesim_dense_energy_breakdown( stats=stats, metadata=metadata, config=self.config ) component_by_stage[stage] = breakdown energy_by_stage[stage] = breakdown["total_pj"] for key, value in breakdown.items(): if key == "total_pj": continue component_totals[key] += value warnings = list(self._counter.warnings) warnings = list(self._warnings) + warnings if not energy_by_stage: warnings.append( "No supported Conv2d forward inference stages were profiled by " "SpikeSim energy." ) totals_dict = dict(component_totals) total_pj = sum(energy_by_stage.values()) counts = { "dense_pe_cycle_count": int(self._counter.get_total()), "stage_count": len(stage_metadata), } if self.config.activity_mode == "event": warnings.append( "SpikeSim activity_mode='event' is an experimental sparse runtime " "extension; activity_mode='dense' is the original SpikeSim-aligned " "default." ) return SpikeSimEnergyReport( energy_total_pj=total_pj, energy_by_stage=energy_by_stage, energy_by_component={ "totals": {**totals_dict, "total_pj": total_pj}, "by_stage": component_by_stage, }, event_stats_by_stage=event_stats_by_stage, stage_metadata=stage_metadata, warnings=warnings, breakdown_pj=totals_dict, counts=counts, )
[文档] def get_total(self) -> float: r""" **API Language:** :ref:`中文 <SpikeSimEnergyProfiler.get_total-cn>` | :ref:`English <SpikeSimEnergyProfiler.get_total-en>` ---- .. _SpikeSimEnergyProfiler.get_total-cn: * **中文** 返回总能耗(pJ)。 ---- .. _SpikeSimEnergyProfiler.get_total-en: * **English** Return total energy in pJ. """ return self.get_report().energy_total_pj
[文档] def get_counts(self) -> dict[str, Any]: r""" **API Language:** :ref:`中文 <SpikeSimEnergyProfiler.get_counts-cn>` | :ref:`English <SpikeSimEnergyProfiler.get_counts-en>` ---- .. _SpikeSimEnergyProfiler.get_counts-cn: * **中文** 返回统计量与 stage 元数据,便于和计数接口对齐。 ---- .. _SpikeSimEnergyProfiler.get_counts-en: * **English** Return event stats and stage metadata in an ``op_counter``-like count shape. """ report = self.get_report() return { "event_stats_by_stage": report.event_stats_by_stage, "stage_metadata": report.stage_metadata, }
SpikeSimEventEnergyReport = SpikeSimEnergyReport SpikeSimEventEnergyProfiler = SpikeSimEnergyProfiler
[文档] def estimate_spikesim_event_energy( model: nn.Module, inputs, *, config: SpikeSimEnergyConfig | None = None, strict: bool = False, verbose: bool = False, ) -> SpikeSimEventEnergyReport: r""" **API Language:** :ref:`中文 <estimate_spikesim_event_energy-cn>` | :ref:`English <estimate_spikesim_event_energy-en>` ---- .. _estimate_spikesim_event_energy-cn: * **中文** SpikeSim runtime 能耗估计的便捷入口。 该函数会执行一次真实前向传播并返回能耗报告。 :param model: 待统计模型 :param inputs: 模型输入;若为 tuple/list 则按 ``model(*inputs)`` 调用 :param config: SpikeSim 能耗配置 :param strict: 是否在 unsupported 情况下直接抛异常 :param verbose: 是否打印逐 stage 的运行时统计信息 ---- .. _estimate_spikesim_event_energy-en: * **English** Convenience entry for runtime SpikeSim energy estimation. It runs one real forward pass and returns the energy report. :param model: model to profile :param inputs: model input; tuple/list will be passed as ``model(*inputs)`` :param config: SpikeSim energy config :param strict: whether to raise immediately on unsupported behaviors :param verbose: whether to print per-stage runtime statistics """ cfg = (config or SpikeSimEnergyConfig()).copy() cfg.validate() training_message = ( "SpikeSim energy only covers forward inference; call model.eval() before " "profiling." ) if model.training: if strict: raise ValueError(training_message) neuron_warnings: list[str] = [] if cfg.require_if_lif_neurons: unsupported_neurons = _unsupported_neuron_modules(model) if unsupported_neurons: details = ", ".join( f"{name} ({class_name})" for name, class_name in unsupported_neurons ) message = ( "SpikeSim assumes IF/LIF neuronal modules. Unsupported neuron " f"modules found: {details}." ) if strict: raise ValueError(message) neuron_warnings.append(message) with SpikeSimEnergyProfiler(config=cfg, strict=strict, verbose=verbose) as profiler: profiler.add_warnings(neuron_warnings) with torch.no_grad(): _ = _call_model(model, inputs) report = profiler.get_report() if model.training: report.warnings.insert(0, training_message) return report