Distributed SNN Training (DTensor / FSDP2)#

中文版: SNN 分布式训练(DTensor / FSDP2)

This tutorial introduces the experimental distributed-training helpers in spikingjelly.activation_based.distributed. The current implementation focuses on:

  • DP: conventional data parallelism

  • TP: simple SNN-oriented tensor parallelism

  • FSDP2: DTensor-based parameter, gradient, and optimizer-state sharding

  • FSDP2 + TP: the recommended hybrid distributed strategy

  • PP: experimental pipeline parallelism (currently implemented first for CIFAR10DVSVGG and Spikformer)

In contrast, the traditional DDP + TP combination still runs into Tensor / DTensor synchronization issues in the current PyTorch version, so this implementation reports an actionable error and recommends FSDP2 + TP instead.

Quick Start#

The current low-level entry is configure_snn_distributed, which lives in spikingjelly.activation_based.distributed.dtensor. The unified public API is the newer analyze / plan / apply trio on spikingjelly.activation_based.distributed.

For example, pure FSDP2 on CIFAR10DVSVGG:

from spikingjelly.activation_based.distributed.dtensor import (
    SNNDistributedConfig,
    configure_snn_distributed,
)
from spikingjelly.activation_based.examples.memopt.models import CIFAR10DVSVGG
import torch.distributed as dist

world_size = dist.get_world_size()  # requires init_process_group()
model = CIFAR10DVSVGG(dropout=0.0, backend='inductor')
model, mesh, analysis = configure_snn_distributed(
    model,
    SNNDistributedConfig(
        device_type='cuda',
        mesh_shape=(world_size,),
        auto_tensor_parallel=False,
        enable_fsdp2=True,
        fsdp_shard_roots=['features', 'classifier'],
        fsdp_shard_module_root=True,
        dp_mesh_dim=0,
    ),
)

To enable FSDP2 + TP, use a 2D mesh:

model, mesh, analysis = configure_snn_distributed(
    model,
    SNNDistributedConfig(
        device_type='cuda',
        mesh_shape=(2, 2),   # (dp, tp)
        enable_fsdp2=True,
        fsdp_shard_roots=['features'],
        fsdp_shard_module_root=False,
        tensor_parallel_roots=['classifier'],
        auto_tensor_parallel=True,
        experimental_conv_tensor_parallel=True,
        conv_tensor_parallel_roots=['features'],
        dp_mesh_dim=0,
        tp_mesh_dim=1,
    ),
)

Training Script#

The repository now includes a real distributed training entry:

  • spikingjelly/activation_based/examples/memopt/train_distributed.py

Example:

torchrun --nproc_per_node=4 \
  spikingjelly/activation_based/examples/memopt/train_distributed.py \
  --data-dir /path/to/cifar10dvs \
  --distributed-mode fsdp2_tp \
  --mesh-shape 2 2 \
  --backend inductor \
  --batch-size 16 \
  --epochs 1 \
  --print-summary

Supported modes in the script:

  • none

  • dp

  • tp

  • fsdp2

  • fsdp2_tp

  • pp: currently exposed as an experimental training path aimed first at smoke benchmarks and structural validation

PP also exposes a set of scheduling and layout controls that move it closer to Megatron-style pipeline tuning:

  • --pp-schedule gpipe: the simplest GPipe schedule;

  • --pp-schedule 1f1b: standard 1F1B;

  • --pp-schedule interleaved: interleaved / VPP-style scheduling;

  • --pp-schedule zero_bubble: an experimental zero-bubble schedule driven by delayed wgrad;

  • --pp-virtual-stages N: keep N virtual chunks on each physical stage;

  • --pp-layout: explicitly describe the contiguous logical-stage split, for example 1|2|2|1;

  • --pp-delay-wgrad: explicitly enable delayed-wgrad style optimization when the chosen schedule supports it.

The historical hybrid (DDP + TP) combination is still unsupported and is no longer exposed in the script; use fsdp2_tp instead.

If you want to further shard optimizer state on the pure dp path, you can also enable ZeroRedundancyOptimizer:

torchrun --nproc_per_node=2 \
  benchmark/benchmark_snn_distributed.py \
  --model cifar10dvs_vgg \
  --mode dp \
  --optimizer-sharding zero \
  --backend inductor \
  --batch-size 2 \
  --T 10

