SNN 分布式训练(DTensor / FSDP2)#
English version: Distributed SNN Training (DTensor / FSDP2)
本教程介绍 spikingjelly.activation_based.distributed 中新增的实验性分布式训练工具。当前实现重点支持:
DP:常规数据并行
TP:面向 SNN 的简单 tensor parallel
FSDP2:基于 DTensor 的参数、梯度与优化器状态分片
FSDP2 + TP:推荐的混合分布式训练方案
PP:实验性的 pipeline parallel(当前先支持
CIFAR10DVSVGG和Spikformer)
其中,传统的 DDP + TP 组合在当前 PyTorch 版本下仍会在参数同步阶段遇到 Tensor / DTensor 混用问题,因此本实现会直接提示用户改用 FSDP2 + TP。
快速开始#
当前低层入口是 configure_snn_distributed,它位于 spikingjelly.activation_based.distributed.dtensor。统一的新公开接口则是 spikingjelly.activation_based.distributed 下的 analyze / plan / apply。
例如,对 CIFAR10DVSVGG 启用纯 FSDP2:
from spikingjelly.activation_based.distributed.dtensor import (
SNNDistributedConfig,
configure_snn_distributed,
)
from spikingjelly.activation_based.examples.memopt.models import CIFAR10DVSVGG
import torch.distributed as dist
world_size = dist.get_world_size() # 需要先 init_process_group()
model = CIFAR10DVSVGG(dropout=0.0, backend='inductor')
model, mesh, analysis = configure_snn_distributed(
model,
SNNDistributedConfig(
device_type='cuda',
mesh_shape=(world_size,),
auto_tensor_parallel=False,
enable_fsdp2=True,
fsdp_shard_roots=['features', 'classifier'],
fsdp_shard_module_root=True,
dp_mesh_dim=0,
),
)
若想启用 FSDP2 + TP,可使用 2D mesh:
model, mesh, analysis = configure_snn_distributed(
model,
SNNDistributedConfig(
device_type='cuda',
mesh_shape=(2, 2), # (dp, tp)
enable_fsdp2=True,
fsdp_shard_roots=['features'],
fsdp_shard_module_root=False,
tensor_parallel_roots=['classifier'],
auto_tensor_parallel=True,
experimental_conv_tensor_parallel=True,
conv_tensor_parallel_roots=['features'],
dp_mesh_dim=0,
tp_mesh_dim=1,
),
)
训练脚本#
仓库中提供了一个真实训练入口:
spikingjelly/activation_based/examples/memopt/train_distributed.py
示例:
torchrun --nproc_per_node=4 \
spikingjelly/activation_based/examples/memopt/train_distributed.py \
--data-dir /path/to/cifar10dvs \
--distributed-mode fsdp2_tp \
--mesh-shape 2 2 \
--backend inductor \
--batch-size 16 \
--epochs 1 \
--print-summary
该脚本支持的模式有:
nonedptpfsdp2fsdp2_tppp:当前是实验性训练入口,优先面向 smoke benchmark 和结构验证
PP 还支持一组更接近 Megatron 风格的调度与布局参数:
--pp-schedule gpipe:最简单的 GPipe 调度;--pp-schedule 1f1b:标准 1F1B;--pp-schedule interleaved:interleaved / VPP 风格调度;--pp-schedule zero_bubble:基于 delayed-wgrad的实验性 zero-bubble 调度;--pp-virtual-stages N:每个物理 stage 持有N个虚拟 chunk;--pp-layout:显式指定逻辑 stage 的连续切分,例如1|2|2|1;--pp-delay-wgrad:在可用调度下显式打开 delayed-wgrad风格优化。
历史上的 hybrid``(``DDP + TP)组合当前仍不支持,也没有在脚本里继续暴露;推荐直接使用 fsdp2_tp。
如果希望在纯 dp 路径上进一步压缩优化器状态,还可以启用 ZeroRedundancyOptimizer:
torchrun --nproc_per_node=2 \
benchmark/benchmark_snn_distributed.py \
--model cifar10dvs_vgg \
--mode dp \
--optimizer-sharding zero \
--backend inductor \
--batch-size 2 \
--T 10
当前实现范围#
线性层使用官方 tensor-parallel API。
逐元素脉冲神经元现在会显式跟随上游 shard:
对
[T, N, C]激活,按最后一维C切分;对
[T, N, C, H, W]激活,按通道维C切分。
这意味着神经元内部状态
v只保留本地 shard,而不是完整复制一份全局状态。CIFAR10DVSVGG的Conv + BN + Neuron主干支持实验性的 channel tensor parallel。FSDP2 + TP当前优先对features做 FSDP2 分片;当classifier已经启用 TP 时,不再额外对其做 root fully-shard,以避免跨 mesh 维度重复切分。传统
hybrid``(即 ``DDP + TP)当前显式不支持,接口会直接提示改用fsdp2_tp。PP当前通过手工 stage 切分实现,而不是依赖torch.export整图切分;这样可以兼容标准脉冲神经元的内部状态写入。PP在 microbatch 之间会显式重置每个 stage 内的神经元状态,避免不同样本的状态串扰。
服务器实测结果(小网络 smoke benchmark)#
以下数据来自单机多卡服务器(RTX 4090),网络为 CIFAR10DVSVGG,后端为 inductor,输入配置为 batch_size=2、T=10,指标为短步数训练 benchmark。表中的 global_samples/s 统一表示整个分布式作业的全局吞吐。这个工作负载非常小,更多用于 smoke test 和显存趋势对比,不适合作为最终扩展效率结论。
模式 |
GPU 数 |
|
|
|
备注 |
|---|---|---|---|---|---|
|
1 |
12.86 |
155.52 |
401.63 |
单卡基线 |
|
2 |
83.71 |
47.78 |
434.25 |
纯 DDP,小 batch 下通信开销占主导 |
|
2 |
96.79 |
41.33 |
410.22 |
纯 DDP + |
|
2 |
86.58 |
23.10 |
308.88 |
纯 TP,神经元按 shard 后特征/通道局部执行 |
|
2 |
97.11 |
41.19 |
400.61 |
纯 FSDP2 |
|
4 |
26.68 |
149.91 |
316.27 |
推荐的 |
|
4 |
当前显式不支持;请改用 |
从这组小网络 smoke benchmark 可以看出:
TP和FSDP2 + TP都已经可以在标准神经元backend='inductor'下完成真实 SNN 训练 step。显式 neuron shard 后,神经元状态会随特征/通道切分,而不再保持完整复制。
即使在很小的网络和 batch 上,
TP/FSDP2 + TP也已经能带来可见的单卡显存下降。DDP + TP目前仍不推荐,建议直接使用fsdp2_tp。
实验性 PP benchmark(服务器复测)#
当前 PP 已经支持:
基于 dry-run 实际耗时的 stage balance,而不是简单按层数均分;
自动选择更积极的
pp_microbatches``(优先选择 ``batch_size的可整除值);gpipe / 1f1b / interleaved / zero_bubble多种调度;显式
pp_layout覆盖自动切分;更轻量的 microbatch reset 逻辑,减少每次调度的遍历开销。
下面的结果来自重新在服务器上跑的 schedule 对比。它们更适合回答“当前哪种 PP 调度更值得默认推荐”,而不是把 PP 说成已经是吞吐主力。
CIFAR10DVSVGG,backend='inductor',2 张 GPU,batch_size=8,T=4:
调度 |
|
|
|
|
|
|
|---|---|---|---|---|---|---|
|
1 |
0 |
0.0 |
93.70 |
85.38 |
507.84 |
|
1 |
0 |
0.0 |
102.65 |
77.93 |
259.09 |
|
2 |
0 |
0.0 |
87.63 |
91.29 |
361.45 |
|
2 |
1 |
1521.40 |
84.39 |
94.79 |
361.45 |
|
2 |
0 |
0.0 |
145.17 |
55.11 |
452.67 |
|
2 |
1 |
1535.38 |
118.00 |
67.80 |
452.67 |
spikformer_ti,backend='inductor',2 张 GPU,batch_size=4,T=8,image_size=224:
调度 |
|
|
|
|
|
|
|---|---|---|---|---|---|---|
|
1 |
0 |
0.0 |
423.64 |
9.44 |
1286.03 |
|
1 |
0 |
0.0 |
461.92 |
8.66 |
679.22 |
|
2 |
0 |
0.0 |
394.63 |
10.14 |
1389.71 |
|
2 |
1 |
112.83 |
423.73 |
9.44 |
541.91 |
|
2 |
0 |
0.0 |
455.79 |
8.78 |
1356.73 |
|
2 |
1 |
164.35 |
473.41 |
8.45 |
483.31 |
这组服务器复测说明:
PP与标准神经元backend='inductor'已经能够真实训练;对
CIFAR10DVSVGG这类小型卷积 SNN,interleaved目前是最好的默认调度,吞吐最好;1f1b的优势更多体现在显存;对
spikformer_ti,interleaved同样是当前最好的默认调度;如果叠加memopt level=1,可以把peak_allocated_mb从约1.39 GB压到约0.54 GB;zero_bubble已经能在CIFAR10DVSVGG和spikformer_ti上功能跑通,但当前吞吐都还不占优;对
spikformer_ti,zero_bubble + memopt level=1现在也已经可用,并能把peak_allocated_mb压到约0.48 GB;不过
zero_bubble仍会伴随额外的inductor重编译告警,因此当前更适合手动实验和容量优先场景,而不是默认推荐。
Spikformer 与 memopt 组合结果#
在更接近 ImageNet 训练设置的 spikformer_ti 上,TP 和 FSDP2 + TP 也已经可以和 memopt level=1 结合使用。下面的实验使用:
模型:
spikformer_ti输入:
224x224batch_size=4T=8后端:
inductorGPU:RTX 4090
模式 |
|
|
|
|
|
|
|---|---|---|---|---|---|---|
|
|
|
0.0 |
36.70 |
109.00 |
2070.34 |
|
|
|
26852.97 |
57.35 |
69.74 |
1298.16 |
|
|
|
0.0 |
126.56 |
63.21 |
2070.93 |
|
|
|
0.0 |
122.28 |
65.42 |
2055.70 |
|
|
|
22591.25 |
134.48 |
59.49 |
1315.71 |
|
|
|
23030.79 |
149.21 |
53.61 |
1297.59 |
|
|
|
0.0 |
111.08 |
72.02 |
2033.86 |
|
|
|
22919.87 |
132.65 |
60.31 |
1272.13 |
|
|
|
0.0 |
196.41 |
20.37 |
1321.38 |
|
|
|
26913.14 |
173.65 |
23.03 |
767.51 |
|
|
|
0.0 |
131.90 |
60.65 |
1319.68 |
|
|
|
26403.47 |
103.95 |
76.96 |
761.26 |
可以看到:
memopt level=1与none / dp / fsdp2 / tp / fsdp2_tp都已经可以组合使用;tp / fsdp2_tp / pp上更高 level 的memopt``(``level >= 2)现在也已经打通,做法是在 TP/FSDP2/PP 物化之前先完成 split-search;不过这类搜索开销很大,更适合离线调优或小规模 smoke 验证;对
Spikformer这类更大的 SNN,TP/FSDP2 + TP在inductor神经元下已经能明显降低单卡峰值显存;再叠加
memopt level=1后,tp与fsdp2_tp的单卡峰值显存都可以压到约0.76 GB;这组 benchmark 里,
fsdp2_tp + memopt level=1同时拿到了更低显存和更好的吞吐;dp + zero是否优于纯dp取决于工作负载,在较大模型上更值得尝试。
推荐组合#
如果你的目标比较明确,可以按下面的经验规则选择:
吞吐优先,显存压力不大:
对小模型或单卡训练,先看
none;对更大的分布式 workload,优先尝试
fsdp2或fsdp2_tp;dp + zero可以作为纯数据并行路线的一个可选增强,但收益和 workload 强相关。
单卡显存优先,尤其是 ImageNet / Transformer 型 SNN:
优先尝试
tp + memopt level=1或fsdp2_tp + memopt level=1;当前实测里,这两种组合都能把
Spikformer的单卡峰值显存压到约0.76 GB。
希望在速度和显存之间取得折中:
fsdp2_tp仍然是最稳妥的主推荐;如果你的工作负载与这里的
Spikformerbenchmark 接近,可以直接试fsdp2_tp + memopt level=1;如果显存已经足够,则保留
fsdp2_tp而不开memopt,可以减少优化前处理时间。
只想要最省心、最稳妥的分布式训练入口:
从
dp开始;如果要进一步扩展到更大模型,再迁移到
fsdp2或fsdp2_tp。
如果你不想自己手工挑模式,现在训练脚本和 benchmark 也支持高层自动推荐器:
torchrun --nproc_per_node=4 \
spikingjelly/activation_based/examples/memopt/train_distributed.py \
--data-dir /path/to/cifar10dvs \
--distributed-mode auto \
--prefer memory \
--backend inductor \
--batch-size 16
其中:
--prefer speed倾向于选择吞吐优先的组合;--prefer memory倾向于选择单卡显存更低的组合;--prefer capacity倾向于选择更容易放下大模型的组合(优先考虑PP)。
当 prefer=capacity 且环境允许时,自动推荐器会优先选择:
mode=pppp_virtual_stages=2pp_schedule=interleavedmemopt level=1
zero_bubble 仍然作为显式可选项保留在命令行里。它现在已经能稳定跑通,但当前默认仍建议优先使用更稳、更快的 interleaved;zero_bubble 更适合手动实验和容量优先场景。
如果显式指定了 --distributed-mode,那么 prefer 仍然可以帮你补默认的 memopt / optimizer_sharding 等参数,但不会覆盖你手工指定的模式。
Benchmark 自动记录与对比#
benchmark/benchmark_snn_distributed.py 现在会默认把结果追加到 benchmark/results/benchmark_snn_distributed.jsonl,并自动和同配置的上一条记录做对比。新版记录会显式区分 benchmark 口径与 batch 语义,统一保存:
benchmark_regime:throughput_weak_scaling/latency_strong_scaling/memory_capacityglobal_batch_sizeper_rank_batch_sizedata_replicaspp_memopt_stagesstep_latency_msglobal_throughput_spsper_device_throughput_spspeak_allocated_mboptimize_msforward_msbackward_msoptimizer_msreset_msmaterialize_mstp_all_reduce_callstp_all_reduce_mbwarning_countrecompile_countgraph_break_count
例如:
torchrun --nproc_per_node=4 \
benchmark/benchmark_snn_distributed.py \
--mode auto \
--prefer speed \
--model spikformer_ti \
--backend inductor \
--batch-size 4 \
--T 8
当前建议避免的组合:
hybrid``(``DDP + TP):当前仍不支持;在大尺寸
Spikformer工作负载上直接使用高 levelmemopt``(``level >= 2)做在线搜索:虽然功能上已经可用,但optimize_ms仍然很高,并且更容易触发inductor的额外重编译,建议先离线搜索、再固定策略。