spikingjelly.activation_based.distributed package#

本子包提供基于 torch.distributedDTensor、tensor parallel 与 FSDP2 的实验性分布式训练工具,面向 spikingjelly.activation_based 的多步 SNN。


This package provides experimental distributed-training helpers for multi-step SNNs in spikingjelly.activation_based based on torch.distributed, DTensor, tensor parallelism, and FSDP2.

Distributed Helpers#

SNNDistributedConfig

High-level configuration for DTensor-ready SNN distribution.

SNNDistributedAnalysis

Capability analysis for stateful modules and tensor-parallel candidates.

ensure_distributed_initialized

Initialize torch.distributed when needed.

build_device_mesh

Build a DeviceMesh for tensor/data parallelism.

configure_snn_distributed

The main low-level entry for DTensor-ready SNN distribution.

configure_cifar10dvs_vgg_distributed

Convenience helper for CIFAR10DVSVGG with DP / TP.

configure_cifar10dvs_vgg_fsdp2

Convenience helper for CIFAR10DVSVGG with FSDP2 / FSDP2+TP.

materialize_dtensor_output

Convert a DTensor output back to a regular tensor when needed.

API Language - 中文 | English


  • 中文

分布式训练支持模块,包含张量并行和数据并行工具。


  • English

Distributed training support module with tensor and data parallelism utilities.

class spikingjelly.activation_based.distributed.DistributedFeatureSet(allow_experimental_conv_tp: 'bool' = False, allow_experimental_spikformer_tp: 'bool' = False, allow_pipeline: 'bool' = True, allow_zero_optimizer: 'bool' = True)[源代码]#

基类:object

参数:
  • allow_experimental_conv_tp (bool)

  • allow_experimental_spikformer_tp (bool)

  • allow_pipeline (bool)

  • allow_zero_optimizer (bool)

allow_experimental_conv_tp: bool = False#
allow_experimental_spikformer_tp: bool = False#
allow_pipeline: bool = True#
allow_zero_optimizer: bool = True#
class spikingjelly.activation_based.distributed.SNNDistributedPlan(mode: 'str', objective: 'str', topology: 'SNNDistributedTopology', model_family: 'str', backend: 'str', batch_size: 'int', optimizer_strategy: 'str', memopt_level: 'int', rationale: 'Tuple[str, ...]', notes: 'Tuple[str, ...]', tensor_parallel_roots: 'Optional[Tuple[str, ...]]' = None, mesh_shape: 'Optional[Tuple[int, ...]]' = None, tp_mesh_dim: 'int' = 0, dp_mesh_dim: 'Optional[int]' = None, pp_microbatches: 'Optional[int]' = None, pp_schedule: 'str' = '1f1b', pp_virtual_stages: 'int' = 1, pp_layout: 'Optional[Tuple[int, ...]]' = None, pp_delay_wgrad: 'bool' = False, experimental_features: 'DistributedFeatureSet' = DistributedFeatureSet(allow_experimental_conv_tp=False, allow_experimental_spikformer_tp=False, allow_pipeline=True, allow_zero_optimizer=True))[源代码]#

基类:object

参数:
dp_mesh_dim: int | None = None#
experimental_features: DistributedFeatureSet = DistributedFeatureSet(allow_experimental_conv_tp=False, allow_experimental_spikformer_tp=False, allow_pipeline=True, allow_zero_optimizer=True)#
mesh_shape: Tuple[int, ...] | None = None#
pp_delay_wgrad: bool = False#
pp_layout: Tuple[int, ...] | None = None#
pp_microbatches: int | None = None#
pp_schedule: str = '1f1b'#
pp_virtual_stages: int = 1#
tensor_parallel_roots: Tuple[str, ...] | None = None#
tp_mesh_dim: int = 0#
mode: str#
objective: str#
topology: SNNDistributedTopology#
model_family: str#
backend: str#
batch_size: int#
optimizer_strategy: str#
memopt_level: int#
rationale: Tuple[str, ...]#
notes: Tuple[str, ...]#
class spikingjelly.activation_based.distributed.SNNDistributedAnalysis(memory_module_names, tensor_parallel_candidate_names, unsupported_tensor_parallel_names, notes, tensor_parallel_roots=None)[源代码]#

