SNN 分布式训练(DTensor / FSDP2)#

English version: Distributed SNN Training (DTensor / FSDP2)

本教程介绍 spikingjelly.activation_based.distributed 中新增的实验性分布式训练工具。当前实现重点支持:

  • DP:常规数据并行

  • TP:面向 SNN 的简单 tensor parallel

  • FSDP2:基于 DTensor 的参数、梯度与优化器状态分片

  • FSDP2 + TP:推荐的混合分布式训练方案

  • PP:实验性的 pipeline parallel(当前先支持 CIFAR10DVSVGGSpikformer

其中,传统的 DDP + TP 组合在当前 PyTorch 版本下仍会在参数同步阶段遇到 Tensor / DTensor 混用问题,因此本实现会直接提示用户改用 FSDP2 + TP

快速开始#

当前低层入口是 configure_snn_distributed,它位于 spikingjelly.activation_based.distributed.dtensor。统一的新公开接口则是 spikingjelly.activation_based.distributed 下的 analyze / plan / apply

例如,对 CIFAR10DVSVGG 启用纯 FSDP2:

from spikingjelly.activation_based.distributed.dtensor import (
    SNNDistributedConfig,
    configure_snn_distributed,
)
from spikingjelly.activation_based.examples.memopt.models import CIFAR10DVSVGG
import torch.distributed as dist

world_size = dist.get_world_size()  # 需要先 init_process_group()
model = CIFAR10DVSVGG(dropout=0.0, backend='inductor')
model, mesh, analysis = configure_snn_distributed(
    model,
    SNNDistributedConfig(
        device_type='cuda',
        mesh_shape=(world_size,),
        auto_tensor_parallel=False,
        enable_fsdp2=True,
        fsdp_shard_roots=['features', 'classifier'],
        fsdp_shard_module_root=True,
        dp_mesh_dim=0,
    ),
)

若想启用 FSDP2 + TP,可使用 2D mesh:

model, mesh, analysis = configure_snn_distributed(
    model,
    SNNDistributedConfig(
        device_type='cuda',
        mesh_shape=(2, 2),   # (dp, tp)
        enable_fsdp2=True,
        fsdp_shard_roots=['features'],
        fsdp_shard_module_root=False,
        tensor_parallel_roots=['classifier'],
        auto_tensor_parallel=True,
        experimental_conv_tensor_parallel=True,
        conv_tensor_parallel_roots=['features'],
        dp_mesh_dim=0,
        tp_mesh_dim=1,
    ),
)

训练脚本#

仓库中提供了一个真实训练入口:

  • spikingjelly/activation_based/examples/memopt/train_distributed.py

示例:

torchrun --nproc_per_node=4 \
  spikingjelly/activation_based/examples/memopt/train_distributed.py \
  --data-dir /path/to/cifar10dvs \
  --distributed-mode fsdp2_tp \
  --mesh-shape 2 2 \
  --backend inductor \
  --batch-size 16 \
  --epochs 1 \
  --print-summary

该脚本支持的模式有:

  • none

  • dp

  • tp

  • fsdp2

  • fsdp2_tp

  • pp:当前是实验性训练入口,优先面向 smoke benchmark 和结构验证

PP 还支持一组更接近 Megatron 风格的调度与布局参数:

  • --pp-schedule gpipe:最简单的 GPipe 调度;

  • --pp-schedule 1f1b:标准 1F1B;

  • --pp-schedule interleaved:interleaved / VPP 风格调度;

  • --pp-schedule zero_bubble:基于 delayed-wgrad 的实验性 zero-bubble 调度;

  • --pp-virtual-stages N:每个物理 stage 持有 N 个虚拟 chunk;

  • --pp-layout:显式指定逻辑 stage 的连续切分,例如 1|2|2|1

  • --pp-delay-wgrad:在可用调度下显式打开 delayed-wgrad 风格优化。

历史上的 hybrid``(``DDP + TP)组合当前仍不支持,也没有在脚本里继续暴露;推荐直接使用 fsdp2_tp

如果希望在纯 dp 路径上进一步压缩优化器状态,还可以启用 ZeroRedundancyOptimizer

torchrun --nproc_per_node=2 \
  benchmark/benchmark_snn_distributed.py \
  --model cifar10dvs_vgg \
  --mode dp \
  --optimizer-sharding zero \
  --backend inductor \
  --batch-size 2 \
  --T 10

当前实现范围#

  1. 线性层使用官方 tensor-parallel API。

  2. 逐元素脉冲神经元现在会显式跟随上游 shard:

    • [T, N, C] 激活,按最后一维 C 切分;

    • [T, N, C, H, W] 激活,按通道维 C 切分。

    这意味着神经元内部状态 v 只保留本地 shard,而不是完整复制一份全局状态。

  3. CIFAR10DVSVGGConv + BN + Neuron 主干支持实验性的 channel tensor parallel。

  4. FSDP2 + TP 当前优先对 features 做 FSDP2 分片;当 classifier 已经启用 TP 时,不再额外对其做 root fully-shard,以避免跨 mesh 维度重复切分。

  5. 传统 hybrid``(即 ``DDP + TP)当前显式不支持,接口会直接提示改用 fsdp2_tp

  6. PP 当前通过手工 stage 切分实现,而不是依赖 torch.export 整图切分;这样可以兼容标准脉冲神经元的内部状态写入。

  7. PP 在 microbatch 之间会显式重置每个 stage 内的神经元状态,避免不同样本的状态串扰。

服务器实测结果(小网络 smoke benchmark)#

以下数据来自单机多卡服务器(RTX 4090),网络为 CIFAR10DVSVGG,后端为 inductor,输入配置为 batch_size=2T=10,指标为短步数训练 benchmark。表中的 global_samples/s 统一表示整个分布式作业的全局吞吐。这个工作负载非常小,更多用于 smoke test 和显存趋势对比,不适合作为最终扩展效率结论。

模式

GPU 数

step_ms

global_samples/s

peak_allocated_mb

备注

none

1

12.86

155.52

401.63

单卡基线

dp

2

83.71

47.78

434.25

纯 DDP,小 batch 下通信开销占主导

dp + zero

2

96.79

41.33

410.22

纯 DDP + ZeroRedundancyOptimizer

tp

2

86.58

23.10

308.88

纯 TP,神经元按 shard 后特征/通道局部执行

fsdp2

2

97.11

41.19

400.61

纯 FSDP2

fsdp2_tp

4

26.68

149.91

316.27

推荐的 FSDP2 + TP 方案

hybrid``(``DDP + TP

4

当前显式不支持;请改用 fsdp2_tp

从这组小网络 smoke benchmark 可以看出:

  • TPFSDP2 + TP 都已经可以在标准神经元 backend='inductor' 下完成真实 SNN 训练 step。

  • 显式 neuron shard 后,神经元状态会随特征/通道切分,而不再保持完整复制。

  • 即使在很小的网络和 batch 上,TP / FSDP2 + TP 也已经能带来可见的单卡显存下降。

  • DDP + TP 目前仍不推荐,建议直接使用 fsdp2_tp

实验性 PP benchmark(服务器复测)#

当前 PP 已经支持:

  • 基于 dry-run 实际耗时的 stage balance,而不是简单按层数均分;

  • 自动选择更积极的 pp_microbatches``(优先选择 ``batch_size 的可整除值);

  • gpipe / 1f1b / interleaved / zero_bubble 多种调度;

  • 显式 pp_layout 覆盖自动切分;

  • 更轻量的 microbatch reset 逻辑,减少每次调度的遍历开销。

下面的结果来自重新在服务器上跑的 schedule 对比。它们更适合回答“当前哪种 PP 调度更值得默认推荐”,而不是把 PP 说成已经是吞吐主力。

CIFAR10DVSVGGbackend='inductor',2 张 GPU,batch_size=8T=4

调度

pp_virtual_stages

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

gpipe

1

0

0.0

93.70

85.38

507.84

1f1b

1

0

0.0

102.65

77.93

259.09

interleaved

2

0

0.0

87.63

91.29

361.45

interleaved

2

1

1521.40

84.39

94.79

361.45

zero_bubble

2

0

0.0

145.17

55.11

452.67

zero_bubble

2

1

1535.38

118.00

67.80

452.67

spikformer_tibackend='inductor',2 张 GPU,batch_size=4T=8image_size=224

调度

pp_virtual_stages

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

gpipe

1

0

0.0

423.64

9.44

1286.03

1f1b

1

0

0.0

461.92

8.66

679.22

interleaved

2

0

0.0

394.63

10.14

1389.71

interleaved

2

1

112.83

423.73

9.44

541.91

zero_bubble

2

0

0.0

455.79

8.78

1356.73

zero_bubble

2

1

164.35

473.41

8.45

483.31

这组服务器复测说明:

  • PP 与标准神经元 backend='inductor' 已经能够真实训练;

  • CIFAR10DVSVGG 这类小型卷积 SNN,interleaved 目前是最好的默认调度,吞吐最好;1f1b 的优势更多体现在显存;

  • spikformer_tiinterleaved 同样是当前最好的默认调度;如果叠加 memopt level=1,可以把 peak_allocated_mb 从约 1.39 GB 压到约 0.54 GB

  • zero_bubble 已经能在 CIFAR10DVSVGGspikformer_ti 上功能跑通,但当前吞吐都还不占优;

  • spikformer_tizero_bubble + memopt level=1 现在也已经可用,并能把 peak_allocated_mb 压到约 0.48 GB

  • 不过 zero_bubble 仍会伴随额外的 inductor 重编译告警,因此当前更适合手动实验和容量优先场景,而不是默认推荐。

Spikformer 与 memopt 组合结果#

在更接近 ImageNet 训练设置的 spikformer_ti 上,TPFSDP2 + TP 也已经可以和 memopt level=1 结合使用。下面的实验使用:

  • 模型:spikformer_ti

  • 输入:224x224

  • batch_size=4

  • T=8

  • 后端:inductor

  • GPU:RTX 4090

模式

optimizer_sharding

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

none

none

0

0.0

36.70

109.00

2070.34

none

none

1

26852.97

57.35

69.74

1298.16

dp

none

0

0.0

126.56

63.21

2070.93

dp

zero

0

0.0

122.28

65.42

2055.70

dp

none

1

22591.25

134.48

59.49

1315.71

dp

zero

1

23030.79

149.21

53.61

1297.59

fsdp2

none

0

0.0

111.08

72.02

2033.86

fsdp2

none

1

22919.87

132.65

60.31

1272.13

tp

none

0

0.0

196.41

20.37

1321.38

tp

none

1

26913.14

173.65

23.03

767.51

fsdp2_tp

none

0

0.0

131.90

60.65

1319.68

fsdp2_tp

none

1

26403.47

103.95

76.96

761.26

可以看到:

  • memopt level=1none / dp / fsdp2 / tp / fsdp2_tp 都已经可以组合使用;

  • tp / fsdp2_tp / pp 上更高 level 的 memopt``(``level >= 2)现在也已经打通,做法是在 TP/FSDP2/PP 物化之前先完成 split-search;不过这类搜索开销很大,更适合离线调优或小规模 smoke 验证;

  • Spikformer 这类更大的 SNN,TP / FSDP2 + TPinductor 神经元下已经能明显降低单卡峰值显存;

  • 再叠加 memopt level=1 后,tpfsdp2_tp 的单卡峰值显存都可以压到约 0.76 GB

  • 这组 benchmark 里,fsdp2_tp + memopt level=1 同时拿到了更低显存和更好的吞吐;

  • dp + zero 是否优于纯 dp 取决于工作负载,在较大模型上更值得尝试。

推荐组合#

如果你的目标比较明确,可以按下面的经验规则选择:

  • 吞吐优先,显存压力不大

    • 对小模型或单卡训练,先看 none

    • 对更大的分布式 workload,优先尝试 fsdp2fsdp2_tp

    • dp + zero 可以作为纯数据并行路线的一个可选增强,但收益和 workload 强相关。

  • 单卡显存优先,尤其是 ImageNet / Transformer 型 SNN

    • 优先尝试 tp + memopt level=1fsdp2_tp + memopt level=1

    • 当前实测里,这两种组合都能把 Spikformer 的单卡峰值显存压到约 0.76 GB

  • 希望在速度和显存之间取得折中

    • fsdp2_tp 仍然是最稳妥的主推荐;

    • 如果你的工作负载与这里的 Spikformer benchmark 接近,可以直接试 fsdp2_tp + memopt level=1

    • 如果显存已经足够,则保留 fsdp2_tp 而不开 memopt,可以减少优化前处理时间。

  • 只想要最省心、最稳妥的分布式训练入口

    • dp 开始;

    • 如果要进一步扩展到更大模型,再迁移到 fsdp2fsdp2_tp

如果你不想自己手工挑模式,现在训练脚本和 benchmark 也支持高层自动推荐器:

torchrun --nproc_per_node=4 \
  spikingjelly/activation_based/examples/memopt/train_distributed.py \
  --data-dir /path/to/cifar10dvs \
  --distributed-mode auto \
  --prefer memory \
  --backend inductor \
  --batch-size 16

其中:

  • --prefer speed 倾向于选择吞吐优先的组合;

  • --prefer memory 倾向于选择单卡显存更低的组合;

  • --prefer capacity 倾向于选择更容易放下大模型的组合(优先考虑 PP)。

prefer=capacity 且环境允许时,自动推荐器会优先选择:

  • mode=pp

  • pp_virtual_stages=2

  • pp_schedule=interleaved

  • memopt level=1

zero_bubble 仍然作为显式可选项保留在命令行里。它现在已经能稳定跑通,但当前默认仍建议优先使用更稳、更快的 interleavedzero_bubble 更适合手动实验和容量优先场景。

如果显式指定了 --distributed-mode,那么 prefer 仍然可以帮你补默认的 memopt / optimizer_sharding 等参数,但不会覆盖你手工指定的模式。

Benchmark 自动记录与对比#

benchmark/benchmark_snn_distributed.py 现在会默认把结果追加到 benchmark/results/benchmark_snn_distributed.jsonl,并自动和同配置的上一条记录做对比。新版记录会显式区分 benchmark 口径与 batch 语义,统一保存:

  • benchmark_regimethroughput_weak_scaling / latency_strong_scaling / memory_capacity

  • global_batch_size

  • per_rank_batch_size

  • data_replicas

  • pp_memopt_stages

  • step_latency_ms

  • global_throughput_sps

  • per_device_throughput_sps

  • peak_allocated_mb

  • optimize_ms

  • forward_ms

  • backward_ms

  • optimizer_ms

  • reset_ms

  • materialize_ms

  • tp_all_reduce_calls

  • tp_all_reduce_mb

  • warning_count

  • recompile_count

  • graph_break_count

例如:

torchrun --nproc_per_node=4 \
  benchmark/benchmark_snn_distributed.py \
  --mode auto \
  --prefer speed \
  --model spikformer_ti \
  --backend inductor \
  --batch-size 4 \
  --T 8

当前建议避免的组合:

  • hybrid``(``DDP + TP):当前仍不支持;

  • 在大尺寸 Spikformer 工作负载上直接使用高 level memopt``(``level >= 2)做在线搜索:虽然功能上已经可用,但 optimize_ms 仍然很高,并且更容易触发 inductor 的额外重编译,建议先离线搜索、再固定策略。