Current Scope#

  1. Linear layers use PyTorch's official tensor-parallel APIs.

  2. Element-wise spiking neurons now explicitly follow upstream sharding:

    • for [T, N, C] activations, they shard along the last feature dimension C;

    • for [T, N, C, H, W] activations, they shard along the channel dimension C.

    In other words, the neuron state v now only keeps the local shard rather than a full replicated global state.

  3. The Conv + BN + Neuron backbone in CIFAR10DVSVGG supports experimental channel tensor parallelism.

  4. Under FSDP2 + TP, the current implementation prioritizes FSDP2 sharding on features; when the classifier is already tensor-parallelized, it is not additionally root-sharded, avoiding repeated cross-mesh sharding.

  5. Traditional hybrid (that is, DDP + TP) is explicitly unsupported for now and will recommend fsdp2_tp instead.

  6. PP currently uses manual stage splitting rather than torch.export full-graph partitioning, which makes it compatible with standard spiking neurons that mutate internal state.

  7. PP explicitly resets neuron state between microbatches inside each stage so that different microbatches do not leak state into each other.

Server Benchmarks (small-network smoke benchmark)#

The following numbers were collected on a multi-GPU RTX 4090 server with CIFAR10DVSVGG, backend='inductor', batch_size=2, and T=10 using a short training-step benchmark. The global_samples/s column is the unified global throughput of the whole distributed job. This workload is intentionally tiny and should be read as a smoke benchmark plus a memory-trend probe rather than a definitive scaling study.

Mode

#GPUs

step_ms

global_samples/s

peak_allocated_mb

Notes

none

1

12.86

155.52

401.63

single-GPU baseline

dp

2

83.71

47.78

434.25

pure DDP; communication dominates at this tiny batch size

dp + zero

2

96.79

41.33

410.22

pure DDP + ZeroRedundancyOptimizer

tp

2

86.58

23.10

308.88

pure TP with neuron states kept on local feature/channel shards

fsdp2

2

97.11

41.19

400.61

pure FSDP2

fsdp2_tp

4

26.68

149.91

316.27

recommended FSDP2 + TP strategy

hybrid (DDP + TP)

4

explicitly unsupported for now; use fsdp2_tp instead

This small-network smoke benchmark shows that:

  • both TP and FSDP2 + TP can now execute real SNN training steps with standard neurons running on backend='inductor';

  • explicit neuron sharding keeps neuron states aligned with local feature/channel shards instead of fully replicating them;

  • even at this tiny scale, TP / FSDP2 + TP already reduce per-GPU peak allocated memory;

  • DDP + TP is still not recommended, and fsdp2_tp should be used instead.

Experimental PP benchmark (server rerun)#

PP now includes:

  • cost-aware stage balancing based on dry-run module timings rather than simple layer counts;

  • a more aggressive automatic pp_microbatches heuristic that prefers values dividing batch_size cleanly;

  • multiple schedules: gpipe, 1f1b, interleaved, and zero_bubble;

  • explicit pp_layout overrides for manual stage placement;

  • lighter microbatch reset handling to avoid repeated full-module tree scans.

The following numbers come from rerunning the larger PP benchmarks on the server. They are more useful for deciding which schedule should be the default recommendation than for claiming that PP is already the primary throughput path.

CIFAR10DVSVGG, backend='inductor', 2 GPUs, batch_size=8, T=4:

Schedule

pp_virtual_stages

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

gpipe

1

0

0.0

93.70

85.38

507.84

1f1b

1

0

0.0

102.65

77.93

259.09

interleaved

2

0

0.0

87.63

91.29

361.45

interleaved

2

1

1521.40

84.39

94.79

361.45

zero_bubble

2

0

0.0

145.17

55.11

452.67

zero_bubble

2

1

1535.38

118.00

67.80

452.67

spikformer_ti, backend='inductor', 2 GPUs, batch_size=4, T=8, image_size=224:

Schedule

pp_virtual_stages

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

gpipe

1

0

0.0

423.64

9.44

1286.03

1f1b

1

0

0.0

461.92

8.66

679.22

interleaved

2

0

0.0

394.63

10.14

1389.71

interleaved

2

1

112.83

423.73

9.44

541.91

zero_bubble

2

0

0.0

455.79

8.78

1356.73

zero_bubble

2

1

164.35

473.41

8.45

483.31