基类:object

API Language - 中文 | English


  • 中文

SNN 分布式训练分析器。分析模型结构并推荐并行策略。


  • English

SNN distributed training analyzer.

参数:
  • memory_module_names (Tuple[str, ...])

  • tensor_parallel_candidate_names (Tuple[str, ...])

  • unsupported_tensor_parallel_names (Tuple[str, ...])

  • notes (Tuple[str, ...])

  • tensor_parallel_roots (Tuple[str, ...] | None)

tensor_parallel_roots: Tuple[str, ...] | None = None#
memory_module_names: Tuple[str, ...]#
tensor_parallel_candidate_names: Tuple[str, ...]#
unsupported_tensor_parallel_names: Tuple[str, ...]#
notes: Tuple[str, ...]#
class spikingjelly.activation_based.distributed.SNNDistributedRuntime(kind: 'str', model: 'nn.Module', mesh: 'Optional[object]', analysis: 'Optional[SNNDistributedAnalysis]', plan: 'Optional[SNNDistributedPlan]' = None, mode: 'str' = 'none', pipeline_runtime: 'Optional[SNNPipelineRuntime]' = None)[源代码]#

基类:object

参数:
build_optimizer(optimizer_cls=<class 'torch.optim.adam.Adam'>, lr=0.001, weight_decay=0.0, **kwargs)[源代码]#
参数:
forward_loss(criterion, images, labels)[源代码]#
参数:
classmethod from_legacy(*, kind, model, mesh, analysis, mode, pipeline_runtime=None)[源代码]#
参数:
返回类型:

SNNDistributedRuntime

mode: str = 'none'#
pipeline_runtime: SNNPipelineRuntime | None = None#
plan: SNNDistributedPlan | None = None#
prepare_classification_output(outputs, labels, *, return_metadata=False)[源代码]#
参数:
返回类型:

Tuple[Tensor, Tensor] | PreparedModelOutput

prepare_dataloader(*, dataset, batch_size, shuffle, num_workers, drop_last, pin_memory=True)[源代码]#
参数:
返回类型:

DataLoader

static reduce_classification_output(outputs, labels)[源代码]#
参数:
返回类型:

Tuple[Tensor, Tensor]

reset_state()[源代码]#

API Language - 中文 | English


  • 中文

重置模型中所有有状态模块(如神经元膜电位)。


  • English

Reset all stateful modules in the model (e.g. neuron membrane potentials).

kind: str#
model: Module#
mesh: object | None#
analysis: SNNDistributedAnalysis | None#
class spikingjelly.activation_based.distributed.SNNDistributedTopology(world_size: 'int', dims: 'Mapping[str, int]')[源代码]#

基类:object

参数:
classmethod from_mapping(dims, *, world_size=None)[源代码]#
参数:
返回类型:

SNNDistributedTopology

property mesh_shape: Tuple[int, ...]#
property ordered_dim_names: Tuple[str, ...]#
world_size: int#
dims: Mapping[str, int]#
class spikingjelly.activation_based.distributed.TensorShardMemoryModule(source, shard_dim, logical_dim_size=None, process_group=None)[源代码]#

基类:MemoryModule

API Language - 中文 | English


  • 中文

支持张量并行分片的内存模块基类。

参数:
  • source (MemoryModule) -- 源 MemoryModule

  • shard_dim (int) -- 切分维度

  • logical_dim_size (Optional[int]) -- 逻辑维度大小(每一维的大小),用于验证分片正确性

  • process_group (Any) -- 分布式进程组


  • English

Base memory module supporting tensor parallel sharding.

参数:
  • source (MemoryModule) -- Source MemoryModule

  • shard_dim (int) -- Dimension along which to shard

  • logical_dim_size (Optional[int]) -- Logical dimension size, used to validate sharding

  • process_group (Any) -- Distributed process group

extra_repr()[源代码]#
返回类型:

str

forward(x)[源代码]#
参数:

x (Tensor)

reset()[源代码]#
property store_v_seq#
property supported_backends#
spikingjelly.activation_based.distributed.analyze(model, *, model_family=None, roots=None)[源代码]#
参数:
返回类型:

