基于 STA 的 Transformer ANN2SNN 转换#

本页作者:黄一凡 (AllenYolk)

English version: STA-based Transformer ANN2SNN Conversion

本页介绍 spikingjelly.activation_based.ann2snn 中面向 Transformer 的 ANN2SNN 转换路径,核心对象是 STATransformerRecipe,它是一个基于 Spatio-Temporal Approximation (STA) [1] 的 training-free 转换 recipe。

如果要做经典 CNN 上的 ReLU-to-IFNode rate coding 转换,请阅读 ANN转换SNN。本页介绍的是独立的 Transformer 转换流程。

警告

STA 转换不是严格意义上的 fully spike-driven SNN 转换。在 mode="spiking_encoder" 中,模块输出的是“整数脉冲数乘以校准阈值”的 量化值;当残差为负时,这个整数脉冲数也可以为负。它不是二值脉冲 tensor。用严格 SNN 定义比较不同方法时,这一点很重要,也容易引起争议。

因此,STATransformerRecipe 应理解为一种 training-free Transformer ANN2SNN 近似转换流程,而不是完整 fully spike-driven LLM conversion 的承诺。 以整数 token 为输入的语言模型需要额外定义 input 和 embedding 的转换契约。 当前 STA 实现刻意采用在线、有状态形式;Layer-wise sequence TD execution 和更快 multi-step inference 属于后续执行后端工作,不是本教程中的默认路径。

STA 转换思想#

Transformer 模型包含仿射投影、LayerNorm、GELU、attention、残差加法、mask 和 tensor 常量等组件。直接套用 ReLU-to-IFNode rate coding 规则不足以覆盖这些组件。 STA 用在线时序近似来处理它们。

不带脉冲的在线差分#

先看不带 spike encoder 的在线等价路径,核心概念是累计激活。令 \(x\) 为原始 ANN 输入,构造输入序列:

\[x^{(0)} = x,\qquad x^{(t)} = 0,\quad t=1,\ldots,T-1.\]

\(t\) 个时间步后的累计输入为:

\[X^{(t)} = \sum_{\tau=0}^{t} x^{(\tau)}.\]

因此 \(X^{(0)} = X^{(1)} = \cdots = x\)。令 \(f\) 表示转换前 ANN 中的一个函数或模块,例如仿射投影、LayerNorm、GELU 或 attention 模块。 转换后的 STA 模块不是 \(f\) 本身。记转换后在单个时间步上执行的差分模块为 \(F_t\),它输出 \(f\) 在相邻累计输入上的差分:

\[F_t\left(X^{(t)}\right) = f\left(X^{(t)}\right) - f\left(X^{(t-1)}\right), \qquad f\left(X^{(-1)}\right) = 0.\]

在实现中,\(F_t\) 由包裹原操作的有状态包装模块实现:它计算或复用 \(f\) 的累计输出,保存上一时间步的累计输出,并只返回当前差分 \(\Delta y^{(t)} = F_t(X^{(t)})\)

累计输出满足:

\[\sum_{t=0}^{T-1} F_t\left(X^{(t)}\right) = f\left(X^{(T-1)}\right) - f\left(X^{(-1)}\right) = f(x).\]

这个恒等式解释了 STA 中在线等价部分为什么成立:如果每个转换模块都输出累计 结果的差分,并且常量与 bias 只计算一次,那么把所有时间步输出相加即可 恢复 ANN 模块的输出。

不带脉冲编码的在线等价路径的执行模式如下:

  • 第 0 个时间步输入原始 ANN 输入;

  • 后续时间步输入零值浮点 tensor;

  • 转换后的模型在内部执行 time_steps 次循环;

  • 有状态转换模块输出累计结果的差分;

  • 包装模块累加每步输出并返回累计结果。

高层执行流程如下:

y = 0
for t in range(time_steps):
    if t == 0:
        x_t = original_input
    else:
        x_t = zeros_like(original_input)
    y = y + converted_graph_step(x_t, static_control_tensors)
return y

这里的 converted_graph_step 表示转换后的 FX 计算图在一个内部时间步上的 一次执行。该计算图包含有状态模块,会记住上一时间步的累计输出,因此每次调用 只返回当前增量。

带脉冲的 spike encoder#

mode="spiking_encoder" 会在选中的增量后加入 spike encoder。对模拟增量 \(a^{(t)}\) 和阈值 \(V\),encoder 维护残差膜电位 \(r^{(t)}\)。 初始残差为 0:

\[r^{(-1)} = 0.\]

每个时间步,encoder 先累加模拟增量:

\[u^{(t)} = r^{(t-1)} + a^{(t)}.\]

然后计算可以发放多少个阈值单位:

\[n^{(t)} = \operatorname{trunc}\left(\frac{u^{(t)}}{V}\right), \qquad s^{(t)} = n^{(t)} V.\]

其中 \(s^{(t)}\) 是当前时间步的量化输出。下一时间步的残差为:

\[r^{(t)} = u^{(t)} - s^{(t)}.\]

从 SNN 角度看,\(r^{(t)}\) 是发放后保留下来的膜电位, \(n^{(t)}\) 是整数脉冲数;当残差为负时,它也可以为负。 \(s^{(t)}\) 是按阈值加权后的脉冲输出。更新式 \(r^{(t)} = u^{(t)} - s^{(t)}\) 相当于广义的软重置:它从膜电位中减去 本步输出的阈值加权值。当 \(n^{(t)} = 1\) 时,这退化为普通软重置; 更大的正整数或负整数表示在一个时间步内跨过多个阈值单位。

