from __future__ import annotations
import re
import warnings
from collections import defaultdict
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
import torch
import torch.nn as nn
from torch.overrides import resolve_name
from torch.utils._python_dispatch import TorchDispatchMode
from .config import MemoryHierarchyConfig, MemoryInstanceSpec
from .utils import _is_spike, _prod
__all__ = [
"MemoryHierarchyConfig",
"NeuroMCEnergyProfiler",
"NeuroMCRuntimeEnergyReport",
"estimate_neuromc_runtime_energy",
]
_ALLOWED_CORE_TYPES = {
"fp_soma",
"fp_bn",
"bp_grad",
"bp_bn",
"bp_grad_opt",
"wg",
"ann_fe",
"ann_be",
"ann_we",
"ann_bn",
"ann_bn_bp",
}
_EXTRA_OP_COST_PJ = {
"mux": 0.548 * (1.0 / 16.0),
"add": 0.548,
"mul": 0.812,
"comp": 0.056,
"sqrt": 0.514,
}
_MAC_COST_PJ = {
"fp_soma": 0.548 + 0.548 * (1.0 / 16.0),
"bp_grad": 0.548 + 0.812,
"wg": 0.548 + 0.548 * (1.0 / 16.0),
"ann_fe": 0.548 + 0.812,
"ann_be": 0.548 + 0.812,
"ann_we": 0.548 + 0.812,
}
_IGNORED_OP_PREFIXES = (
"aten.detach",
"aten.view",
"aten.t.default",
"aten.transpose",
"aten.permute",
"aten.expand",
"aten.slice",
"aten.select",
"aten.alias",
"aten._unsafe_view",
"aten.as_strided",
"aten.clone",
"aten.copy_",
"profiler.",
)
_AUXILIARY_ATEN_OPS = {
"aten.all.default",
"aten._local_scalar_dense.default",
"aten.lift_fresh.default",
"aten._to_copy.default",
"aten.add.Tensor",
"aten.add_.Tensor",
"aten.add.Scalar",
"aten.add_.Scalar",
"aten.sub.Tensor",
"aten.sub_.Tensor",
"aten.sub.Scalar",
"aten.sub_.Scalar",
"aten.rsub.Tensor",
"aten.rsub.Scalar",
"aten.mul.Tensor",
"aten.mul_.Tensor",
"aten.mul.Scalar",
"aten.mul_.Scalar",
"aten.div.Tensor",
"aten.div_.Tensor",
"aten.div.Scalar",
"aten.div_.Scalar",
"aten.empty.memory_format",
"aten.cat.default",
"aten.stack.default",
"aten.split.Tensor",
"aten.index.Tensor",
"aten.full_like.default",
"aten.mse_loss.default",
"aten.mse_loss_backward.default",
"aten.where.self",
"aten.where.ScalarOther",
"aten.where.ScalarSelf",
"aten.sqrt.default",
"aten.sqrt_.default",
"aten.rsqrt.default",
"aten.sigmoid.default",
"aten.sigmoid_.default",
"aten.sum.default",
"aten.sum.dim_IntList",
"aten.mean.dim",
"aten.eq.Tensor",
"aten.eq.Scalar",
"aten.ne.Tensor",
"aten.ne.Scalar",
"aten.lt.Tensor",
"aten.lt.Scalar",
"aten.le.Tensor",
"aten.le.Scalar",
"aten.gt.Tensor",
"aten.gt.Scalar",
"aten.ge.Tensor",
"aten.ge.Scalar",
"aten.logical_and.default",
"aten.logical_or.default",
"aten.logical_xor.default",
"aten.logical_not.default",
"aten.new_empty_strided.default",
"aten.ones_like.default",
"aten.zeros_like.default",
}
[文档]
@dataclass(init=False)
class NeuroMCRuntimeEnergyReport:
"""Energy profiling report generated by the NeuroMC framework.
**API Language:**
:ref:`中文 <NeuroMCRuntimeEnergyReport-cn>` | :ref:`English <NeuroMCRuntimeEnergyReport-en>`
----
.. _NeuroMCRuntimeEnergyReport-cn:
* **中文**
NeuroMC 运行时能耗报告数据类。
记录一次能耗分析会话的完整结果,包括总能耗、计算能耗、内存能耗、
各阶段能耗分解、各算子类型的能耗分布以及内存访问位宽等详细信息。
可通过 :meth:`~NeuroMCRuntimeEnergyReport.summary` 方法获取
关键指标的字符串摘要,便于快速查看分析结果。
:param energy_total_pj: Total energy consumption in picojoules
:type energy_total_pj: float
:param energy_compute_pj: Total compute energy in picojoules
:type energy_compute_pj: float
:param energy_memory_pj: Total memory access energy in picojoules
:type energy_memory_pj: float
:param energy_by_stage: Energy breakdown by execution stage
:type energy_by_stage: ``dict[str, float]``
:param energy_by_op: Energy breakdown by operation type
:type energy_by_op: ``dict[str, float]``
:param primitive_counts: Raw primitive operation counts
:type primitive_counts: ``dict[str, Any]``
:param memory_bits_by_level: Memory access bits by hierarchy level
:type memory_bits_by_level: ``dict[str, Any]``
:param warnings: List of warnings generated during profiling
:type warnings: ``list[str]``
:param energy_mac_pj: Energy of MAC operations in picojoules
:type energy_mac_pj: float
:param energy_base_memory_pj: Base memory energy in picojoules
:type energy_base_memory_pj: float
:param energy_extra_memory_pj: Extra memory energy in picojoules
:type energy_extra_memory_pj: float
:param energy_extra_compute_pj: Extra compute energy in picojoules
:type energy_extra_compute_pj: float
:rtype: None
Contains detailed breakdown of compute and memory energy consumption
across different stages, operations, and memory hierarchy levels.
----
.. _NeuroMCRuntimeEnergyReport-en:
* **English**
Neuromcruntimeenergyreport function
:param energy_total_pj: Total energy consumption in picojoules
:param energy_compute_pj: Total compute energy in picojoules
:param energy_memory_pj: Total memory access energy in picojoules
:param energy_by_stage: Energy breakdown by execution stage
:param energy_by_op: Energy breakdown by operation type
:param primitive_counts: Raw primitive operation counts
:param memory_bits_by_level: Memory access bits by hierarchy level
:param warnings: List of warnings generated during profiling
:param energy_mac_pj: Energy of MAC operations in picojoules
:param energy_base_memory_pj: Base memory energy in picojoules
:param energy_extra_memory_pj: Extra memory energy in picojoules
:param energy_extra_compute_pj: Extra compute energy in picojoules
:type energy_total_pj: float
:type energy_compute_pj: float
:type energy_memory_pj: float
:type energy_by_stage: ``dict[str, float]``
:type energy_by_op: ``dict[str, float]``
:type primitive_counts: ``dict[str, Any]``
:type memory_bits_by_level: ``dict[str, Any]``
:type warnings: ``list[str]``
:type energy_mac_pj: float
:type energy_base_memory_pj: float
:type energy_extra_memory_pj: float
:type energy_extra_compute_pj: float
:return: None
:rtype: None
"""
energy_total_pj: float = 0.0
energy_compute_pj: float = 0.0
energy_memory_pj: float = 0.0
energy_by_stage: dict[str, float] = field(default_factory=dict)
energy_by_op: dict[str, float] = field(default_factory=dict)
primitive_counts: dict[str, Any] = field(default_factory=dict)
memory_bits_by_level: dict[str, Any] = field(default_factory=dict)
warnings: list[str] = field(default_factory=list)
energy_mac_pj: float = 0.0
energy_base_memory_pj: float = 0.0
energy_extra_memory_pj: float = 0.0
energy_extra_compute_pj: float = 0.0
energy_by_core_type: dict[str, float] = field(default_factory=dict)
energy_by_process_key: dict[str, float] = field(default_factory=dict)
energy_by_memory_level_dir: dict[str, dict[str, float]] = field(default_factory=dict)
counts_by_core_type: dict[str, dict[str, int]] = field(default_factory=dict)
counts_by_process_key: dict[str, dict[str, int]] = field(default_factory=dict)
mapping_summary: list[dict[str, Any]] = field(default_factory=list)
def __init__(
self,
energy_total_pj: float = 0.0,
energy_compute_pj: float = 0.0,
energy_memory_pj: float = 0.0,
energy_by_stage: Optional[dict[str, float]] = None,
energy_by_op: Optional[dict[str, float]] = None,
primitive_counts: Optional[dict[str, Any]] = None,
memory_bits_by_level: Optional[dict[str, Any]] = None,
warnings: Optional[list[str]] = None,
energy_mac_pj: float = 0.0,
energy_base_memory_pj: float = 0.0,
energy_extra_memory_pj: float = 0.0,
energy_extra_compute_pj: float = 0.0,
energy_by_core_type: Optional[dict[str, float]] = None,
energy_by_process_key: Optional[dict[str, float]] = None,
energy_by_memory_level_dir: Optional[dict[str, dict[str, float]]] = None,
counts_by_core_type: Optional[dict[str, dict[str, int]]] = None,
counts_by_process_key: Optional[dict[str, dict[str, int]]] = None,
mapping_summary: Optional[list[dict[str, Any]]] = None,
):
self.energy_total_pj = energy_total_pj
self.energy_compute_pj = energy_compute_pj
self.energy_memory_pj = energy_memory_pj
self.energy_by_stage = dict(energy_by_stage or {})
self.energy_by_op = dict(energy_by_op or {})
self.primitive_counts = dict(primitive_counts or {})
self.memory_bits_by_level = dict(memory_bits_by_level or {})
self.warnings = list(warnings or [])
self.energy_mac_pj = energy_mac_pj
self.energy_base_memory_pj = energy_base_memory_pj
self.energy_extra_memory_pj = energy_extra_memory_pj
self.energy_extra_compute_pj = energy_extra_compute_pj
self.energy_by_core_type = dict(energy_by_core_type or {})
self.energy_by_process_key = dict(energy_by_process_key or {})
self.energy_by_memory_level_dir = dict(energy_by_memory_level_dir or {})
self.counts_by_core_type = dict(counts_by_core_type or {})
self.counts_by_process_key = dict(counts_by_process_key or {})
self.mapping_summary = list(mapping_summary or [])
@dataclass
class _TraceTensor:
shape: tuple[int, ...]
dtype: torch.dtype
requires_grad: bool
numel_value: int
is_spike: bool
@property
def ndim(self) -> int:
return len(self.shape)
def numel(self) -> int:
return self.numel_value
@dataclass
class _TraceEvent:
op_name: str
stage: str
phase: str
args: Any
kwargs: Any
out: Any
@dataclass
class _Fragment:
stage: str
phase: str
op_name: str
core_type: str
process_key: str
loop_dims: dict[str, int]
input_precision_bits: int
weight_precision_bits: int
output_precision_bits: int
input_numel: int
weight_numel: int
output_numel: int
mac_count: int
conv_type: str = "--"
b_type: int = 0
t_type: int = 0
source: str = "trace"
optimizer_has_momentum: bool = False
optimizer_has_weight_decay: bool = False
optimizer_has_momentum_buffer: bool = False
def _module_overlap_signature(fragment: _Fragment) -> tuple[Any, ...]:
ld = fragment.loop_dims
return (
fragment.stage,
fragment.phase,
fragment.core_type,
fragment.process_key,
(
("BT", ld["B"] * ld["T"]),
("C", ld["C"]),
("K", ld["K"]),
("OY", ld["OY"]),
("OX", ld["OX"]),
("FY", ld["FY"]),
("FX", ld["FX"]),
),
fragment.input_precision_bits,
fragment.weight_precision_bits,
fragment.output_precision_bits,
fragment.input_numel,
fragment.weight_numel,
fragment.output_numel,
fragment.mac_count,
fragment.conv_type,
fragment.b_type,
fragment.t_type,
)
def _tensor_numel(x: Any) -> int:
if isinstance(x, _TraceTensor):
return x.numel_value
if not torch.is_tensor(x):
return 0
return int(x.numel())
def _tensor_layout(
x: torch.Tensor | _TraceTensor, module: nn.Module | None = None
) -> tuple[int, int, int, tuple[int, ...]]:
step_mode = getattr(module, "step_mode", "s") if module is not None else "s"
if step_mode == "m":
if x.ndim < 3:
raise ValueError(
f"Expected multi-step tensor with shape [T, N, C, ...], got {tuple(x.shape)}."
)
t = int(x.shape[0])
b = int(x.shape[1])
c = int(x.shape[2])
spatial = tuple(int(v) for v in x.shape[3:])
else:
t = 1
b = int(x.shape[0]) if x.ndim > 0 else 1
if x.ndim >= 2:
c = int(x.shape[1])
spatial = tuple(int(v) for v in x.shape[2:])
else:
c = _tensor_numel(x)
spatial = ()
return t, b, c, spatial
def _snapshot_trace_value(value: Any) -> Any:
if torch.is_tensor(value):
return _TraceTensor(
shape=tuple(int(v) for v in value.shape),
dtype=value.dtype,
requires_grad=bool(value.requires_grad),
numel_value=int(value.numel()),
is_spike=_is_spike(value),
)
if isinstance(value, tuple):
return tuple(_snapshot_trace_value(v) for v in value)
if isinstance(value, list):
return [_snapshot_trace_value(v) for v in value]
if isinstance(value, dict):
return {k: _snapshot_trace_value(v) for k, v in value.items()}
return value
def _is_spike_like(x: Any) -> bool:
if isinstance(x, _TraceTensor):
return x.is_spike
return _is_spike(x)
def _shape_tuple(x: Any) -> tuple[int, ...]:
if isinstance(x, (_TraceTensor, torch.Tensor)):
return tuple(int(v) for v in x.shape)
return ()
def _forward_core_type(is_spike_input: bool) -> str:
return "fp_soma" if is_spike_input else "ann_fe"
def _grad_input_core_type(is_spike_input: bool) -> str:
return "bp_grad" if is_spike_input else "ann_be"
def _weight_grad_core_type(is_spike_input: bool) -> str:
return "wg" if is_spike_input else "ann_we"
def _optimizer_param_count(fragment: _Fragment) -> int:
return fragment.loop_dims["K"] + (
fragment.loop_dims["FX"] * fragment.loop_dims["FY"] * fragment.loop_dims["C"]
)
def _resolve_loss_fn(loss_fn: Callable | None):
if loss_fn is None:
return None
if isinstance(loss_fn, nn.Module):
return loss_fn
return loss_fn
def _call_model(model: nn.Module, inputs):
if isinstance(inputs, (tuple, list)):
return model(*inputs)
return model(inputs)
def _clear_existing_grads(model: nn.Module, optimizer: torch.optim.Optimizer | None):
if optimizer is not None:
optimizer.zero_grad(set_to_none=True)
return
for p in model.parameters():
p.grad = None
class _TraceMode(TorchDispatchMode):
def __init__(self, profiler: "NeuroMCEnergyProfiler"):
super().__init__()
self.profiler = profiler
self.op_counts: dict[str, int] = {}
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
if self.profiler._suspended:
return func(*args, **kwargs)
op_name = resolve_name(func)
self.op_counts[op_name] = self.op_counts.get(op_name, 0) + 1
out = func(*args, **kwargs)
self.profiler._maybe_record_trace_event(op_name, args, kwargs, out)
return out
[文档]
class NeuroMCEnergyProfiler:
"""High-level energy profiler for spiking neural networks using the NeuroMC framework.
**API Language:**
:ref:`中文 <NeuroMCEnergyProfiler-cn>` | :ref:`English <NeuroMCEnergyProfiler-en>`
----
.. _NeuroMCEnergyProfiler-cn:
* **中文**
NeuroMC能耗分析器
:rtype: None
Profiles the energy consumption of a model by tracking operation counts
and memory access patterns across forward, backward, and optimizer stages.
----
.. _NeuroMCEnergyProfiler-en:
* **English**
NeuroMC energy profiler
:return: None
:rtype: None
"""
def __init__(
self,
*,
core_type: str = "fp_soma",
memory_config: MemoryHierarchyConfig | None = None,
strict: bool = False,
verbose: bool = False,
extra_ignore_modules: list[nn.Module] | None = None,
):
"""
:param core_type: Type of compute core (e.g., ``\"fp_soma\"``)
:type core_type: str
:param memory_config: Memory hierarchy configuration. If ``None``, uses the default ``neuromc_like_v1`` config
:type memory_config: MemoryHierarchyConfig | None
:param strict: If ``True``, raise on unknown operations instead of warning
:type strict: bool
:param verbose: If ``True``, print progress information during profiling
:type verbose: bool
:param extra_ignore_modules: Additional module types to ignore during counting
:type extra_ignore_modules: list[nn.Module] | None
:raises ValueError: If ``core_type`` is not in the supported set
:return: None
:rtype: None
"""
if core_type not in _ALLOWED_CORE_TYPES:
raise ValueError(
f"Unsupported NeuroMC core_type={core_type}. "
f"Supported: {sorted(_ALLOWED_CORE_TYPES)}."
)
self.core_type = core_type
self.memory_config = memory_config or MemoryHierarchyConfig.neuromc_like_v1()
self.memory_config.validate()
self.strict = strict
self.verbose = verbose
self.extra_ignore_modules = list(extra_ignore_modules or [])
self._stage_stack: list[str] = []
self._warnings: list[str] = []
self._trace_mode = _TraceMode(self)
self._trace_events: list[_TraceEvent] = []
self._fragments: list[_Fragment] = []
self._bound_model: nn.Module | None = None
self._active = False
self._suspended = False
self._model_bound = False
self._optimizer: torch.optim.Optimizer | None = None
self._hook_handles = []
[文档]
def bind_model(self, model: nn.Module):
if self._model_bound:
if self._bound_model is model:
return
raise RuntimeError(
"NeuroMCEnergyProfiler is already bound to a different model."
)
from ...neuron.base_node import BaseNode
supported = (
nn.Conv1d,
nn.Conv2d,
nn.Linear,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
BaseNode,
)
for module in model.modules():
if module in self.extra_ignore_modules:
continue
if isinstance(module, nn.Conv3d):
raise ValueError(
"Exact NeuroMC runtime does not support nn.Conv3d yet."
)
self._bound_model = model
for module in model.modules():
if module in self.extra_ignore_modules:
continue
if isinstance(module, supported):
self._hook_handles.append(
module.register_forward_hook(self._forward_hook)
)
self._hook_handles.append(
module.register_full_backward_hook(self._backward_hook)
)
self._model_bound = True
[文档]
def bind_optimizer(self, optimizer: torch.optim.Optimizer | None):
self._optimizer = optimizer
def __enter__(self):
self._trace_mode.__enter__()
self._active = True
return self
def __exit__(self, exc_type, exc, tb):
self._active = False
try:
for handle in self._hook_handles:
handle.remove()
if self._bound_model is not None:
for module in self._bound_model.modules():
if hasattr(module, "_neuromc_last_input"):
delattr(module, "_neuromc_last_input")
self._model_bound = False
self._bound_model = None
finally:
self._hook_handles.clear()
return self._trace_mode.__exit__(exc_type, exc, tb)
[文档]
@contextmanager
def stage(self, name: str):
if not self._active:
raise RuntimeError(
"stage() can only be used inside active profiler context."
)
if self._stage_stack:
raise RuntimeError("Nested stage() is not supported in NeuroMC v2.")
self._stage_stack.append(name)
try:
yield self
finally:
self._stage_stack.pop()
[文档]
@contextmanager
def suspend(self):
old = self._suspended
self._suspended = True
try:
yield self
finally:
self._suspended = old
def _current_stage(self) -> str:
if not self._stage_stack:
return "unlabeled"
return self._stage_stack[-1]
def _bn_core_type(self, phase: str, is_spike_input: bool) -> str:
if is_spike_input:
return "fp_bn" if phase == "forward" else "bp_bn"
return "ann_bn" if phase == "forward" else "ann_bn_bp"
def _stage_phase(self, stage: str | None = None) -> str:
name = stage or self._current_stage()
lowered = name.lower()
if "backward" in lowered:
return "backward"
if "optimizer" in lowered or "update" in lowered:
return "optimizer"
return "forward"
def _stage_position(self, stage: str | None = None) -> tuple[int, int]:
name = stage or self._current_stage()
lowered = name.lower()
b_match = re.search(r"(?:^|[^a-z0-9])b(\d+)(?:[^a-z0-9]|$)", lowered)
t_match = re.search(r"(?:^|[^a-z0-9])t(\d+)(?:[^a-z0-9]|$)", lowered)
b_type = int(b_match.group(1)) if b_match is not None else 0
t_type = int(t_match.group(1)) if t_match is not None else 0
return b_type, t_type
def _stage_conv_type(self, stage: str | None = None) -> str:
name = (stage or self._current_stage()).lower()
if "without_bp_bn" in name:
return "without_bp_bn"
return "--"
def _maybe_record_trace_event(self, op_name: str, args, kwargs, out):
if not self._active or self._suspended:
return
self._trace_events.append(
_TraceEvent(
op_name=op_name,
stage=self._current_stage(),
phase=self._stage_phase(),
args=_snapshot_trace_value(args),
kwargs=_snapshot_trace_value(kwargs),
out=_snapshot_trace_value(out),
)
)
def _forward_hook(self, module: nn.Module, args, out):
if not self._active or self._suspended:
return
stage = self._current_stage()
from ...neuron.base_node import BaseNode
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
x = args[0]
module._neuromc_last_input = x
self._fragments.append(
self._make_conv_forward_fragment(stage, module, x, out)
)
elif isinstance(module, nn.Linear):
x = args[0]
module._neuromc_last_input = x
self._fragments.append(
self._make_linear_forward_fragment(stage, module, x, out)
)
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
x = args[0]
module._neuromc_last_input = x
self._fragments.append(
self._make_bn_forward_fragment(stage, module, x, out)
)
elif isinstance(module, BaseNode):
x = args[0]
self._fragments.append(
self._make_soma_forward_fragment(stage, module, x, out)
)
def _backward_hook(self, module: nn.Module, grad_input, grad_output):
if not self._active or self._suspended:
return
stage = self._current_stage()
from ...neuron.base_node import BaseNode
if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
self._fragments.extend(
self._make_conv_backward_fragments(
stage, module, grad_input, grad_output
)
)
elif isinstance(module, nn.Linear):
self._fragments.extend(
self._make_linear_backward_fragments(
stage, module, grad_input, grad_output
)
)
elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
self._fragments.append(
self._make_bn_backward_fragment(stage, module, grad_input, grad_output)
)
elif isinstance(module, BaseNode):
self._fragments.append(
self._make_soma_backward_fragment(
stage, module, grad_input, grad_output
)
)
def _make_loop_dims(
self,
*,
batch_size: int,
time_steps: int = 1,
channels_in: int,
channels_out: int,
oy: int,
ox: int,
fy: int,
fx: int,
) -> dict[str, int]:
return {
"B": max(int(batch_size), 1),
"T": max(int(time_steps), 1),
"C": max(int(channels_in), 1),
"K": max(int(channels_out), 1),
"OY": max(int(oy), 1),
"OX": max(int(ox), 1),
"FY": max(int(fy), 1),
"FX": max(int(fx), 1),
}
def _make_conv_forward_fragment(
self, stage: str, module: nn.Module, x, out
) -> _Fragment:
b_type, t_type = self._stage_position(stage)
is_spike_input = _is_spike(x)
t, b, _, _ = _tensor_layout(x, module)
_, _, _, out_spatial = _tensor_layout(out, module)
spatial = out_spatial if len(out_spatial) > 0 else (1, 1)
if len(spatial) == 1:
spatial = (spatial[0], 1)
kernel = (
tuple(module.kernel_size)
if isinstance(module.kernel_size, tuple)
else (module.kernel_size,)
)
if len(kernel) == 1:
kernel = (kernel[0], 1)
loop_dims = self._make_loop_dims(
batch_size=b,
time_steps=t,
channels_in=int(module.in_channels),
channels_out=int(module.out_channels),
oy=int(spatial[0]),
ox=int(spatial[1]),
fy=int(kernel[0]),
fx=int(kernel[1]),
)
return _Fragment(
stage=stage,
phase="forward",
op_name="conv.forward",
core_type=_forward_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(module.weight),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
source="module",
)
def _make_linear_forward_fragment(
self, stage: str, module: nn.Linear, x, out
) -> _Fragment:
b_type, t_type = self._stage_position(stage)
is_spike_input = _is_spike(x)
t, batch, _, _ = _tensor_layout(x, module)
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=int(module.in_features),
channels_out=int(module.out_features),
oy=1,
ox=1,
fy=1,
fx=1,
)
return _Fragment(
stage=stage,
phase="forward",
op_name="linear.forward",
core_type=_forward_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(module.weight),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
source="module",
)
def _make_bn_forward_fragment(
self, stage: str, module: nn.Module | None, x, out
) -> _Fragment:
is_spike_input = _is_spike_like(x)
t, batch, c, spatial = _tensor_layout(x, module)
spatial_prod = max(_prod(spatial), 1) if spatial else 1
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=c,
channels_out=c,
oy=spatial_prod,
ox=1,
fy=1,
fx=1,
)
return _Fragment(
stage=stage,
phase="forward",
op_name="bn.forward",
core_type=self._bn_core_type("forward", is_spike_input),
process_key="with_bn",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=c * 2,
output_numel=_tensor_numel(
out[0] if isinstance(out, (tuple, list)) else out
),
mac_count=0,
source="module",
)
def _make_soma_forward_fragment(
self, stage: str, module: nn.Module | None, x, out
) -> _Fragment:
b_type, t_type = self._stage_position(stage)
out_tensor = out[0] if isinstance(out, (tuple, list)) else out
if torch.is_tensor(out_tensor):
t, batch, c, spatial = _tensor_layout(out_tensor, module)
spatial_prod = max(_prod(spatial), 1) if spatial else 1
else:
t, batch, c, spatial_prod = 1, 1, 1, 1
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=c,
channels_out=c,
oy=spatial_prod,
ox=1,
fy=1,
fx=1,
)
return _Fragment(
stage=stage,
phase="forward",
op_name="soma.forward",
core_type="fp_soma",
process_key="with_sg",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=1,
input_numel=_tensor_numel(x),
weight_numel=0,
output_numel=_tensor_numel(out_tensor),
mac_count=0,
b_type=b_type,
t_type=t_type,
source="module",
)
def _make_conv_backward_fragments(self, stage, module, grad_input, grad_output):
grad_out = (
grad_output[0] if isinstance(grad_output, (tuple, list)) else grad_output
)
is_spike_input = _is_spike(getattr(module, "_neuromc_last_input", None))
t, batch, _, spatial = _tensor_layout(grad_out, module)
spatial = spatial if len(spatial) > 0 else (1, 1)
if len(spatial) == 1:
spatial = (spatial[0], 1)
kernel = (
tuple(module.kernel_size)
if isinstance(module.kernel_size, tuple)
else (module.kernel_size,)
)
if len(kernel) == 1:
kernel = (kernel[0], 1)
grad_in = (
grad_input[0]
if isinstance(grad_input, (tuple, list)) and grad_input
else None
)
base_loop = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=int(module.in_channels),
channels_out=int(module.out_channels),
oy=int(spatial[0]),
ox=int(spatial[1]),
fy=int(kernel[0]),
fx=int(kernel[1]),
)
fragments = []
if torch.is_tensor(grad_in):
fragments.append(
_Fragment(
stage=stage,
phase="backward",
op_name="conv.backward.grad_input",
core_type=_grad_input_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=base_loop,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(grad_out),
weight_numel=_tensor_numel(module.weight),
output_numel=_tensor_numel(grad_in),
mac_count=self._mac_count(base_loop),
source="module",
)
)
if module.weight.requires_grad:
fragments.append(
_Fragment(
stage=stage,
phase="backward",
op_name="conv.backward.grad_weight",
core_type=_weight_grad_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=base_loop,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(
getattr(module, "_neuromc_last_input", None)
),
weight_numel=_tensor_numel(grad_out),
output_numel=_tensor_numel(module.weight),
mac_count=self._mac_count(base_loop),
source="module",
)
)
return fragments
def _make_linear_backward_fragments(self, stage, module, grad_input, grad_output):
grad_out = (
grad_output[0] if isinstance(grad_output, (tuple, list)) else grad_output
)
is_spike_input = _is_spike(getattr(module, "_neuromc_last_input", None))
t, batch, _, _ = _tensor_layout(grad_out, module)
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=int(module.in_features),
channels_out=int(module.out_features),
oy=1,
ox=1,
fy=1,
fx=1,
)
grad_in = (
grad_input[0]
if isinstance(grad_input, (tuple, list)) and grad_input
else None
)
fragments = []
if torch.is_tensor(grad_in):
fragments.append(
_Fragment(
stage=stage,
phase="backward",
op_name="linear.backward.grad_input",
core_type=_grad_input_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(grad_out),
weight_numel=_tensor_numel(module.weight),
output_numel=_tensor_numel(grad_in),
mac_count=self._mac_count(loop_dims),
source="module",
)
)
if module.weight.requires_grad:
fragments.append(
_Fragment(
stage=stage,
phase="backward",
op_name="linear.backward.grad_weight",
core_type=_weight_grad_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(
getattr(module, "_neuromc_last_input", None)
),
weight_numel=_tensor_numel(grad_out),
output_numel=_tensor_numel(module.weight),
mac_count=self._mac_count(loop_dims),
source="module",
)
)
return fragments
def _make_bn_backward_fragment(self, stage, module, grad_input, grad_output):
is_spike_input = _is_spike(getattr(module, "_neuromc_last_input", None))
grad_out = (
grad_output[0] if isinstance(grad_output, (tuple, list)) else grad_output
)
grad_in = (
grad_input[0]
if isinstance(grad_input, (tuple, list)) and grad_input
else grad_out
)
t, batch, c, spatial = _tensor_layout(grad_out, module)
spatial_prod = max(_prod(spatial), 1) if spatial else 1
conv_type = self._stage_conv_type(stage)
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=c,
channels_out=c,
oy=spatial_prod,
ox=1,
fy=1,
fx=1,
)
return _Fragment(
stage=stage,
phase="backward",
op_name="bn.backward",
core_type=self._bn_core_type("backward", is_spike_input),
process_key="with_bn",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(grad_out),
weight_numel=c * 2,
output_numel=_tensor_numel(grad_in),
mac_count=0,
conv_type=conv_type,
source="module",
)
def _make_soma_backward_fragment(self, stage, module, grad_input, grad_output):
grad_in = (
grad_input[0]
if isinstance(grad_input, (tuple, list)) and grad_input
else None
)
grad_out = (
grad_output[0] if isinstance(grad_output, (tuple, list)) else grad_output
)
tensor = grad_in if torch.is_tensor(grad_in) else grad_out
t, batch, c, spatial = _tensor_layout(tensor, module)
spatial_prod = max(_prod(spatial), 1) if spatial else 1
loop_dims = self._make_loop_dims(
batch_size=batch,
time_steps=t,
channels_in=c,
channels_out=c,
oy=spatial_prod,
ox=1,
fy=1,
fx=1,
)
return _Fragment(
stage=stage,
phase="backward",
op_name="soma.backward",
core_type="bp_grad",
process_key="with_sg",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(grad_out),
weight_numel=0,
output_numel=_tensor_numel(tensor),
mac_count=0,
source="module",
)
def _mac_count(self, loop_dims: dict[str, int]) -> int:
return int(
loop_dims["B"]
* loop_dims["T"]
* loop_dims["C"]
* loop_dims["K"]
* loop_dims["OY"]
* loop_dims["OX"]
* loop_dims["FY"]
* loop_dims["FX"]
)
def _matches_trainable_param_shape(self, out: Any) -> bool:
if self._bound_model is None or not (
torch.is_tensor(out) or isinstance(out, _TraceTensor)
):
return False
out_shape = tuple(out.shape)
for p in self._bound_model.parameters():
if p.requires_grad and tuple(p.shape) == out_shape:
return True
return False
def _gemm_backward_fragment_kind(
self, x: torch.Tensor | _TraceTensor, y: torch.Tensor | _TraceTensor, out: Any
) -> str | None:
x_req = bool(x.requires_grad)
y_req = bool(y.requires_grad)
if x_req and (not y_req):
return "wg"
if (not x_req) and y_req:
return "bp_grad"
if self._matches_trainable_param_shape(out):
return "wg"
return None
def _fallback_fragments_from_trace(
self, *, supplemental_only: bool = False
) -> list[_Fragment]:
fragments: list[_Fragment] = []
gemm_forward_is_spike: dict[
tuple[str, tuple[int, ...], tuple[int, ...]], deque[bool]
] = defaultdict(deque)
supplemental_ops = {
"aten.convolution.default",
"aten.addmm.default",
"aten.mm.default",
"aten.bmm.default",
"aten.native_batch_norm.default",
"aten.native_batch_norm_backward.default",
}
for event in self._trace_events:
op = event.op_name
b_type, t_type = self._stage_position(event.stage)
if supplemental_only and op not in supplemental_ops:
continue
if op in _AUXILIARY_ATEN_OPS or op.startswith(_IGNORED_OP_PREFIXES):
continue
if op == "aten.convolution.default" and event.phase == "forward":
x, w = event.args[:2]
out = event.out
if x.ndim > 4 or out.ndim > 4:
raise ValueError(
"Exact NeuroMC runtime does not support multi-step or 3D "
"Conv trace fallback yet."
)
is_spike_input = _is_spike_like(x)
spatial = tuple(out.shape[2:]) if out.ndim > 2 else (1, 1)
if len(spatial) > 2:
raise ValueError(
"Exact NeuroMC runtime does not support Conv3d fallback yet."
)
if len(spatial) == 1:
spatial = (spatial[0], 1)
kernel = tuple(w.shape[2:]) if w.ndim > 2 else (1, 1)
if len(kernel) > 2:
raise ValueError(
"Exact NeuroMC runtime does not support Conv3d fallback yet."
)
if len(kernel) == 1:
kernel = (kernel[0], 1)
loop_dims = self._make_loop_dims(
batch_size=int(x.shape[0]),
channels_in=int(x.shape[1]),
channels_out=int(out.shape[1]),
oy=int(spatial[0]),
ox=int(spatial[1]),
fy=int(kernel[0]),
fx=int(kernel[1]),
)
fragments.append(
_Fragment(
stage=event.stage,
phase=event.phase,
op_name=op,
core_type=_forward_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(w),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
)
)
elif op in {"aten.addmm.default", "aten.mm.default", "aten.bmm.default"}:
x = event.args[-2]
y = event.args[-1]
out = event.out
if event.phase == "forward":
is_spike_input = _is_spike_like(x)
gemm_forward_is_spike[
(op, _shape_tuple(x), _shape_tuple(out))
].append(is_spike_input)
batch = (
int(x.shape[0]) if x.ndim == 2 else int(x.shape[0] * x.shape[1])
)
k = int(x.shape[-1])
n = int(y.shape[-1])
loop_dims = self._make_loop_dims(
batch_size=batch,
channels_in=k,
channels_out=n,
oy=1,
ox=1,
fy=1,
fx=1,
)
fragments.append(
_Fragment(
stage=event.stage,
phase=event.phase,
op_name=op,
core_type=_forward_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(y),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
)
)
else:
kind = self._gemm_backward_fragment_kind(x, y, out)
if kind == "wg":
is_spike_input = _is_spike_like(x)
if x.ndim == 2:
batch = int(x.shape[-1])
c = int(x.shape[0])
k = int(y.shape[-1])
else:
batch = int(x.shape[0] * x.shape[-1])
c = int(x.shape[1])
k = int(y.shape[-1])
loop_dims = self._make_loop_dims(
batch_size=batch,
channels_in=c,
channels_out=k,
oy=1,
ox=1,
fy=1,
fx=1,
)
fragments.append(
_Fragment(
stage=event.stage,
phase=event.phase,
op_name=op,
core_type=_weight_grad_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=1 if is_spike_input else 16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(y),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
)
)
else:
forward_key = (op, _shape_tuple(out), _shape_tuple(x))
is_spike_input = False
forward_queue = gemm_forward_is_spike.get(forward_key)
if forward_queue:
is_spike_input = forward_queue.pop()
elif op == "aten.addmm.default":
alt_key = (
"aten.mm.default",
_shape_tuple(out),
_shape_tuple(x),
)
alt_queue = gemm_forward_is_spike.get(alt_key)
if alt_queue:
is_spike_input = alt_queue.pop()
elif op == "aten.mm.default":
alt_key = (
"aten.addmm.default",
_shape_tuple(out),
_shape_tuple(x),
)
alt_queue = gemm_forward_is_spike.get(alt_key)
if alt_queue:
is_spike_input = alt_queue.pop()
# Missing forward provenance should bias toward the denser,
# more expensive path instead of undercounting energy.
batch = (
int(x.shape[0])
if x.ndim == 2
else int(x.shape[0] * x.shape[1])
)
k = int(x.shape[-1])
n = int(y.shape[-1])
loop_dims = self._make_loop_dims(
batch_size=batch,
channels_in=k,
channels_out=n,
oy=1,
ox=1,
fy=1,
fx=1,
)
fragments.append(
_Fragment(
stage=event.stage,
phase=event.phase,
op_name=op,
core_type=_grad_input_core_type(is_spike_input),
process_key="with_nothing",
loop_dims=loop_dims,
input_precision_bits=16,
weight_precision_bits=16,
output_precision_bits=16,
input_numel=_tensor_numel(x),
weight_numel=_tensor_numel(y),
output_numel=_tensor_numel(out),
mac_count=self._mac_count(loop_dims),
b_type=b_type,
t_type=t_type,
)
)
elif op == "aten.native_batch_norm.default" and event.phase == "forward":
x = event.args[0]
if x.ndim > 4:
raise ValueError(
"Exact NeuroMC runtime does not support multi-step or 3D "
"BatchNorm trace fallback yet."
)
out = (
event.out[0] if isinstance(event.out, (tuple, list)) else event.out
)
fragments.append(
self._make_bn_forward_fragment(event.stage, None, x, out)
)
elif op == "aten.native_batch_norm_backward.default":
fragments.append(
self._make_bn_backward_fragment(
event.stage, None, event.args, event.out
)
)
return fragments
def _unsupported_ops(self) -> list[str]:
unsupported = []
for op_name, count in self._trace_mode.op_counts.items():
if op_name.startswith(_IGNORED_OP_PREFIXES):
continue
if op_name in _AUXILIARY_ATEN_OPS:
continue
if op_name in {
"aten.convolution.default",
"aten.addmm.default",
"aten.mm.default",
"aten.bmm.default",
"aten.native_batch_norm.default",
"aten.native_batch_norm_backward.default",
}:
continue
unsupported.append(f"{op_name} (calls={count})")
return sorted(unsupported)
def _extra_counts(self, fragment: _Fragment) -> dict[str, int]:
counts = {"mux": 0, "add": 0, "mul": 0, "comp": 0, "sqrt": 0}
ld = fragment.loop_dims
oyoxkbt = ld["OY"] * ld["OX"] * ld["K"] * ld["B"] * ld["T"]
oyoxcbt = ld["OY"] * ld["OX"] * ld["C"] * ld["B"] * ld["T"]
ct = ld["C"] * ld["T"]
if fragment.process_key == "with_opt":
kt = ld["K"]
fyfxkc = ld["FX"] * ld["FY"] * ld["C"]
else:
kt = ld["K"] * ld["T"]
fyfxkc = ld["FY"] * ld["FX"] * ld["K"] * ld["C"]
if fragment.process_key == "with_sg":
if fragment.phase == "forward":
counts["mux"] += oyoxkbt
counts["add"] += oyoxkbt
counts["mul"] += oyoxkbt
counts["comp"] += oyoxkbt * 3
else:
counts["mux"] += oyoxcbt
counts["add"] += oyoxcbt * 2
counts["mul"] += oyoxcbt * 4
elif fragment.process_key == "with_bn":
if fragment.phase == "forward":
counts["add"] += oyoxkbt * 3 + kt * 2
counts["mul"] += oyoxkbt * 3 + kt * 4
counts["sqrt"] += kt
else:
counts["add"] += oyoxcbt * 7
counts["mul"] += oyoxcbt * 3 + ct * 22
if fragment.conv_type == "without_bp_bn":
counts["add"] = 0
counts["mul"] = 0
elif fragment.process_key == "with_opt":
if fragment.op_name in {"adam", "adamw"}:
counts["add"] += 8 * kt + 4 * fyfxkc
counts["mul"] += 22 * kt + 11 * fyfxkc
counts["sqrt"] += 2 * kt + fyfxkc
elif fragment.op_name == "sgd":
total_params = _optimizer_param_count(fragment)
counts["add"] += total_params
counts["mul"] += total_params
if fragment.optimizer_has_weight_decay:
counts["add"] += total_params
counts["mul"] += total_params
if (
fragment.optimizer_has_momentum
and fragment.optimizer_has_momentum_buffer
):
counts["add"] += total_params
counts["mul"] += total_params
return counts
def _memory_energy_per_element(
self, spec: MemoryInstanceSpec, precision_bits: int, read: bool
) -> float:
bw = spec.r_bw if read else spec.w_bw
cost = spec.r_cost if read else spec.w_cost
if precision_bits <= 0 or bw <= 0:
return 0.0
return cost / (bw / precision_bits)
def _accumulate_memory(
self,
totals: dict[str, dict[str, int]],
energy: dict[str, dict[str, float]],
level: str,
direction: str,
bits: int,
spec: MemoryInstanceSpec,
precision_bits: int,
read: bool,
):
if bits <= 0:
return
if level == "dram" and self.memory_config.zero_dram_in_paper_energy:
return
if level == "noc" and self.memory_config.zero_noc_in_paper_energy:
return
if (
level == "sram"
and self.memory_config.zero_sram_high_directions
and direction in {"rl2h", "wh2l"}
):
return
totals[level][direction] += bits
energy[level][direction] += (
bits / precision_bits
) * self._memory_energy_per_element(spec, precision_bits, read)
def _base_memory_for_fragment(self, fragment: _Fragment):
totals = defaultdict(lambda: defaultdict(int))
energy = defaultdict(lambda: defaultdict(float))
if (
fragment.core_type
not in {"fp_soma", "bp_grad", "wg", "ann_fe", "ann_be", "ann_we"}
or fragment.mac_count == 0
):
return totals, energy
cfg = self.memory_config.memory_instances
if fragment.core_type == "fp_soma":
reg_i1, reg_i2, reg_o = cfg["reg_1b"], cfg["reg_16b"], cfg["reg_16b"]
sram_i1, sram_i2, sram_o = (
cfg["sram_fp_conv_in_s"],
cfg["sram_fp_conv_in_w"],
cfg["sram_fp_conv_out_xi"],
)
elif fragment.core_type == "ann_fe":
reg_i1, reg_i2, reg_o = cfg["reg_16b"], cfg["reg_16b"], cfg["reg_16b"]
sram_i1, sram_i2, sram_o = (
cfg["sram_fp_conv_in_w"],
cfg["sram_fp_conv_in_w"],
cfg["sram_fp_conv_out_xi"],
)
elif fragment.core_type in {"bp_grad", "ann_be"}:
reg_i1, reg_i2, reg_o = cfg["reg_16b"], cfg["reg_16b"], cfg["reg_16b"]
sram_i1, sram_i2, sram_o = (
cfg["sram_bp_conv_in_du"],
cfg["sram_bp_conv_in_w"],
cfg["sram_bp_conv_out_res"],
)
else:
if fragment.core_type == "ann_we":
reg_i1, reg_i2, reg_o = (
cfg["reg_16b"],
cfg["reg_16b"],
cfg["reg_16b"],
)
sram_i1, sram_i2, sram_o = (
cfg["sram_wg_conv_in_du"],
cfg["sram_wg_conv_in_du"],
cfg["sram_wg_conv_out_dw"],
)
else:
reg_i1, reg_i2, reg_o = cfg["reg_1b"], cfg["reg_16b"], cfg["reg_16b"]
sram_i1, sram_i2, sram_o = (
cfg["sram_wg_conv_in_s"],
cfg["sram_wg_conv_in_du"],
cfg["sram_wg_conv_out_dw"],
)
i1_bits = fragment.input_numel * fragment.input_precision_bits
i2_bits = fragment.weight_numel * fragment.weight_precision_bits
o_bits = fragment.output_numel * fragment.output_precision_bits
reuse_weight = (
fragment.b_type > 0 or fragment.t_type > 0
) and i2_bits <= sram_i2.size_bits
self._accumulate_memory(
totals,
energy,
"reg",
"rh2l",
i1_bits,
reg_i1,
fragment.input_precision_bits,
True,
)
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
i1_bits,
sram_i1,
fragment.input_precision_bits,
True,
)
if not reuse_weight:
self._accumulate_memory(
totals,
energy,
"reg",
"rh2l",
i2_bits,
reg_i2,
fragment.weight_precision_bits,
True,
)
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
i2_bits,
sram_i2,
fragment.weight_precision_bits,
True,
)
self._accumulate_memory(
totals,
energy,
"reg",
"wl2h",
o_bits,
reg_o,
fragment.output_precision_bits,
False,
)
self._accumulate_memory(
totals,
energy,
"sram",
"wl2h",
o_bits,
sram_o,
fragment.output_precision_bits,
False,
)
return totals, energy
def _extra_memory_for_fragment(self, fragment: _Fragment):
totals = defaultdict(lambda: defaultdict(int))
energy = defaultdict(lambda: defaultdict(float))
if fragment.process_key == "with_nothing":
return totals, energy
cfg = self.memory_config.memory_instances
ld = fragment.loop_dims
if fragment.process_key == "with_opt" and fragment.op_name == "sgd":
total_params = _optimizer_param_count(fragment)
total_bits = total_params * 32
sram_spec = cfg["sram_6MB"]
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
total_bits,
sram_spec,
32,
True,
)
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
total_bits,
sram_spec,
32,
True,
)
self._accumulate_memory(
totals,
energy,
"sram",
"wl2h",
total_bits,
sram_spec,
32,
False,
)
if (
fragment.optimizer_has_momentum
and fragment.optimizer_has_momentum_buffer
):
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
total_bits,
sram_spec,
32,
True,
)
if fragment.optimizer_has_momentum:
self._accumulate_memory(
totals,
energy,
"sram",
"wl2h",
total_bits,
sram_spec,
32,
False,
)
return totals, energy
scalar_counts = {
"OYOXKBT": ld["OY"] * ld["OX"] * ld["K"] * ld["B"] * ld["T"],
"KT": ld["K"] if fragment.process_key == "with_opt" else ld["K"] * ld["T"],
"OYOXCBT": ld["OY"] * ld["OX"] * ld["C"] * ld["B"] * ld["T"],
"CT": ld["C"] * ld["T"],
"FYFXKC": (
ld["FY"] * ld["FX"] * ld["C"]
if fragment.process_key == "with_opt"
else ld["FY"] * ld["FX"] * ld["K"] * ld["C"]
),
}
variables: dict[str, tuple[str, int, str, str]] = {}
if fragment.process_key == "with_sg" and fragment.phase == "forward":
variables = {
"fp_yi1": ("OYOXKBT", 16, "reg_16b", "sram_fp_conv_out_xi"),
"fp_u_l": ("OYOXKBT", 16, "reg_16b", "sram_fp_soma_u"),
"fp_s_l": ("OYOXKBT", 1, "reg_1b", "sram_fp_soma_s"),
"fp_smask_l": ("OYOXKBT", 1, "reg_1b", "sram_fp_soma_smask"),
}
elif fragment.process_key == "with_sg" and fragment.phase == "backward":
variables = {
"bp_conv_res": ("OYOXCBT", 16, "reg_16b", "sram_bp_conv_out_res"),
"bp_u_l_pre": ("OYOXCBT", 16, "reg_16b", "sram_bp_grad_in_u"),
"bp_s_l_pre": ("OYOXCBT", 1, "reg_1b", "sram_bp_grad_in_s"),
"bp_smask_l_pre": ("OYOXCBT", 1, "reg_1b", "sram_bp_grad_in_smask"),
"bp_du_l_pre": ("OYOXCBT", 16, "reg_16b", "sram_bp_grad_out_du"),
}
elif fragment.process_key == "with_bn" and fragment.phase == "forward":
variables = {
"fp_bn_mean_v": ("KT", 16, "reg_16b", "sram_2MB"),
"fp_bn_variance": ("KT", 16, "reg_16b", "sram_2MB"),
"fp_bn_n": ("OYOXKBT", 16, "reg_16b", "sram_2MB"),
"fp_bn_sqrt": ("KT", 16, "reg_16b", "sram_2MB"),
"fp_bn_xi_": ("OYOXKBT", 16, "reg_16b", "sram_2MB"),
"fp_bn_y": ("KT", 16, "reg_16b", "sram_2MB"),
"fp_bn_b": ("KT", 16, "reg_16b", "sram_2MB"),
}
elif fragment.process_key == "with_bn" and fragment.phase == "backward":
variables = {
"bp_bn_du_l_pre1": ("OYOXCBT", 16, "reg_16b", "sram_2MB"),
"bp_bn_sqrt": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_y": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_m": ("OYOXCBT", 16, "reg_16b", "sram_2MB"),
"bp_bn_n": ("OYOXCBT", 16, "reg_16b", "sram_2MB"),
"bp_bn_sigma_m": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_sigma_n": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_sigma_mn": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_dy": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_db": ("CT", 16, "reg_16b", "sram_2MB"),
"bp_bn_du_l_pre2": ("OYOXCBT", 16, "reg_16b", "sram_2MB"),
}
elif fragment.process_key == "with_opt":
if fragment.op_name in {"adam", "adamw"}:
variables = {
"opt_y": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_b": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_w": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_dy": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_db": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_dw": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_v_y": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_v_b": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_v_w": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_s_y": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_s_b": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_s_w": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_vbc_y": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_vbc_b": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_vbc_w": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_sbc_y": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_sbc_b": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_sbc_w": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
"opt_y_updated": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_b_updated": ("KT", 32, "reg_32b", "sram_6MB"),
"opt_w_updated": ("FYFXKC", 32, "reg_32b", "sram_6MB"),
}
for _, (count_key, bits_per_elem, reg_name, sram_name) in variables.items():
total_bits = scalar_counts[count_key] * bits_per_elem
sram_spec = cfg[sram_name]
self._accumulate_memory(
totals,
energy,
"sram",
"rh2l",
total_bits,
sram_spec,
bits_per_elem,
True,
)
self._accumulate_memory(
totals,
energy,
"sram",
"wl2h",
total_bits,
sram_spec,
bits_per_elem,
False,
)
return totals, energy
def _validate_sgd_group(self, group: dict[str, Any]) -> None:
if group.get("nesterov", False):
raise ValueError(
"Exact NeuroMC SGD modeling currently supports only nesterov=False."
)
if float(group.get("dampening", 0.0)) != 0.0:
raise ValueError(
"Exact NeuroMC SGD modeling currently supports only dampening=0."
)
if group.get("maximize", False):
raise ValueError(
"Exact NeuroMC SGD modeling currently supports only maximize=False."
)
def _make_optimizer_fragment(
self,
*,
stage: str,
op_name: str,
kt: int,
fyfxkc: int,
optimizer_has_momentum: bool = False,
optimizer_has_weight_decay: bool = False,
optimizer_has_momentum_buffer: bool = False,
) -> _Fragment:
loop_dims = self._make_loop_dims(
batch_size=1,
channels_in=1,
channels_out=max(kt, 1),
oy=1,
ox=1,
fy=1,
fx=max(fyfxkc, 1),
)
loop_dims["K"] = kt
loop_dims["C"] = 1
loop_dims["FY"] = 1
loop_dims["FX"] = fyfxkc
return _Fragment(
stage=stage,
phase="optimizer",
op_name=op_name,
core_type="bp_grad_opt",
process_key="with_opt",
loop_dims=loop_dims,
input_precision_bits=32,
weight_precision_bits=32,
output_precision_bits=32,
input_numel=kt,
weight_numel=fyfxkc,
output_numel=kt + fyfxkc,
mac_count=0,
source="optimizer",
optimizer_has_momentum=optimizer_has_momentum,
optimizer_has_weight_decay=optimizer_has_weight_decay,
optimizer_has_momentum_buffer=optimizer_has_momentum_buffer,
)
def _optimizer_fragments(self, stage: str) -> list[_Fragment]:
if self._optimizer is None:
raise RuntimeError("Optimizer stage requires a bound optimizer.")
if isinstance(self._optimizer, (torch.optim.Adam, torch.optim.AdamW)):
kt = 0
fyfxkc = 0
for group in self._optimizer.param_groups:
for p in group["params"]:
if p.grad is None:
continue
if p.ndim <= 1:
kt += int(p.numel())
else:
fyfxkc += int(p.numel())
return [
self._make_optimizer_fragment(
stage=stage,
op_name=type(self._optimizer).__name__.lower(),
kt=kt,
fyfxkc=fyfxkc,
)
]
if isinstance(self._optimizer, torch.optim.SGD):
buckets: dict[tuple[bool, bool, bool, str], int] = defaultdict(int)
for group in self._optimizer.param_groups:
self._validate_sgd_group(group)
has_momentum = float(group.get("momentum", 0.0)) > 0.0
has_weight_decay = float(group.get("weight_decay", 0.0)) > 0.0
for p in group["params"]:
if p.grad is None:
continue
bucket_name = "kt" if p.ndim <= 1 else "fyfxkc"
has_momentum_buffer = (
has_momentum and "momentum_buffer" in self._optimizer.state[p]
)
buckets[
(
has_momentum,
has_weight_decay,
has_momentum_buffer,
bucket_name,
)
] += int(p.numel())
fragments = []
for (
has_momentum,
has_weight_decay,
has_momentum_buffer,
bucket_name,
), count in buckets.items():
if count <= 0:
continue
fragments.append(
self._make_optimizer_fragment(
stage=stage,
op_name="sgd",
kt=count if bucket_name == "kt" else 0,
fyfxkc=count if bucket_name == "fyfxkc" else 0,
optimizer_has_momentum=has_momentum,
optimizer_has_weight_decay=has_weight_decay,
optimizer_has_momentum_buffer=has_momentum_buffer,
)
)
return fragments
raise ValueError(
"Exact NeuroMC optimizer modeling only supports Adam/AdamW and "
"common SGD (nesterov=False, dampening=0, maximize=False); "
f"got {type(self._optimizer).__name__}."
)
def _optimizer_fragment(self, stage: str) -> _Fragment:
fragments = self._optimizer_fragments(stage)
if len(fragments) != 1:
raise RuntimeError(
"Optimizer expands to multiple NeuroMC fragments; use "
"_optimizer_fragments() instead."
)
return fragments[0]
[文档]
def get_report(self) -> NeuroMCRuntimeEnergyReport:
fragments = list(self._fragments)
trace_fragments = self._fallback_fragments_from_trace(
supplemental_only=bool(fragments)
)
if fragments:
existing_signatures = {
_module_overlap_signature(fragment)
for fragment in fragments
if fragment.source == "module"
}
fragments.extend(
fragment
for fragment in trace_fragments
if _module_overlap_signature(fragment) not in existing_signatures
)
else:
fragments = trace_fragments
unsupported = self._unsupported_ops()
if unsupported:
raise ValueError(
"Exact NeuroMC runtime does not support these aten ops: "
+ ", ".join(unsupported[:30])
)
if not fragments:
raise ValueError("No NeuroMC-supported runtime fragments were recognized.")
energy_by_stage = defaultdict(float)
energy_by_op = defaultdict(float)
energy_by_core_type = defaultdict(float)
energy_by_process_key = defaultdict(float)
energy_by_memory_level_dir = defaultdict(lambda: defaultdict(float))
counts_by_core_type = defaultdict(lambda: defaultdict(int))
counts_by_process_key = defaultdict(lambda: defaultdict(int))
primitive_by_stage = defaultdict(lambda: defaultdict(int))
primitive_by_op = defaultdict(lambda: defaultdict(int))
memory_bits_by_level = defaultdict(lambda: defaultdict(int))
mapping_summary = []
warnings = list(self._warnings)
energy_mac = 0.0
energy_base_memory = 0.0
energy_extra_memory = 0.0
energy_extra_compute = 0.0
for fragment in fragments:
extra_counts = self._extra_counts(fragment)
base_bits, base_energy = self._base_memory_for_fragment(fragment)
extra_bits, extra_mem_energy = self._extra_memory_for_fragment(fragment)
mac_energy = fragment.mac_count * _MAC_COST_PJ.get(fragment.core_type, 0.0)
extra_compute_energy = (
extra_counts["mux"] * _EXTRA_OP_COST_PJ["mux"]
+ extra_counts["add"] * _EXTRA_OP_COST_PJ["add"]
+ extra_counts["mul"] * _EXTRA_OP_COST_PJ["mul"]
+ extra_counts["comp"] * _EXTRA_OP_COST_PJ["comp"]
+ extra_counts["sqrt"] * _EXTRA_OP_COST_PJ["sqrt"]
)
base_energy_total = sum(
sum(direction_map.values()) for direction_map in base_energy.values()
)
extra_memory_total = sum(
sum(direction_map.values())
for direction_map in extra_mem_energy.values()
)
total = (
mac_energy
+ extra_compute_energy
+ base_energy_total
+ extra_memory_total
)
energy_mac += mac_energy
energy_base_memory += base_energy_total
energy_extra_memory += extra_memory_total
energy_extra_compute += extra_compute_energy
energy_by_stage[fragment.stage] += total
energy_by_op[fragment.op_name] += total
energy_by_core_type[fragment.core_type] += total
energy_by_process_key[fragment.process_key] += total
counts_by_core_type[fragment.core_type]["fragments"] += 1
counts_by_core_type[fragment.core_type]["mac"] += fragment.mac_count
counts_by_process_key[fragment.process_key]["fragments"] += 1
counts_by_process_key[fragment.process_key]["mac"] += fragment.mac_count
for primitive, count in extra_counts.items():
primitive_by_stage[fragment.stage][primitive] += count
primitive_by_op[fragment.op_name][primitive] += count
counts_by_core_type[fragment.core_type][primitive] += count
counts_by_process_key[fragment.process_key][primitive] += count
for level, directions in base_bits.items():
for direction, bits in directions.items():
memory_bits_by_level[level][direction] += bits
for level, directions in extra_bits.items():
for direction, bits in directions.items():
memory_bits_by_level[level][direction] += bits
for level, directions in base_energy.items():
for direction, value in directions.items():
energy_by_memory_level_dir[level][direction] += value
for level, directions in extra_mem_energy.items():
for direction, value in directions.items():
energy_by_memory_level_dir[level][direction] += value
mapping_summary.append(
{
"stage": fragment.stage,
"phase": fragment.phase,
"op_name": fragment.op_name,
"core_type": fragment.core_type,
"process_key": fragment.process_key,
"input_precision_bits": fragment.input_precision_bits,
"weight_precision_bits": fragment.weight_precision_bits,
"output_precision_bits": fragment.output_precision_bits,
"loop_dims": dict(fragment.loop_dims),
"b_type": fragment.b_type,
"t_type": fragment.t_type,
"mac_count": fragment.mac_count,
"source": fragment.source,
"optimizer_has_momentum": fragment.optimizer_has_momentum,
"optimizer_has_weight_decay": fragment.optimizer_has_weight_decay,
"optimizer_has_momentum_buffer": fragment.optimizer_has_momentum_buffer,
}
)
primitive_totals = {
primitive: int(
sum(
stage_counts.get(primitive, 0)
for stage_counts in primitive_by_stage.values()
)
)
for primitive in ("mux", "add", "mul", "comp", "sqrt")
}
primitive_totals["mac"] = int(sum(fragment.mac_count for fragment in fragments))
energy_compute_pj = energy_mac + energy_extra_compute
energy_memory_pj = energy_base_memory + energy_extra_memory
energy_total_pj = energy_compute_pj + energy_memory_pj
primitive_counts = {
"totals": primitive_totals,
"by_stage": {k: dict(v) for k, v in primitive_by_stage.items()},
"by_op": {k: dict(v) for k, v in primitive_by_op.items()},
"core_type": self.core_type,
}
memory_report = {
"preset_name": self.memory_config.preset_name,
"totals": {
level: int(sum(directions.values()))
for level, directions in memory_bits_by_level.items()
},
"by_level_dir": {k: dict(v) for k, v in memory_bits_by_level.items()},
}
return NeuroMCRuntimeEnergyReport(
energy_total_pj=energy_total_pj,
energy_compute_pj=energy_compute_pj,
energy_memory_pj=energy_memory_pj,
energy_by_stage=dict(energy_by_stage),
energy_by_op=dict(energy_by_op),
primitive_counts=primitive_counts,
memory_bits_by_level=memory_report,
warnings=warnings,
energy_mac_pj=energy_mac,
energy_base_memory_pj=energy_base_memory,
energy_extra_memory_pj=energy_extra_memory,
energy_extra_compute_pj=energy_extra_compute,
energy_by_core_type=dict(energy_by_core_type),
energy_by_process_key=dict(energy_by_process_key),
energy_by_memory_level_dir={
level: dict(direction_map)
for level, direction_map in energy_by_memory_level_dir.items()
},
counts_by_core_type={k: dict(v) for k, v in counts_by_core_type.items()},
counts_by_process_key={
k: dict(v) for k, v in counts_by_process_key.items()
},
mapping_summary=mapping_summary,
)
[文档]
def get_total(self) -> float:
return self.get_report().energy_total_pj
[文档]
def get_counts(self) -> dict[str, Any]:
report = self.get_report()
return {
"primitive_counts": report.primitive_counts,
"memory_bits_by_level": report.memory_bits_by_level,
"counts_by_core_type": report.counts_by_core_type,
"counts_by_process_key": report.counts_by_process_key,
}
[文档]
def record_optimizer_step(self, stage: str = "optimizer") -> None:
self._fragments.extend(self._optimizer_fragments(stage))
[文档]
def estimate_neuromc_runtime_energy(
model: nn.Module,
inputs,
*,
target: torch.Tensor | None = None,
loss_fn: Callable | None = None,
optimizer: torch.optim.Optimizer | None = None,
core_type: str = "fp_soma",
op_cost_pj: dict[str, float] | None = None,
memory_cost_pj_per_bit: dict[str, float] | None = None,
memory_level_weights: dict[str, float] | None = None,
memory_model: str | None = None,
memory_config: MemoryHierarchyConfig | None = None,
strict: bool = False,
verbose: bool = False,
extra_ignore_modules: list[nn.Module] | None = None,
) -> NeuroMCRuntimeEnergyReport:
r"""
**API Language:**
:ref:`中文 <estimate_neuromc_runtime_energy-cn>` | :ref:`English <estimate_neuromc_runtime_energy-en>`
----
.. _estimate_neuromc_runtime_energy-cn:
* **中文**
estimate neuromc runtime energy 函数
:param model: The PyTorch model to profile
:type model: ``nn.Module``
:param inputs: Input tensors for the forward pass
:type inputs: Any
:param target: Target tensors for loss computation
:type target: torch.Tensor | None
:param loss_fn: Loss function for the backward pass
:type loss_fn: Callable | None
:param optimizer: Optimizer for training-stage profiling
:type optimizer: torch.optim.Optimizer | None
:param core_type: Type of compute core (e.g., ``\"fp_soma\"``)
:type core_type: str
:param op_cost_pj: (Deprecated) Ignored
:type op_cost_pj: dict[str, float] | None
:param memory_cost_pj_per_bit: (Deprecated) Ignored
:type memory_cost_pj_per_bit: dict[str, float] | None
:param memory_level_weights: (Deprecated) Ignored
:type memory_level_weights: dict[str, float] | None
:param memory_model: (Deprecated) Ignored
:type memory_model: str | None
:param memory_config: Memory hierarchy configuration. If ``None``, uses the default config
:type memory_config: MemoryHierarchyConfig | None
:param strict: If ``True``, raise on unknown operations instead of warning
:type strict: bool
:param verbose: If ``True``, print progress information during profiling
:type verbose: bool
:param extra_ignore_modules: Additional module types to ignore during counting
:type extra_ignore_modules: list[nn.Module] | None
:return: Energy profiling report
:rtype: NeuroMCRuntimeEnergyReport
This is a convenience function that creates a :class:`NeuroMCEnergyProfiler`,
binds the model and optional optimizer, runs the full profile, and returns
the energy report.
----
.. _estimate_neuromc_runtime_energy-en:
* **English**
Estimate Neuromc Runtime Energy function
:param model: The PyTorch model to profile
:param inputs: Input tensors for the forward pass
:param target: Target tensors for loss computation
:param loss_fn: Loss function for the backward pass
:param optimizer: Optimizer for training-stage profiling
:param core_type: Type of compute core (e.g., ``\"fp_soma\"``)
:param op_cost_pj: (Deprecated) Ignored
:param memory_cost_pj_per_bit: (Deprecated) Ignored
:param memory_level_weights: (Deprecated) Ignored
:param memory_model: (Deprecated) Ignored
:param memory_config: Memory hierarchy configuration. If ``None``, uses the default config
:param strict: If ``True``, raise on unknown operations instead of warning
:param verbose: If ``True``, print progress information during profiling
:param extra_ignore_modules: Additional module types to ignore during counting
:type model: ``nn.Module``
:type inputs: Any
:type target: torch.Tensor | None
:type loss_fn: Callable | None
:type optimizer: torch.optim.Optimizer | None
:type core_type: str
:type op_cost_pj: dict[str, float] | None
:type memory_cost_pj_per_bit: dict[str, float] | None
:type memory_level_weights: dict[str, float] | None
:type memory_model: str | None
:type memory_config: MemoryHierarchyConfig | None
:type strict: bool
:type verbose: bool
:type extra_ignore_modules: list[nn.Module] | None
:return: Energy profiling report
:rtype: NeuroMCRuntimeEnergyReport
"""
if (
op_cost_pj is not None
or memory_cost_pj_per_bit is not None
or memory_level_weights is not None
or memory_model is not None
):
warnings.warn(
"Legacy kwargs (op_cost_pj, memory_cost_pj_per_bit, "
"memory_level_weights, memory_model) are deprecated and ignored "
"by exact NeuroMC runtime profiling.",
DeprecationWarning,
stacklevel=2,
)
profiler = NeuroMCEnergyProfiler(
core_type=core_type,
memory_config=(memory_config or MemoryHierarchyConfig.neuromc_like_v1()).copy(),
strict=strict,
verbose=verbose,
extra_ignore_modules=extra_ignore_modules,
)
profiler.bind_model(model)
profiler.bind_optimizer(optimizer)
resolved_loss_fn = _resolve_loss_fn(loss_fn)
_clear_existing_grads(model, optimizer)
with profiler:
with profiler.stage("forward"):
output = _call_model(model, inputs)
loss = None
if resolved_loss_fn is not None:
with profiler.suspend():
if target is None:
raise ValueError("target is required when loss_fn is provided")
loss = resolved_loss_fn(output, target)
if loss is not None:
with profiler.stage("backward"):
loss.backward()
if optimizer is not None:
with profiler.stage("optimizer"):
profiler.record_optimizer_step("optimizer")
with profiler.suspend():
optimizer.step()
optimizer.zero_grad(set_to_none=True)
elif optimizer is not None:
raise ValueError(
"Exact NeuroMC optimizer modeling requires loss_fn and target; "
"optimizer.step() without backward is unsupported."
)
return profiler.get_report()