基于 STA 的 Transformer ANN2SNN 转换#
本页作者:黄一凡 (AllenYolk)
English version: STA-based Transformer ANN2SNN Conversion
本页介绍 spikingjelly.activation_based.ann2snn 中面向 Transformer 的
ANN2SNN 转换路径,核心对象是 STATransformerRecipe,它是一个基于
Spatio-Temporal Approximation (STA) [1] 的 training-free 转换 recipe。
如果要做经典 CNN 上的 ReLU-to-IFNode rate coding 转换,请阅读 ANN转换SNN。本页介绍的是独立的 Transformer 转换流程。
警告
STA 转换不是严格意义上的 fully spike-driven SNN 转换。在
mode="spiking_encoder" 中,模块输出的是“整数脉冲数乘以校准阈值”的
量化值;当残差为负时,这个整数脉冲数也可以为负。它不是二值脉冲
tensor。用严格 SNN 定义比较不同方法时,这一点很重要,也容易引起争议。
因此,STATransformerRecipe 应理解为一种 training-free Transformer
ANN2SNN 近似转换流程,而不是完整 fully spike-driven LLM conversion 的承诺。
以整数 token 为输入的语言模型需要额外定义 input 和 embedding 的转换契约。
当前 STA 实现刻意采用在线、有状态形式;Layer-wise sequence TD execution
和更快 multi-step inference 属于后续执行后端工作,不是本教程中的默认路径。
STA 转换思想#
Transformer 模型包含仿射投影、LayerNorm、GELU、attention、残差加法、mask 和 tensor 常量等组件。直接套用 ReLU-to-IFNode rate coding 规则不足以覆盖这些组件。 STA 用在线时序近似来处理它们。
不带脉冲的在线差分#
先看不带 spike encoder 的在线等价路径,核心概念是累计激活。令 \(x\) 为原始 ANN 输入,构造输入序列:
第 \(t\) 个时间步后的累计输入为:
因此 \(X^{(0)} = X^{(1)} = \cdots = x\)。令 \(f\) 表示转换前 ANN 中的一个函数或模块,例如仿射投影、LayerNorm、GELU 或 attention 模块。 转换后的 STA 模块不是 \(f\) 本身。记转换后在单个时间步上执行的差分模块为 \(F_t\),它输出 \(f\) 在相邻累计输入上的差分:
在实现中,\(F_t\) 由包裹原操作的有状态包装模块实现:它计算或复用 \(f\) 的累计输出,保存上一时间步的累计输出,并只返回当前差分 \(\Delta y^{(t)} = F_t(X^{(t)})\)。
累计输出满足:
这个恒等式解释了 STA 中在线等价部分为什么成立:如果每个转换模块都输出累计 结果的差分,并且常量与 bias 只计算一次,那么把所有时间步输出相加即可 恢复 ANN 模块的输出。
不带脉冲编码的在线等价路径的执行模式如下:
第 0 个时间步输入原始 ANN 输入;
后续时间步输入零值浮点 tensor;
转换后的模型在内部执行
time_steps次循环;有状态转换模块输出累计结果的差分;
包装模块累加每步输出并返回累计结果。
高层执行流程如下:
y = 0
for t in range(time_steps):
if t == 0:
x_t = original_input
else:
x_t = zeros_like(original_input)
y = y + converted_graph_step(x_t, static_control_tensors)
return y
这里的 converted_graph_step 表示转换后的 FX 计算图在一个内部时间步上的
一次执行。该计算图包含有状态模块,会记住上一时间步的累计输出,因此每次调用
只返回当前增量。
带脉冲的 spike encoder#
mode="spiking_encoder" 会在选中的增量后加入 spike encoder。对模拟增量
\(a^{(t)}\) 和阈值 \(V\),encoder 维护残差膜电位 \(r^{(t)}\)。
初始残差为 0:
每个时间步,encoder 先累加模拟增量:
然后计算可以发放多少个阈值单位:
其中 \(s^{(t)}\) 是当前时间步的量化输出。下一时间步的残差为:
从 SNN 角度看,\(r^{(t)}\) 是发放后保留下来的膜电位, \(n^{(t)}\) 是整数脉冲数;当残差为负时,它也可以为负。 \(s^{(t)}\) 是按阈值加权后的脉冲输出。更新式 \(r^{(t)} = u^{(t)} - s^{(t)}\) 相当于广义的软重置:它从膜电位中减去 本步输出的阈值加权值。当 \(n^{(t)} = 1\) 时,这退化为普通软重置; 更大的正整数或负整数表示在一个时间步内跨过多个阈值单位。
经过 \(T\) 个时间步后:
也就是说,encoder 的输出和模拟增量之和只差最终残差。如果
\(a^{(t)}\) 就是 STA 差分 \(F_t(X^{(t)})\),那么模拟增量之和就是
ANN 模块输出 \(f(x)\),脉冲编码后的结果与它之间的差异就是最终残差。
由于 STA 校准阈值时会使用 time_steps,在激活范围固定时,更大的
\(T\) 对应更细的时间量化。
这一点对 Transformer 很关键:LayerNorm、GELU 和 attention 都不是简单的 ReLU rate coding 层。在线累计差分视角允许转换模型在累计输入上计算这些函数, 再对输出增量做编码,从而在保留算子语义的同时在选中的输出处引入脉冲式的时序通信。
仿射模块、LayerNorm、GELU、MultiheadAttention 和浮点
FX tensor 常量都维护在线累计差分状态。Bias 和图中的常量只注入一次。
静态 attention mask 等控制 tensor 会在各时间步保留,而不是被置零。
本教程推荐的公开路径是 mode="spiking_encoder"。该模式会在
LayerNorm、GELU 和 MultiheadAttention 的输出侧加入校准后的
有状态 spike encoder,同时保持主干 affine projection 在线等价。阈值来自
dataloader 校准,并依赖 time_steps。
使用 STATransformerRecipe#
最小 Python API 和其它 ANN2SNN recipe 一样,遵循 Recipe + Converter 结构:
from spikingjelly.activation_based import ann2snn
recipe = ann2snn.STATransformerRecipe(
dataloader=calibration_loader,
time_steps=8,
mode="spiking_encoder",
threshold_mode="mse",
threshold_scale=0.5,
)
converted = ann2snn.Converter(recipe=recipe, device="cuda:0").convert(model)
converted.eval()
time_steps 属于 recipe 参数,因为它既参与阈值校准,也参与转换后模型内部
的推理循环。与 rate-coded CNN 转换不同,用户不需要在 Python 外层按时间步
反复调用转换后的 Transformer;包装模块自己持有这个循环。
STA 当前有三种模式:
equivalent:无需校准的在线累计差分基线;spiking_encoder:在非线性和 attention 输出上使用校准 spike encoder;spiking_affine:高级路径,会进一步把选中的 affine 模块替换为 spiking affine 模块。
本教程的模型级结果使用 spiking_encoder。
与 TransformerSpikeEquivalentRecipe 的关系#
TransformerSpikeEquivalentRecipe 是一个不需要 dataloader 的替换路径,
用于将当前支持的 Transformer 算子替换为 TD / spike-equivalent 模块。它适合
作为算子级转换基线,但不进行 STA 校准,也不持有内部时间步循环。
STATransformerRecipe 是模型级 STA 流程。启用 spike encoder 时它需要校准,
并返回一个内部执行 time_steps 循环的包装模块。
ViT-B/16 ImageNet 示例#
完整可运行示例是
spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta。该脚本使用
torchvision.models.vit_b_16、
ViT_B_16_Weights.DEFAULT,以及可被
torchvision.datasets.ImageFolder 读取的 ImageNet 验证集目录。
下面命令假设 /path/to/imagenet/val 直接包含各类别文件夹,需要 CUDA。
CUDA_VISIBLE_DEVICES=0 python -m spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta \
--data-root /path/to/imagenet/val \
--device cuda:0 \
--batch-size 64 \
--num-workers 8 \
--calib-samples 2048 \
--time-steps 8 \
--threshold-scale 0.5
若只想快速检查环境,可使用少量验证样本:
CUDA_VISIBLE_DEVICES=0 python -m spikingjelly.activation_based.ann2snn.examples.imagenet_vit_sta \
--data-root /path/to/imagenet/val \
--device cuda:0 \
--batch-size 8 \
--num-workers 2 \
--calib-samples 32 \
--eval-samples 32 \
--time-steps 8 \
--threshold-scale 0.5
下表数据在 NVIDIA A100-SXM4-80GB 上,使用完整 50000 张 ImageNet 验证集测得:
方法 |
校准样本 |
验证样本 |
时间步 |
Top-1 (%) |
Top-5 (%) |
|---|---|---|---|---|---|
ANN |
50000 |
81.068 |
95.318 |
||
STA |
2048 |
50000 |
8 |
80.590 |
95.074 |
Top-1 下降 0.478 个百分点。原始运行中 ANN baseline 推理耗时约 115.4 秒, STA 转换模型约 1197.1 秒;一次精度相同的重跑中分别为 250.8 秒和 2613.1 秒, wall-clock time 对运行时环境比较敏感。
关键 stdout 行如下:
BASELINE {"top1": 0.81068, "top5": 0.95318, "total": 50000, "seconds": 115.39487862586975}
STA_SPIKING_ENCODER_T8_S05 {"top1": 0.8059, "top5": 0.95074, "total": 50000, "seconds": 1197.0657494068146}
DROP 0.0047800000000000065