These rerun results show that:

  • PP already works with standard neurons running on backend='inductor';

  • for a small convolutional SNN such as CIFAR10DVSVGG, interleaved is currently the best default schedule for throughput, while 1f1b is more attractive when memory is the main concern;

  • for spikformer_ti, interleaved is also the strongest default schedule, and adding memopt level=1 can reduce peak_allocated_mb from about 1.39 GB to about 0.54 GB;

  • zero_bubble now runs successfully on both CIFAR10DVSVGG and spikformer_ti, but it is still not the strongest throughput option on either workload;

  • for spikformer_ti, zero_bubble + memopt level=1 now also works and reduces peak_allocated_mb to about 0.48 GB;

  • however, zero_bubble still comes with extra inductor recompilation warnings, so it is best viewed as a manual experimental or capacity-oriented option rather than the default recommendation today.

Spikformer + memopt Results#

On spikformer_ti in a more ImageNet-like setting, TP and FSDP2 + TP can now also be combined with memopt level=1. The following experiment uses:

  • model: spikformer_ti

  • input resolution: 224x224

  • batch_size=4

  • T=8

  • backend: inductor

  • GPU: RTX 4090

Mode

optimizer_sharding

memopt

optimize_ms

step_ms

global_samples/s

peak_allocated_mb

none

none

0

0.0

36.70

109.00

2070.34

none

none

1

26852.97

57.35

69.74

1298.16

dp

none

0

0.0

126.56

63.21

2070.93

dp

zero

0

0.0

122.28

65.42

2055.70

dp

none

1

22591.25

134.48

59.49

1315.71

dp

zero

1

23030.79

149.21

53.61

1297.59

fsdp2

none

0

0.0

111.08

72.02

2033.86

fsdp2

none

1

22919.87

132.65

60.31

1272.13

tp

none

0

0.0

196.41

20.37

1321.38

tp

none

1

26913.14

173.65

23.03

767.51

fsdp2_tp

none

0

0.0

131.90

60.65

1319.68

fsdp2_tp

none

1

26403.47

103.95

76.96

761.26

These numbers highlight that:

  • memopt level=1 now works with none / dp / fsdp2 / tp / fsdp2_tp;

  • higher memopt levels on tp / fsdp2_tp / pp (level >= 2) are now functional as well by running split-search before TP/FSDP2/PP materialization; however, this path is still expensive and is better suited to offline tuning or smoke validation than to frequent online retuning;

  • for larger SNNs such as Spikformer, TP / FSDP2 + TP already give a clear per-GPU memory reduction when the neurons use the standard inductor backend;

  • adding memopt level=1 reduces tp / fsdp2_tp further to about 0.76 GB per GPU in this benchmark;

  • in this particular benchmark, fsdp2_tp + memopt level=1 improves both memory usage and throughput;

  • whether dp + zero beats plain dp remains workload-dependent, but it is still a useful option when optimizer state becomes meaningful.

Automatic Benchmark Logging and Comparison#

benchmark/benchmark_snn_distributed.py now appends results to benchmark/results/benchmark_snn_distributed.jsonl by default and automatically compares each run against the most recent earlier run with the same configuration. The newer records also make the benchmark regime and batch semantics explicit. Each record stores:

  • benchmark_regime: throughput_weak_scaling / latency_strong_scaling / memory_capacity

  • global_batch_size

  • per_rank_batch_size

  • data_replicas

  • pp_memopt_stages

  • step_latency_ms

  • global_throughput_sps

  • per_device_throughput_sps

  • peak_allocated_mb

  • optimize_ms

  • forward_ms

  • backward_ms

  • optimizer_ms

  • reset_ms

  • materialize_ms

  • tp_all_reduce_calls

  • tp_all_reduce_mb

  • warning_count

  • recompile_count

  • graph_break_count

For example:

torchrun --nproc_per_node=4 \
  benchmark/benchmark_snn_distributed.py \
  --mode auto \
  --prefer speed \
  --model spikformer_ti \
  --backend inductor \
  --batch-size 4 \
  --T 8

Combinations that should still be avoided for now:

  • hybrid (DDP + TP): still unsupported;

  • running high-level memopt (level >= 2) online on large Spikformer-like workloads: it now works functionally, but the search cost is still high and it is more likely to trigger extra inductor recompiles, so it is best treated as an offline tuning workflow for now.