SNNDistributedAnalysis

spikingjelly.activation_based.distributed.apply(*, model, plan, device_type='cuda', device_mesh=None)[源代码]#
参数:
返回类型:

SNNDistributedRuntime

spikingjelly.activation_based.distributed.apply_pipeline_stage_memopt(runtime, *, memopt_level, compress_x=False, stage_budget_ratio=0.5, use_plan_cache=True)[源代码]#
参数:
  • runtime (SNNPipelineRuntime)

  • memopt_level (int)

  • compress_x (bool)

  • stage_budget_ratio (float)

  • use_plan_cache (bool)

返回类型:

Tuple[SNNPipelineRuntime, float, bool]

spikingjelly.activation_based.distributed.build_snn_optimizer(module, mode, lr, weight_decay=0.0, optimizer_sharding='none', foreach=None, optimizer_cls=<class 'torch.optim.adam.Adam'>, **optimizer_kwargs)[源代码]#
参数:
spikingjelly.activation_based.distributed.build_device_mesh(device_type='cuda', mesh_shape=None, mesh_dim_names=None)[源代码]#
参数:
  • device_type (str)

  • mesh_shape (Tuple[int, ...] | None)

  • mesh_dim_names (Tuple[str, ...] | None)

返回类型:

DeviceMesh

spikingjelly.activation_based.distributed.enable_tp_communication_debug(enabled=True)[源代码]#
参数:

enabled (bool)

返回类型:

None

spikingjelly.activation_based.distributed.ensure_distributed_initialized(backend=None, init_method=None, rank=None, world_size=None)[源代码]#
参数:
  • backend (str | None)

  • init_method (str | None)

  • rank (int | None)

  • world_size (int | None)

返回类型:

bool

spikingjelly.activation_based.distributed.get_tp_communication_debug_stats()[源代码]#
返回类型:

Dict[str, int]

spikingjelly.activation_based.distributed.plan(*, analysis, objective, topology, backend, batch_size, model_family=None, mode=None, features=None)[源代码]#
参数:
返回类型:

SNNDistributedPlan

spikingjelly.activation_based.distributed.recommended_pipeline_microbatches(batch_size, num_stages)[源代码]#

API Language - 中文 | English


推荐流水线并行的微批次数量。


Recommend microbatches for pipeline parallelism.

参数:
  • batch_size (int)

  • num_stages (int)

返回类型:

int

spikingjelly.activation_based.distributed.recommend_snn_distributed_strategy(model, world_size, prefer, batch_size, backend='inductor', zero_redundancy_optimizer_available=None, pipelining_available=None, fsdp2_available=None, tensor_parallel_available=None)[源代码]#

API Language - 中文 | English


  • 中文

推荐 SNN 分布式训练策略。


  • English

Recommend SNN distributed strategy.

参数:
  • model (str)

  • world_size (int)

  • prefer (str)

  • batch_size (int)

  • backend (str)

  • zero_redundancy_optimizer_available (bool | None)

  • pipelining_available (bool | None)

  • fsdp2_available (bool | None)

  • tensor_parallel_available (bool | None)

返回类型:

SNNDistributedRecommendation

spikingjelly.activation_based.distributed.recommend_pipeline_memopt_stages(stage_costs, stage_budget_ratio=0.5)[源代码]#
参数:
返回类型:

Tuple[int, ...]

spikingjelly.activation_based.distributed.reset_tp_communication_debug_stats()[源代码]#
返回类型:

None

spikingjelly.activation_based.distributed.resolve_data_parallel_partition(device_mesh, dp_mesh_dim, sharded_by_data_parallel)[源代码]#
参数:
  • device_mesh (DeviceMesh | None)

  • dp_mesh_dim (int | None)

  • sharded_by_data_parallel (bool)

返回类型:

Tuple[int, int]

spikingjelly.activation_based.distributed.resolve_tensor_parallel_group_size(device_mesh, tp_mesh_dim, tensor_parallel_enabled)[源代码]#
参数:
返回类型:

int

spikingjelly.activation_based.distributed.unwrap_parallel_module(module)[源代码]#
参数:

module (Module)

返回类型:

Module