spikingjelly.activation_based.distributed.api 源代码

from __future__ import annotations

from typing import Mapping, Optional, Sequence, Union

import torch.nn as nn

from .adapters import resolve_adapter
from .dtensor import (
    SNNDistributedAnalysis,
    SNNDistributedConfig,
    SNN_DISTRIBUTED_PREFERENCES,
    analyze_snn_distributed_capability,
    configure_snn_distributed,
    recommend_snn_distributed_strategy,
)
from .planner import (
    DistributedFeatureSet,
    SNNDistributedPlan,
)
from .runtime import SNNDistributedRuntime
from .topology import SNNDistributedTopology


def _normalize_mode(mode: Optional[str]) -> Optional[str]:
    if mode is None:
        return None
    mode = mode.lower()
    valid_modes = ("none", "dp", "tp", "fsdp2", "fsdp2_tp", "pp")
    if mode not in valid_modes:
        raise ValueError(f"Unsupported mode='{mode}'. Expected one of {valid_modes}.")
    if mode == "pp":
        raise NotImplementedError(
            "Pipeline parallelism ('pp') is not supported by the unified analyze/plan/apply API. "
            "Please use the dedicated pipeline configuration path directly."
        )
    return mode


[文档] def analyze( model: nn.Module, *, model_family: Optional[str] = None, roots: Optional[Sequence[str]] = None, ) -> SNNDistributedAnalysis: return analyze_snn_distributed_capability(model, tensor_parallel_roots=roots)
[文档] def plan( *, analysis: SNNDistributedAnalysis, objective: str, topology: Union[Mapping[str, int], SNNDistributedTopology], backend: str, batch_size: int, model_family: Optional[str] = None, mode: Optional[str] = None, features: Optional[DistributedFeatureSet] = None, ) -> SNNDistributedPlan: objective = objective.lower() if objective not in SNN_DISTRIBUTED_PREFERENCES: raise ValueError( f"Unsupported objective='{objective}'. Expected one of {SNN_DISTRIBUTED_PREFERENCES}." ) resolved_topology = ( topology if isinstance(topology, SNNDistributedTopology) else SNNDistributedTopology.from_mapping(topology) ) features = features or DistributedFeatureSet() resolved_model_family = model_family or "generic" tp_mesh_dim = ( resolved_topology.ordered_dim_names.index("tp") if "tp" in resolved_topology.ordered_dim_names else 0 ) dp_mesh_dim = ( resolved_topology.ordered_dim_names.index("dp") if "dp" in resolved_topology.ordered_dim_names else None ) recommendation = recommend_snn_distributed_strategy( model=resolved_model_family, world_size=resolved_topology.world_size, prefer=objective, batch_size=batch_size, backend=backend, pipelining_available=False, ) mode = _normalize_mode(mode) notes = list(analysis.notes) selected_mode = mode or recommendation.mode if selected_mode == "fsdp2_tp": if "dp" not in resolved_topology.ordered_dim_names or "tp" not in resolved_topology.ordered_dim_names: raise ValueError( "Hybrid 'fsdp2_tp' mode requires both 'dp' and 'tp' dimensions in the topology." ) if ( selected_mode in ("tp", "fsdp2_tp") and not analysis.tensor_parallel_candidate_names ): raise ValueError( f"mode='{selected_mode}' requires at least one tensor-parallel candidate, but analysis found none." ) optimizer_strategy = recommendation.optimizer_sharding if selected_mode != "dp" or not features.allow_zero_optimizer: optimizer_strategy = "none" if selected_mode == "dp" and not features.allow_zero_optimizer: notes.append( "Zero optimizer was disabled by DistributedFeatureSet; planner fell back to optimizer_strategy='none'." ) mesh_shape = resolved_topology.mesh_shape pp_microbatches = recommendation.pp_microbatches pp_schedule = recommendation.pp_schedule pp_virtual_stages = recommendation.pp_virtual_stages pp_layout = recommendation.pp_layout pp_delay_wgrad = recommendation.pp_delay_wgrad return SNNDistributedPlan( mode=selected_mode, objective=objective, topology=resolved_topology, model_family=resolved_model_family, backend=backend, batch_size=batch_size, optimizer_strategy=optimizer_strategy, memopt_level=recommendation.memopt_level, rationale=tuple(recommendation.rationale), notes=tuple(notes), tensor_parallel_roots=analysis.tensor_parallel_roots, mesh_shape=mesh_shape, tp_mesh_dim=tp_mesh_dim, dp_mesh_dim=dp_mesh_dim, pp_microbatches=pp_microbatches, pp_schedule=pp_schedule, pp_virtual_stages=pp_virtual_stages, pp_layout=pp_layout, pp_delay_wgrad=pp_delay_wgrad, experimental_features=features, )
[文档] def apply( *, model: nn.Module, plan: SNNDistributedPlan, device_type: str = "cuda", device_mesh=None, ) -> SNNDistributedRuntime: wrapped = getattr(model, "module", None) if isinstance(wrapped, nn.Module): model = wrapped topology = ( plan.topology if isinstance(plan.topology, SNNDistributedTopology) else SNNDistributedTopology.from_mapping(plan.topology) ) if device_mesh is not None: mesh_tensor = getattr(device_mesh, "mesh", None) mesh_volume = None if mesh_tensor is not None: mesh_volume = int(mesh_tensor.numel()) elif hasattr(device_mesh, "size"): try: mesh_volume = int(device_mesh.size()) except TypeError: mesh_volume = None if mesh_volume is not None and mesh_volume != topology.world_size: raise ValueError( f"device_mesh spans {mesh_volume} ranks, but plan.topology.world_size={topology.world_size}." ) if plan.mode == "pp": raise NotImplementedError( "Pipeline parallelism ('pp') is not supported via the unified `apply` API " "because it requires an `example_input` to partition the model and measure stage costs. " "Please use the dedicated pipeline configuration path directly." ) use_adapter = plan.mode in ("tp", "fsdp2", "fsdp2_tp") or ( plan.experimental_features.allow_experimental_conv_tp or plan.experimental_features.allow_experimental_spikformer_tp ) adapter = resolve_adapter(model, plan.model_family) if use_adapter else None if adapter is not None: return adapter.apply( model, plan, device_type=device_type, device_mesh=device_mesh, ) config = SNNDistributedConfig( device_type=device_type, mesh_shape=plan.mesh_shape or topology.mesh_shape, device_mesh=device_mesh, tp_mesh_dim=plan.tp_mesh_dim, dp_mesh_dim=plan.dp_mesh_dim, enable_data_parallel=plan.mode == "dp", enable_fsdp2=plan.mode in ("fsdp2", "fsdp2_tp"), tensor_parallel_roots=plan.tensor_parallel_roots, auto_tensor_parallel=plan.mode in ("tp", "fsdp2_tp"), ) configured_model, mesh, analysis = configure_snn_distributed(model, config) return SNNDistributedRuntime( kind="eager", model=configured_model, mesh=mesh, analysis=analysis, plan=plan, )