经过 \(T\) 个时间步后:

\[\sum_{t=0}^{T-1} s^{(t)} = \sum_{t=0}^{T-1} a^{(t)} - r^{(T-1)}\]

也就是说,encoder 的输出和模拟增量之和只差最终残差。如果 \(a^{(t)}\) 就是 STA 差分 \(F_t(X^{(t)})\),那么模拟增量之和就是 ANN 模块输出 \(f(x)\),脉冲编码后的结果与它之间的差异就是最终残差。 由于 STA 校准阈值时会使用 time_steps,在激活范围固定时,更大的 \(T\) 对应更细的时间量化。

这一点对 Transformer 很关键:LayerNorm、GELU 和 attention 都不是简单的 ReLU rate coding 层。在线累计差分视角允许转换模型在累计输入上计算这些函数, 再对输出增量做编码,从而在保留算子语义的同时在选中的输出处引入脉冲式的时序通信。

仿射模块、LayerNormGELUMultiheadAttention 和浮点 FX tensor 常量都维护在线累计差分状态。Bias 和图中的常量只注入一次。 静态 attention mask 等控制 tensor 会在各时间步保留,而不是被置零。

本教程推荐的公开路径是 mode="spiking_encoder"。该模式会在 LayerNormGELUMultiheadAttention 的输出侧加入校准后的 有状态 spike encoder,同时保持主干 affine projection 在线等价。阈值来自 dataloader 校准,并依赖 time_steps

使用 STATransformerRecipe#

最小 Python API 和其它 ANN2SNN recipe 一样,遵循 Recipe + Converter 结构:

from spikingjelly.activation_based import ann2snn

recipe = ann2snn.STATransformerRecipe(
    dataloader=calibration_loader,
    time_steps=8,
    mode="spiking_encoder",
    threshold_mode="mse",
    threshold_scale=0.5,
)
converted = ann2snn.Converter(recipe=recipe, device="cuda:0").convert(model)
converted.eval()

time_steps 属于 recipe 参数,因为它既参与阈值校准,也参与转换后模型内部 的推理循环。与 rate-coded CNN 转换不同,用户不需要在 Python 外层按时间步 反复调用转换后的 Transformer;包装模块自己持有这个循环。

STA 当前有三种模式:

  • equivalent:无需校准的在线累计差分基线;

  • spiking_encoder:在非线性和 attention 输出上使用校准 spike encoder;

  • spiking_affine:高级路径,会进一步把选中的 affine 模块替换为 spiking affine 模块。

本教程的模型级结果使用 spiking_encoder

与 TransformerSpikeEquivalentRecipe 的关系#

TransformerSpikeEquivalentRecipe 是一个不需要 dataloader 的替换路径, 用于将当前支持的 Transformer 算子替换为 TD / spike-equivalent 模块。它适合 作为算子级转换基线,但不进行 STA 校准,也不持有内部时间步循环。

STATransformerRecipe 是模型级 STA 流程。启用 spike encoder 时它需要校准, 并返回一个内部执行 time_steps 循环的包装模块。

ViT-B/16 ImageNet 示例#

完整可运行示例是 spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta。该脚本使用 torchvision.models.vit_b_16ViT_B_16_Weights.DEFAULT,以及可被 torchvision.datasets.ImageFolder 读取的 ImageNet 验证集目录。

下面命令假设 /path/to/imagenet/val 直接包含各类别文件夹,需要 CUDA。

CUDA_VISIBLE_DEVICES=0 python -m spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta \
  --data-root /path/to/imagenet/val \
  --device cuda:0 \
  --batch-size 64 \
  --num-workers 8 \
  --calib-samples 2048 \
  --time-steps 8 \
  --threshold-scale 0.5

若只想快速检查环境,可使用少量验证样本:

CUDA_VISIBLE_DEVICES=0 python -m spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta \
  --data-root /path/to/imagenet/val \
  --device cuda:0 \
  --batch-size 8 \
  --num-workers 2 \
  --calib-samples 32 \
  --eval-samples 32 \
  --time-steps 8 \
  --threshold-scale 0.5

下表数据在 NVIDIA A100-SXM4-80GB 上,使用完整 50000 张 ImageNet 验证集测得:

ViT-B/16 ImageNet STA 转换结果#

方法

校准样本

验证样本

时间步

Top-1 (%)

Top-5 (%)

ANN

50000

81.068

95.318

STA spiking_encoder

2048

50000

8

80.590

95.074

Top-1 下降 0.478 个百分点。原始运行中 ANN baseline 推理耗时约 115.4 秒, STA 转换模型约 1197.1 秒;一次精度相同的重跑中分别为 250.8 秒和 2613.1 秒, wall-clock time 对运行时环境比较敏感。

关键 stdout 行如下:

BASELINE {"top1": 0.81068, "top5": 0.95318, "total": 50000, "seconds": 115.39487862586975}
STA_SPIKING_ENCODER_T8_S05 {"top1": 0.8059, "top5": 0.95074, "total": 50000, "seconds": 1197.0657494068146}
DROP 0.0047800000000000065