spikingjelly.activation_based.ann2snn package#

Converter#

class spikingjelly.activation_based.ann2snn.converter.Converter(dataloader, device=None, mode='Max', momentum=0.1, fuse_flag=True, rules=None, neuron_factory=None, threshold_optimizer=None)[源代码]#

基类:object

API Language - 中文 | English


  • 中文

Converter 是 ANN2SNN 转换器对象,而不是用于推理的 torch.nn.Module。它提供显式转换方法: convert_to_spiking_neurons() 用于传统 ReLU→IFNode 校准转换, replace_by_td_operators() 用于 Transformer TD core operator 替换。

ANN2SNN教程见此处 ANN转换SNN

目前支持三种转换模式,由参数mode进行设置。

ReLU→IFNode 转换后 ReLU 模块被删除,SNN 需要的新模块(包括 VoltageScaler、IFNode 等)被创建并存放在 snn tailor 父模块中。 TD operator 替换不使用校准数据,只将支持的 ANN 模块替换为 TD 等价模块。

由于返回值的类型为 fx.GraphModule,建议使用 print(fx.GraphModule.graph) 查看计算图及前向传播关系。更多 API 参见 GraphModule

警告

必须确保ANN中的 ReLU 为module而非function。

您最好在ANN模型中使用平均池化而不是最大池化。否则,可能会损害转换后的SNN模型的性能。

参数:
  • dataloader (Iterable) -- 数据加载器。迭代返回的每个 batch 必须支持 data[0] 取出输入张量,例如 (input, label)(input,)

  • device (device or str or None) -- Device

  • mode (str, float) -- 转换模式。目前支持三种模式:最大电流转换模式 mode="max", 99.9% 电流转换模式 mode="99.9%",以及缩放转换模式 mode=x0 < x <= 1)。

  • momentum (float) -- 动量值,用于modules.VoltageHook

  • fuse_flag (bool) -- 标志位,设置为True,则进行conv与bn的融合,反之不进行。

  • rules (Optional[List[ActivationRule]]) -- 激活函数转换规则列表。每个规则必须实现 matchinsert_hooksfind_replacementsreplace_with_neurons 。默认使用 [ReLURule()]

  • neuron_factory (Optional[NeuronFactory]) -- 脉冲神经元工厂。默认使用 NeuronFactory() (IFNode, threshold=1.0)。

  • threshold_optimizer (Optional[ThresholdOptimizer]) -- 阈值优化器。默认使用 ThresholdOptimizer(strategy="fixed")


  • English

Converter is an ANN2SNN conversion driver, not a torch.nn.Module for inference. It provides explicit conversion methods: convert_to_spiking_neurons() for the traditional ReLU-to-IFNode calibrated conversion, and replace_by_td_operators() for Transformer TD core operator replacement.

ANN2SNN tutorial is here ANN2SNN .

Three common methods are implemented here, which can be selected by the value of parameter mode.

In the ReLU-to-IFNode path, ReLU modules will be removed, and new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module snn tailor. The TD operator replacement path does not use calibration data; it only replaces supported ANN modules with TD-equivalent modules.

Since the converted model is an fx.GraphModule, use print(fx.GraphModule.graph) to inspect the generated computation graph. More APIs are here GraphModule .

警告

Make sure that ReLU is module rather than function.

You'd better use avgpool rather than maxpool in your ann model. If not, the performance of the converted snn model may be ruined.

参数:
  • dataloader (Iterable) -- Dataloader for converting. Each yielded batch must support data[0] as the input tensor, for example (input, label) or (input,).

  • device (device or str or None) -- Device

  • mode (str, float) -- Conversion mode. Now support three mode, MaxNorm (mode="max"), RobustNorm (mode="99.9%"), and scaling mode (mode=x, where 0 < x <= 1).

  • momentum (float) -- Momentum value used by modules.VoltageHook

  • fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.

  • rules (Optional[List[ActivationRule]]) -- List of activation conversion rules. Each rule must implement match, insert_hooks, find_replacements and replace_with_neurons. Defaults to [ReLURule()].

  • neuron_factory (Optional[NeuronFactory]) -- Neuron factory. Defaults to NeuronFactory() (IFNode, threshold=1.0).

  • threshold_optimizer (Optional[ThresholdOptimizer]) -- Threshold optimizer. Defaults to ThresholdOptimizer(strategy="fixed").

convert_to_spiking_neurons(ann)[源代码]#

API Language - 中文 | English


  • 中文

将带有 ReLU module 的 ANN 转换为 SNN GraphModule。该方法会执行 FX tracing、可选 Conv-BN 融合、VoltageHook 校准和神经元替换。

参数:

ann (Module) -- 待转换的 ANN。

返回:

转换得到的 SNN。

返回类型:

GraphModule


  • English

Convert an ANN with ReLU modules to an SNN GraphModule. This method performs FX tracing, optional Conv-BN fusion, VoltageHook calibration, and neuron replacement.

参数:

ann (Module) -- ANN to be converted.

返回:

Converted SNN.

返回类型:

GraphModule

replace_by_td_operators(ann)[源代码]#

API Language: 中文 | English


  • 中文

将 ANN 中支持的 core modules 和窄 attention 子集替换为 temporal-difference (TD) 等价算子,并返回 GraphModule。当前自动 替换 torch.nn.Lineartorch.nn.LayerNormtorch.nn.GELU、literal dropout_p=0.0torch.nn.functional.scaled_dot_product_attention() 调用,以及 dropout=0.0batch_first=Trueneed_weights=Falsetorch.nn.MultiheadAttention 调用。该方法不插入 VoltageHook,不运行 dataloader 校准。返回模型会保留输入模型及 已替换模块的 training/eval 状态。

该转换路径面向完整时间序列输入,约定转换后模型的输入张量使用第 0 维作为时间维,形状通常为 [T, ...]T > 0。TD 算子输出 浮点差分值,不是二值脉冲,也不表示 fully spike-driven 在线执行。 dtype、device 与后端行为跟随被替换算子的 PyTorch 实现;当前没有 CuPy / Triton 专用路径。该方法不改变输入模型本身,而是返回 tracing 后的 GraphModule

参数:

ann (Module) -- 待转换的 ANN。

返回:

已替换 core TD operators 的 GraphModule

返回类型:

GraphModule

抛出:

ValueError -- 若 FX 图中包含当前不支持的 TD attention 配置,例如 非零 SDPA dropout、动态 SDPA 配置、enable_gqa=Truenn.MultiheadAttentiondropout != 0batch_first=Falseneed_weights=Truekey_padding_mask 或非 packed q/k/v 参数。


  • English

Replace supported core modules and a narrow attention subset in an ANN with temporal-difference (TD) equivalent operators and return a GraphModule. Currently, torch.nn.Linear, torch.nn.LayerNorm, torch.nn.GELU, literal dropout_p=0.0 torch.nn.functional.scaled_dot_product_attention() calls, and torch.nn.MultiheadAttention calls with dropout=0.0, batch_first=True and need_weights=False are replaced automatically. This method does not insert VoltageHook and does not run dataloader calibration. The returned model preserves the training/eval state of the input model and replaced modules.

This conversion path targets complete time-sequence inputs. Converted models conventionally use dimension 0 as the time dimension, with shape [T, ...] and T > 0. TD operators output floating-point differential values; they are not binary spikes and do not represent fully spike-driven online execution. Dtype, device, and backend behavior follow the PyTorch implementation of each replaced operator; there is no CuPy / Triton specific path currently. This method does not mutate the input model itself; it returns a traced GraphModule.

参数:

ann (Module) -- ANN to be converted.

返回:

GraphModule with core TD operators replaced.

返回类型:

GraphModule

抛出:

ValueError -- If the FX graph contains unsupported TD attention configurations, such as nonzero SDPA dropout, dynamic SDPA configuration, enable_gqa=True, nn.MultiheadAttention with dropout != 0, batch_first=False, need_weights=True, key_padding_mask, or non-packed q/k/v parameters.

static fuse(fx_model, fuse_flag=True)[源代码]#

API Language - 中文 | English


  • 中文

fuse 用于conv与bn的融合。

参数:
  • fx_model (GraphModule) -- 原模型

  • fuse_flag (bool) -- 标志位,设置为True,则进行conv与bn的融合,反之不进行。

返回:

conv层和bn层融合后的模型.

返回类型:

GraphModule


  • English

fuse is used to fuse conv layer and bn layer.

参数:
  • fx_model (GraphModule) -- Original fx_model

  • fuse_flag (bool) -- Bool specifying if fusion of the conv and the bn happens, by default it happens.

返回:

fx_model whose conv layer and bn layer have been fused.

返回类型:

GraphModule

static set_voltagehook(fx_model, mode='Max', momentum=0.1, rules=None)[源代码]#

API Language - 中文 | English


  • 中文

set_voltagehook 用于给模型添加VoltageHook模块。这里实现了常见的三种模式,同上。

参数:
  • fx_model (GraphModule) -- 原模型

  • mode (str, float) -- 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式

  • momentum (float) -- 动量值,用于VoltageHook

  • rules (Optional[List[ActivationRule]]) -- 自定义的激活匹配规则列表。默认值为 None,此时使用 [ReLURule()],即匹配 ReLU 并为其前后插入 VoltageHook。传入自定义规则可扩展匹配的激活类型或调整 hook 插入位置。

返回:

带有VoltageHook的模型.

返回类型:

GraphModule


  • English

set_voltagehook is used to add VoltageHook to fx_model. Three common methods are implemented here, the same as Converter.mode.

参数:
  • fx_model (GraphModule) -- Original fx_model

  • mode (str, float) -- Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode

  • momentum (float) -- momentum value used by VoltageHook

  • rules (Optional[List[ActivationRule]]) -- Optional list of activation matching rules. When None (default) [ReLURule()] is used, which matches ReLU and wraps it with VoltageHook. Pass custom rules to match additional activation types or to change where hooks are inserted.

返回:

fx_model with VoltageHook.

返回类型:

GraphModule

replace_by_neurons(fx_model)[源代码]#

API Language - 中文 | English


  • 中文

self.rules 匹配到的激活节点替换为脉冲神经元,并按 self.threshold_optimizer 计算出的阈值 完成 VoltageScaler(1/v_threshold) -> Neuron -> VoltageScaler(v_threshold) 的等价变换。 默认规则、神经元工厂与阈值优化器可复现原 replace_by_ifnodeReLU -> IFNode 替换语义。

参数:

fx_model (GraphModule) -- 已插入校准 hook 的 GraphModule

返回:

激活节点被替换为脉冲神经元后的模型

返回类型:

GraphModule


  • English

Replace activations matched by self.rules with spiking neurons, applying the VoltageScaler(1/v_threshold) -> Neuron -> VoltageScaler(v_threshold) transformation using the threshold computed by self.threshold_optimizer. With the default rule, neuron factory and threshold optimizer this reproduces the original replace_by_ifnode ReLU -> IFNode replacement semantics.

参数:

fx_model (GraphModule) -- GraphModule with calibration hooks already inserted.

返回:

Model with activations replaced by spiking neurons.

返回类型:

GraphModule

static replace_by_ifnode(fx_model)[源代码]#

Replace ReLU with IF neurons (legacy API, use replace_by_neurons() instead).

Deprecated:

Use replace_by_neurons() instead.

参数:

fx_model (GraphModule) -- Model with calibration hooks inserted.

返回:

Model with ReLU replaced by IF neurons.

返回类型:

GraphModule

Extension Points#

class spikingjelly.activation_based.ann2snn.rules.ActivationRule(*args, **kwargs)[源代码]#

基类:Protocol

API Language - 中文 | English


  • 中文

激活函数转换规则协议。实现该协议即可接入新的 ANN→SNN 转换算法。规则需要负责:

  1. 通过 match() 判断是否处理某个 fx.Node

  2. 通过 insert_hooks() 在节点后插入校准 hook;

  3. 通过 find_replacements() 找到 (activation_node, hook_node) 对;

  4. 通过 replace_with_neurons() 将激活节点与 hook 替换为脉冲神经元结构。


  • English

Protocol for activation-to-neuron conversion rules. Implement this protocol to plug a new ANN→SNN algorithm into the converter. A rule must:

  1. decide whether it handles a given fx.Node via match();

  2. insert a calibration hook after the node via insert_hooks();

  3. enumerate (activation_node, hook_node) pairs to replace via find_replacements();

  4. replace the activation + hook pair with spiking neurons via replace_with_neurons().

match(node, modules)[源代码]#

API Language - 中文 | English


  • 中文

判断该规则是否处理给定节点。

参数:
  • node (fx.Node) -- 待检查的 fx.Node

  • modules (Dict[str, nn.Module]) -- fx.GraphModule.named_modules() 得到的模块名字典。

返回:

若该规则负责此节点则返回 True

返回类型:

bool


  • English

Return True if this rule handles the given graph node.

参数:
  • node (fx.Node) -- The fx.Node to check.

  • modules (Dict[str, nn.Module]) -- Module-name dictionary obtained from fx.GraphModule.named_modules().

返回:

True if this rule handles the node.

返回类型:

bool

insert_hooks(fx_model, node, hook_factory, hook_counts_per_prefix)[源代码]#

API Language - 中文 | English


  • 中文

node 之后插入一个由 hook_factory 创建的校准 hook,并将新节点加入 fx_modelhook_counts_per_prefix 用于在多 hook 场景下生成唯一的目标 名称。

参数:
  • fx_model (fx.GraphModule) -- 待修改的 GraphModule

  • node (fx.Node) -- 触发 hook 插入的 fx.Node

  • hook_factory (HookFactory) -- 校准 hook 工厂。

  • hook_counts_per_prefix (Dict[str, int]) -- 用于生成唯一 hook 目标名的前缀计数器。

返回:

新插入的 hook 节点。

返回类型:

fx.Node


  • English

Insert a calibration hook created by hook_factory after node and register the new node inside fx_model. hook_counts_per_prefix is used to generate unique hook target names when multiple hooks are inserted.

参数:
  • fx_model (fx.GraphModule) -- The GraphModule to modify.

  • node (fx.Node) -- The fx.Node after which the hook is inserted.

  • hook_factory (HookFactory) -- Hook factory used to build the calibration hook.

  • hook_counts_per_prefix (Dict[str, int]) -- Per-prefix counters used to build unique hook target names.

返回:

The newly inserted hook node.

返回类型:

fx.Node

find_replacements(fx_model, modules)[源代码]#

API Language - 中文 | English


  • 中文

遍历 fx_model,产出需要被替换的 (activation_node, hook_node) 对。 对于非标准图结构的规则,应重写该方法实现自定义遍历。

参数:
  • fx_model (fx.GraphModule) -- 已插入校准 hook 的 GraphModule

  • modules (Dict[str, nn.Module]) -- fx.GraphModule.named_modules() 得到的模块名字典。

返回:

形如 (activation_node, hook_node) 的迭代器。

返回类型:

Iterator[Tuple[fx.Node, fx.Node]]


  • English

Iterate over fx_model and yield (activation_node, hook_node) pairs to replace. Rules with non-standard graph patterns should override this method with their own traversal.

参数:
  • fx_model (fx.GraphModule) -- GraphModule with calibration hooks already inserted.

  • modules (Dict[str, nn.Module]) -- Module-name dictionary obtained from fx.GraphModule.named_modules().

返回:

Iterator of (activation_node, hook_node) pairs.

返回类型:

Iterator[Tuple[fx.Node, fx.Node]]

replace_with_neurons(fx_model, activation_node, hook_node, neuron_factory, threshold_optimizer)[源代码]#

API Language - 中文 | English


  • 中文

activation_nodehook_node 替换为对应的脉冲神经元结构。thresholdthreshold_optimizer 基于 hook 校准数据计算得到;神经元由 neuron_factory 构造。

参数:
  • fx_model (fx.GraphModule) -- 待修改的 GraphModule

  • activation_node (fx.Node) -- 激活节点。

  • hook_node (fx.Node) -- 校准 hook 节点。

  • neuron_factory (NeuronFactory) -- 脉冲神经元工厂。

  • threshold_optimizer (ThresholdOptimizer) -- 阈值优化器。

返回类型:

None


  • English

Replace the activation + hook pair with the corresponding spiking neuron structure. The threshold is computed by threshold_optimizer from the calibration hook; the neuron is built by neuron_factory.

参数:
  • fx_model (fx.GraphModule) -- The GraphModule to modify.

  • activation_node (fx.Node) -- The activation node.

  • hook_node (fx.Node) -- The calibration hook node.

  • neuron_factory (NeuronFactory) -- Spiking-neuron factory.

  • threshold_optimizer (ThresholdOptimizer) -- Threshold optimizer.

返回类型:

None

class spikingjelly.activation_based.ann2snn.rules.ReLURule[源代码]#

基类:object

API Language - 中文 | English


  • 中文

nn.ReLU 转换规则。复现 SpikingJelly 原有行为:将每个 nn.ReLU 替换为 VoltageScaler(1/s) -> IFNode -> VoltageScaler(s),其中 sThresholdOptimizer 基于 VoltageHook 的校准结果计算。


  • English

Conversion rule for nn.ReLU modules. Reproduces the original SpikingJelly behaviour: each nn.ReLU is replaced by VoltageScaler(1/s) -> IFNode -> VoltageScaler(s), where s is computed by ThresholdOptimizer from the VoltageHook calibration data.

match(node, modules)[源代码]#
参数:
返回类型:

bool

insert_hooks(fx_model, node, hook_factory, hook_counts_per_prefix)[源代码]#
参数:
返回类型:

Node

find_replacements(fx_model, modules)[源代码]#
参数:
返回类型:

Iterator[Tuple[Node, Node]]

replace_with_neurons(fx_model, activation_node, hook_node, neuron_factory, threshold_optimizer)[源代码]#
参数:
返回类型:

None

class spikingjelly.activation_based.ann2snn.factories.NeuronFactory(neuron_type=<class 'spikingjelly.activation_based.neuron.integrate_and_fire.IFNode'>, v_threshold=1.0, v_reset=None, **kwargs)[源代码]#

基类:object

API Language - 中文 | English


  • 中文

用于创建替换激活函数的脉冲神经元模块。默认创建 spikingjelly.activation_based.neuron.IFNode,并使用 v_threshold=1.0v_reset=None 保持原有 ANN2SNN 行为。默认转换会通过 VoltageScaler 处理激活尺度,因此默认工厂不会把 scale 直接写入 神经元阈值;自定义工厂可读取 scale 派生阈值或其他参数。

参数:
  • neuron_type (Type[nn.Module]) -- 神经元类,必须接受 v_thresholdv_reset 关键字参数。 默认为 spikingjelly.activation_based.neuron.IFNode

  • v_threshold (float) -- 神经元发放阈值,传递给神经元构造函数。

  • v_reset (Optional[float]) -- 膜电位复位值。None 表示软复位(减法复位),默认为 None

  • kwargs -- 透传给神经元构造函数的其他关键字参数。


  • English

Factory that creates spiking-neuron modules used to replace ANN activation functions. By default it instantiates spikingjelly.activation_based.neuron.IFNode with v_threshold=1.0 and v_reset=None to preserve the original ANN2SNN behaviour. The default conversion handles the activation scale with VoltageScaler, so the default factory does not copy scale into the neuron threshold. Custom factories may derive thresholds or other neuron parameters from scale.

参数:
  • neuron_type (Type[nn.Module]) -- Neuron class to instantiate. Must accept v_threshold and v_reset keyword arguments. Defaults to spikingjelly.activation_based.neuron.IFNode.

  • v_threshold (float) -- Firing threshold passed to the neuron constructor.

  • v_reset (Optional[float]) -- Membrane reset value. None means soft reset (subtractive reset). Defaults to None.

  • kwargs -- Additional keyword arguments forwarded to the neuron constructor.

create(scale)[源代码]#

API Language - 中文 | English


  • 中文

根据工厂配置创建一个脉冲神经元模块实例。scale 为当前层校准得到的激活 尺度,默认实现不直接使用该值,但子类可据此派生阈值或其他参数。

参数:

scale (float) -- 当前层的校准尺度。

返回:

配置完成的脉冲神经元模块。

返回类型:

nn.Module


  • English

Instantiate a spiking-neuron module with the configured parameters. scale is the calibrated activation scale of the current layer; the default implementation does not use it directly, but subclasses can derive thresholds or other neuron parameters from it.

参数:

scale (float) -- Calibration scale for the layer.

返回:

A spiking-neuron module.

返回类型:

nn.Module

class spikingjelly.activation_based.ann2snn.factories.HookFactory(mode='Max', momentum=0.1)[源代码]#

基类:object

API Language - 中文 | English


  • 中文

用于创建校准阶段使用的 VoltageHook 实例。每个匹配到的激活节点会获得 独立的 hook 实例。

参数:
  • mode (str, float) -- 校准模式,传递给 VoltageHook"Max" 记录激活最大值; "99.9%" 记录 99.9 分位点;(0, 1] 区间的 float 表示 max * mode

  • momentum (float) -- VoltageHook 的 EMA 动量。


  • English

Factory that creates VoltageHook instances used during calibration. Each matched activation node receives an independent hook instance.

参数:
  • mode (str, float) -- Calibration mode forwarded to VoltageHook. "Max" records the maximum activation; "99.9%" records the 99.9-th percentile; a float in (0, 1] records max * mode.

  • momentum (float) -- EMA momentum for VoltageHook.

create()[源代码]#

API Language - 中文 | English


  • 中文

创建一个新的 VoltageHook 实例。

返回:

配置完成的 VoltageHook

返回类型:

VoltageHook


  • English

Create a new VoltageHook instance.

返回:

A configured VoltageHook.

返回类型:

VoltageHook

class spikingjelly.activation_based.ann2snn.threshold.ThresholdOptimizer(strategy='fixed')[源代码]#

基类:object

API Language - 中文 | English


  • 中文

阈值优化器。根据 VoltageHook 在校准阶段记录的 scale 计算当前层的 神经元阈值。当前内置策略:

  • "fixed": 阈值等于校准 scale (默认,等价于 SpikingJelly 原有行为)。

其他策略需通过子类化并重写 compute_threshold() 实现;基类可接受任意策略 名,但只有 "fixed" 在基类中真正生效。

参数:

strategy (str) -- 阈值计算策略名称。


  • English

Threshold optimizer. Computes the neuron threshold for a layer from the scale recorded by VoltageHook during calibration. Built-in strategy:

  • "fixed": threshold equals the calibrated scale (default, matches the original SpikingJelly behaviour).

Additional strategies should be implemented by subclassing and overriding compute_threshold(). The base class accepts any strategy name but only implements "fixed" itself.

参数:

strategy (str) -- Name of the threshold computation strategy.

compute_threshold(hook)[源代码]#

API Language - 中文 | English


  • 中文

返回当前层对应的脉冲神经元阈值。当前仅在 strategy="fixed" 时直接返回 hook 中记录的 scale;其他策略由子类实现。

参数:

hook (VoltageHook) -- 已完成校准的 VoltageHook,其 scale 属性保存激活 范围统计量。

返回:

神经元阈值。

返回类型:

float

抛出:

  • English

Return the spiking-neuron threshold for the layer represented by hook. With strategy="fixed" this returns the scale stored in the hook; other strategies should be implemented by subclasses.

参数:

hook (VoltageHook) -- A calibrated VoltageHook whose scale attribute holds the activation range statistic.

返回:

The neuron threshold.

返回类型:

float

抛出:

Helper Modules and Functions#

class spikingjelly.activation_based.ann2snn.operators.TDSoftmax(dim=-1)[源代码]#

基类:Module

API Language

中文 | English


  • 中文

Temporal-difference (TD) Softmax 算子。输入必须是完整时间序列, 时间维固定为第 0 维,形状为 [T, ...]。该模块先对输入在时间维 做累积,再沿 dim 计算 torch.softmax,最后返回累积输出在 时间维上的差分。

返回值是浮点差分值,可能包含负值;它不是二值脉冲,也不表示 fully spike-driven Softmax。输出 dtype 与输入 dtype 相同;推荐使用 float32float16float64 输入。该算子完全由 PyTorch 可微算子组成,对 autograd 透明。

该算子的机制来源于 SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN 中对 Transformer 非线性算子的累积-差分等价转换思路。本文档中的 TD Softmax 只实现张量级算子:它仍调用 torch.softmax,需要完整时间 序列输入,不是逐时间步在线算子,也不是面向神经形态硬件的 fully spike-driven Softmax。

op = TDSoftmax(dim=-1)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:

dim (int) -- Softmax 归一化维度。不能为第 0 维,因为第 0 维保留为时间维。


  • English

Temporal-difference (TD) Softmax operator. The input must be a complete time sequence whose time dimension is fixed at dimension 0, with shape [T, ...]. This module first accumulates the input over time, applies torch.softmax along dim to each cumulative input, and returns the temporal difference of the cumulative outputs.

The output contains floating-point differential values and may contain negative values. It is not a binary spike tensor and does not represent a fully spike-driven Softmax. The output dtype matches the input dtype; float32, float16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd.

The mechanism follows the cumulative-difference equivalence idea for Transformer nonlinear operators in SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN. This implementation provides only a tensor-level operator: it still calls torch.softmax, requires a complete time sequence, is not a step-wise online operator, and is not a fully spike-driven Softmax for neuromorphic hardware.

op = TDSoftmax(dim=-1)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:

dim (int) -- Softmax normalization dimension. It must not be dimension 0, which is reserved as the time dimension.

forward(x_seq)[源代码]#

API Language

中文 | English


  • 中文

对完整时间序列执行 TD Softmax。计算过程为:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{Softmax}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

因此 Y.cumsum(dim=0) 与对 X.cumsum(dim=0) 逐时间步执行 ANN Softmax 的结果一致。输出是浮点差分值,可能为负,不是二值脉冲。 当 T = 1 时,Y[0] 直接等于 torch.softmax(X[0], dim=dim)。 输出 dtype 与输入 dtype 相同,且该算子对 autograd 透明。

参数:

x_seq (Tensor) -- 输入时间序列,形状为 [T, ...],且 T > 0

返回:

TD Softmax 差分序列,形状与 x_seq 相同。

返回类型:

Tensor

抛出:

ValueError -- 若 x_seq 少于 2 维、时间维为空,或 dim 指向时间维。


  • English

Apply TD Softmax to a complete time sequence:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{Softmax}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

Thus, Y.cumsum(dim=0) matches ANN Softmax applied to X.cumsum(dim=0) at each time step. The output contains floating-point differential values, may be negative, and is not a binary spike tensor. When T = 1, Y[0] is exactly torch.softmax(X[0], dim=dim). The output dtype matches the input dtype, and the operator is transparent to autograd.

参数:

x_seq (Tensor) -- Input time sequence with shape [T, ...] and T > 0.

返回:

TD Softmax differential sequence with the same shape as x_seq.

返回类型:

Tensor

抛出:

ValueError -- If x_seq has fewer than 2 dimensions, the time dimension is empty, or dim refers to the time dimension.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.operators.TDLayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, bias=True, device=None, dtype=None)[源代码]#

基类:Module

API Language

中文 | English


  • 中文

Temporal-difference (TD) LayerNorm 算子。输入必须是完整时间序列, 时间维固定为第 0 维,形状为 [T, ...]。该模块先对输入在时间维 做累积,再对每个累积输入执行 torch.nn.functional.layer_norm(),最后返回累积输出在时间维上的 差分。

返回值是浮点差分值,可能包含负值;它不是二值脉冲,也不表示 fully spike-driven LayerNorm。输出 dtype 与输入 dtype 相同;推荐使用 float32float16float64 输入。该算子完全由 PyTorch 可微算子组成,对 autograd 透明。该算子无内部状态,多次 forward 之间不需要调用 reset

该算子的机制来源于 SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN 中对 Transformer 非线性算子的累积-差分等价转换思路。本文档中的 TD LayerNorm 只实现张量级算子:它仍调用 torch.nn.functional.layer_norm(),需要完整时间序列输入,不是逐 时间步在线算子,也不是面向神经形态硬件的 fully spike-driven LayerNorm。

op = TDLayerNorm(normalized_shape=3)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:
  • normalized_shape (int or list[int] or Size) -- 输入尾部需要归一化的形状,与 torch.nn.LayerNormnormalized_shape 语义一致。

  • eps (float) -- 加到方差上的数值稳定项。

  • elementwise_affine (bool) -- 若为 True,使用可学习的逐元素仿射 参数。

  • bias (bool) -- 若 elementwise_affinebias 均为 True, 使用可学习 bias 参数。若 elementwise_affineFalse, 则忽略 bias

  • device (device or str or None) -- 参数初始化设备。

  • dtype (dtype or None) -- 参数初始化 dtype。


  • English

Temporal-difference (TD) LayerNorm operator. The input must be a complete time sequence whose time dimension is fixed at dimension 0, with shape [T, ...]. This module first accumulates the input over time, applies torch.nn.functional.layer_norm() to each cumulative input, and returns the temporal difference of the cumulative outputs.

The output contains floating-point differential values and may contain negative values. It is not a binary spike tensor and does not represent a fully spike-driven LayerNorm. The output dtype matches the input dtype; float32, float16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd. The operator is stateless, and repeated forward calls do not require reset.

The mechanism follows the cumulative-difference equivalence idea for Transformer nonlinear operators in SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN. This implementation provides only a tensor-level operator: it still calls torch.nn.functional.layer_norm(), requires a complete time sequence, is not a step-wise online operator, and is not a fully spike-driven LayerNorm for neuromorphic hardware.

op = TDLayerNorm(normalized_shape=3)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:
  • normalized_shape (int or list[int] or Size) -- Input trailing shape to normalize, with the same semantics as normalized_shape in torch.nn.LayerNorm.

  • eps (float) -- Value added to the variance for numerical stability.

  • elementwise_affine (bool) -- If True, use learnable per-element affine parameters.

  • bias (bool) -- If both elementwise_affine and bias are True, use a learnable bias parameter. If elementwise_affine is False, bias is ignored.

  • device (device or str or None) -- Device used to initialize parameters.

  • dtype (dtype or None) -- Dtype used to initialize parameters.

reset_parameters()[源代码]#
返回类型:

None

forward(x_seq)[源代码]#

API Language

中文 | English


  • 中文

对完整时间序列执行 TD LayerNorm。计算过程为:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{LayerNorm}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

因此 Y.cumsum(dim=0) 与对 X.cumsum(dim=0) 逐时间步执行 ANN LayerNorm 的结果一致。输出是浮点差分值,可能为负,不是二值 脉冲。 当 T = 1 时,Y[0] 直接等于对 X[0] 执行 LayerNorm 的 结果。 输出 dtype 与输入 dtype 相同,且该算子对 autograd 透明。

参数:

x_seq (Tensor) -- 输入时间序列,形状为 [T, ...],且 T > 0,尾部形状必须 匹配 normalized_shape

返回:

TD LayerNorm 差分序列,形状与 x_seq 相同。

返回类型:

Tensor

抛出:

ValueError -- 若 x_seq 少于 2 维、时间维为空或尾部形状不匹配。


  • English

Apply TD LayerNorm to a complete time sequence:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{LayerNorm}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

Thus, Y.cumsum(dim=0) matches ANN LayerNorm applied to X.cumsum(dim=0) at each time step. The output contains floating-point differential values, may be negative, and is not a binary spike tensor. When T = 1, Y[0] is exactly LayerNorm applied to X[0]. The output dtype matches the input dtype, and the operator is transparent to autograd.

参数:

x_seq (Tensor) -- Input time sequence with shape [T, ...] and T > 0. The trailing shape must match normalized_shape.

返回:

TD LayerNorm differential sequence with the same shape as x_seq.

返回类型:

Tensor

抛出:

ValueError -- If x_seq has fewer than 2 dimensions, the time dimension is empty, or the trailing shape does not match.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.operators.TDGELU(approximate='none')[源代码]#

基类:Module

API Language

中文 | English


  • 中文

Temporal-difference (TD) GELU(Gaussian Error Linear Unit)算子。 输入必须是完整时间序列,时间维固定为第 0 维,形状为 [T, ...]。 该模块先对输入在时间维做累积,再对每个累积输入执行 torch.nn.functional.gelu(),最后返回累积输出在时间维上的差分。

返回值是浮点差分值,可能包含负值;它不是二值脉冲,也不表示 fully spike-driven GELU。输出 dtype 与输入 dtype 相同;推荐使用 float32float16bfloat16float64 输入。该算子 完全由 PyTorch 可微算子组成,对 autograd 透明。该算子无内部状态, 多次 forward 之间不需要调用 reset。该算子仅依赖 torch.nn.functional.gelu(),支持 CPU 与 CUDA,后端与 torch 一致,无 CuPy / Triton 专用路径。

该算子的机制来源于 SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN 中对 Transformer 非线性算子的累积-差分等价转换思路。本文档中的 TD GELU 只实现张量级算子:它仍调用 torch.nn.functional.gelu(),需要 完整时间序列输入,不是逐时间步在线算子,也不是面向神经形态硬件的 fully spike-driven GELU。

op = TDGELU(approximate="none")
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:

approximate (Literal["none", "tanh"]) -- GELU 近似模式,与 torch.nn.GELUapproximate 语义一致。

抛出:

ValueError -- 若 approximate 不是 "none""tanh"


  • English

Temporal-difference (TD) GELU (Gaussian Error Linear Unit) operator. The input must be a complete time sequence whose time dimension is fixed at dimension 0, with shape [T, ...]. This module first accumulates the input over time, applies torch.nn.functional.gelu() to each cumulative input, and returns the temporal difference of the cumulative outputs.

The output contains floating-point differential values and may contain negative values. It is not a binary spike tensor and does not represent a fully spike-driven GELU. The output dtype matches the input dtype; float32, float16, bfloat16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd. The operator is stateless, and repeated forward calls do not require reset. It only depends on torch.nn.functional.gelu(), supports CPU and CUDA, follows the torch backend behavior, and has no CuPy / Triton specific path.

The mechanism follows the cumulative-difference equivalence idea for Transformer nonlinear operators in SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN. This implementation provides only a tensor-level operator: it still calls torch.nn.functional.gelu(), requires a complete time sequence, is not a step-wise online operator, and is not a fully spike-driven GELU for neuromorphic hardware.

op = TDGELU(approximate="none")
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:

approximate (Literal["none", "tanh"]) -- GELU approximation mode, with the same semantics as approximate in torch.nn.GELU.

抛出:

ValueError -- If approximate is not "none" or "tanh".

forward(x_seq)[源代码]#

API Language

中文 | English


  • 中文

对完整时间序列执行 TD GELU。计算过程为:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{GELU}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

因此 Y.cumsum(dim=0) 与对 X.cumsum(dim=0) 逐时间步执行 ANN GELU 的结果一致。输出是浮点差分值,可能为负,不是二值脉冲。 当 T = 1 时,Y[0] 直接等于对 X[0] 执行 GELU 的结果。 输出 dtype 与输入 dtype 相同,且该算子对 autograd 透明。

参数:

x_seq (Tensor) -- 输入时间序列,形状为 [T, ...],且 T > 0

返回:

TD GELU 差分序列,形状与 x_seq 相同。

返回类型:

Tensor

抛出:

ValueError -- 若 x_seq 少于 2 维或时间维为空。


  • English

Apply TD GELU to a complete time sequence:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = \operatorname{GELU}(X_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

Thus, Y.cumsum(dim=0) matches ANN GELU applied to X.cumsum(dim=0) at each time step. The output contains floating-point differential values, may be negative, and is not a binary spike tensor. When T = 1, Y[0] is exactly GELU applied to X[0]. The output dtype matches the input dtype, and the operator is transparent to autograd.

参数:

x_seq (Tensor) -- Input time sequence with shape [T, ...] and T > 0.

返回:

TD GELU differential sequence with the same shape as x_seq.

返回类型:

Tensor

抛出:

ValueError -- If x_seq has fewer than 2 dimensions or the time dimension is empty.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.operators.TDLinear(in_features, out_features, bias=True, device=None, dtype=None)[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

Temporal-difference (TD) Linear 算子。输入必须是完整时间序列, 时间维固定为第 0 维,形状为 [T, ..., in_features]。该模块先对 输入在时间维做累积,再执行 torch.nn.functional.linear(),最后 返回累积输出在时间维上的差分。

返回值是浮点差分值,可能包含负值;它不是二值脉冲,也不表示 fully spike-driven Linear。输出 dtype 与 PyTorch Linear 一致;推荐使用 float32float16bfloat16float64 输入。该算子 完全由 PyTorch 可微算子组成,对 autograd 透明。该算子无内部状态, 多次 forward 之间不需要调用 reset。该算子仅依赖 PyTorch Linear,支持 CPU 与 CUDA,后端与 torch 一致,无 CuPy / Triton 专用路径。

该算子用于处理带 bias 的 affine projection。普通 torch.nn.Linear 直接作用在 TD 差分序列上会在时间累积后得到 T * bias;TD Linear 在累积输入上执行 Linear 后再差分,使累计输出 保持 W @ x_cum + bias

op = TDLinear(3, 5)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:
  • in_features (int) -- 输入特征数。

  • out_features (int) -- 输出特征数。

  • bias (bool) -- 若为 True,使用可学习 bias 参数。

  • device (device or str or None) -- 参数初始化设备。

  • dtype (dtype or None) -- 参数初始化 dtype。


  • English

Temporal-difference (TD) Linear operator. The input must be a complete time sequence whose time dimension is fixed at dimension 0, with shape [T, ..., in_features]. This module first accumulates the input over time, applies torch.nn.functional.linear(), and returns the temporal difference of the cumulative outputs.

The output contains floating-point differential values and may contain negative values. It is not a binary spike tensor and does not represent a fully spike-driven Linear. The output dtype follows PyTorch Linear; float32, float16, bfloat16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd. The operator is stateless, and repeated forward calls do not require reset. It only depends on PyTorch Linear, supports CPU and CUDA, follows the torch backend behavior, and has no CuPy / Triton specific path.

This operator handles affine projections with bias. Applying ordinary torch.nn.Linear directly to a TD differential sequence would accumulate the bias as T * bias. TD Linear applies Linear to the cumulative input and then differences the cumulative output, preserving W @ x_cum + bias.

op = TDLinear(3, 5)
x_seq = torch.randn(4, 2, 3)
y_seq = op(x_seq)
参数:
  • in_features (int) -- Number of input features.

  • out_features (int) -- Number of output features.

  • bias (bool) -- If True, use a learnable bias parameter.

  • device (device or str or None) -- Device used to initialize parameters.

  • dtype (dtype or None) -- Dtype used to initialize parameters.

reset_parameters()[源代码]#
返回类型:

None

forward(x_seq)[源代码]#

API Language: 中文 | English


  • 中文

对完整时间序列执行 TD Linear。计算过程为:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = X_{cum}[t] W^T + b\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

因此 Y.cumsum(dim=0) 与对 X.cumsum(dim=0) 逐时间步执行 ANN Linear 的结果一致。输出是浮点差分值,可能为负,不是二值脉冲。当 T = 1 时,Y[0] 直接等于对 X[0] 执行 Linear 的结果。 输出 dtype 与 PyTorch Linear 一致,且该算子对 autograd 透明。

参数:

x_seq (Tensor) -- 输入时间序列,形状为 [T, ..., in_features],且 T > 0

返回:

TD Linear 差分序列,形状为 [T, ..., out_features]

返回类型:

Tensor

抛出:

ValueError -- 若 x_seq 少于 2 维或时间维为空。


  • English

Apply TD Linear to a complete time sequence:

\[X_{cum}[t] = \sum_{i=0}^{t} X[i]\]
\[Y_{cum}[t] = X_{cum}[t] W^T + b\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

Thus, Y.cumsum(dim=0) matches ANN Linear applied to X.cumsum(dim=0) at each time step. The output contains floating-point differential values, may be negative, and is not a binary spike tensor. When T = 1, Y[0] is exactly Linear applied to X[0]. The output dtype follows PyTorch Linear, and the operator is transparent to autograd.

参数:

x_seq (Tensor) -- Input time sequence with shape [T, ..., in_features] and T > 0.

返回:

TD Linear differential sequence with shape [T, ..., out_features].

返回类型:

Tensor

抛出:

ValueError -- If x_seq has fewer than 2 dimensions or the time dimension is empty.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.operators.TDScaledDotProductAttention(is_causal=False, scale=None)[源代码]#

基类:Module

API Language

中文 | English


  • 中文

Temporal-difference (TD) scaled dot-product attention 算子。输入必须 是完整时间序列,时间维固定为第 0 维。query_seq 的形状为 [T, ..., L, E]key_seq 的形状为 [T, ..., S, E]value_seq 的形状为 [T, ..., S, Ev]。该模块先分别对 query、key、value 在时间维做累积,再调用 torch.nn.functional.scaled_dot_product_attention(),最后返回 累积输出在时间维上的差分。

返回值是浮点差分值,可能包含负值;它不是二值脉冲,也不表示 fully spike-driven attention。dtype、device 与 mask broadcast 语义遵循 torch.nn.functional.scaled_dot_product_attention();推荐使用 float32float16bfloat16float64 输入。该算子 完全由 PyTorch 可微算子组成,对 autograd 透明。该算子无内部状态, 多次 forward 之间不需要调用 reset。该算子仅依赖 PyTorch SDPA,支持 CPU 与 CUDA,后端与 torch 一致,无 CuPy / Triton 专用路径。

该算子的机制来源于 SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN 中对 Transformer 算子的累积-差分等价转换思路。本文档中的 TD scaled dot-product attention 只实现张量级最小 primitive:它仍调用 PyTorch SDPA,需要完整时间序列输入,不是逐时间步在线算子,也不是面向神经 形态硬件的 fully spike-driven attention。本实现固定 dropout_p=0.0,且不暴露 enable_gqa。组合 TD Transformer block 时,普通带 bias 的 torch.nn.Linear 不能直接作用在差分 序列上,因为累计后 bias 会被重复累加;应使用 bias=False 或专门的 TD Linear。

op = TDScaledDotProductAttention()
q_seq = torch.randn(4, 2, 3, 8)
k_seq = torch.randn(4, 2, 5, 8)
v_seq = torch.randn(4, 2, 5, 6)
y_seq = op(q_seq, k_seq, v_seq)
参数:
  • is_causal (bool) -- 是否应用 causal attention mask。若为 Trueforward 中不能同时传入 attn_mask

  • scale (Optional[float]) -- attention scale。若为 None,使用 PyTorch SDPA 默认值。


  • English

Temporal-difference (TD) scaled dot-product attention operator. The inputs must be complete time sequences whose time dimension is fixed at dimension 0. query_seq has shape [T, ..., L, E], key_seq has shape [T, ..., S, E], and value_seq has shape [T, ..., S, Ev]. This module first accumulates query, key, and value over time, calls torch.nn.functional.scaled_dot_product_attention(), and returns the temporal difference of the cumulative outputs.

The output contains floating-point differential values and may contain negative values. It is not a binary spike tensor and does not represent fully spike-driven attention. Dtype, device, and mask broadcasting follow torch.nn.functional.scaled_dot_product_attention(); float32, float16, bfloat16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd. The operator is stateless, and repeated forward calls do not require reset. It only depends on PyTorch SDPA, supports CPU and CUDA, follows the torch backend behavior, and has no CuPy / Triton specific path.

The mechanism follows the cumulative-difference equivalence idea for Transformer operators in SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN. This implementation provides only a tensor-level minimal primitive: it still calls PyTorch SDPA, requires a complete time sequence, is not a step-wise online operator, and is not fully spike-driven attention for neuromorphic hardware. This implementation fixes dropout_p=0.0 and does not expose enable_gqa. When composing TD Transformer blocks, ordinary torch.nn.Linear layers with bias must not be applied directly to differential sequences, because the bias would be accumulated repeatedly; use bias=False or a dedicated TD Linear.

op = TDScaledDotProductAttention()
q_seq = torch.randn(4, 2, 3, 8)
k_seq = torch.randn(4, 2, 5, 8)
v_seq = torch.randn(4, 2, 5, 6)
y_seq = op(q_seq, k_seq, v_seq)
参数:
  • is_causal (bool) -- Whether to apply causal attention masking. If True, attn_mask must not be passed to forward.

  • scale (Optional[float]) -- Attention scale. If None, use the PyTorch SDPA default.

forward(query_seq, key_seq, value_seq, attn_mask=None)[源代码]#

API Language

中文 | English


  • 中文

对完整 query、key、value 时间序列执行 TD scaled dot-product attention。计算过程为:

\[Q_{cum}[t] = \sum_{i=0}^{t} Q[i], \quad K_{cum}[t] = \sum_{i=0}^{t} K[i], \quad V_{cum}[t] = \sum_{i=0}^{t} V[i]\]
\[Y_{cum}[t] = \operatorname{SDPA}(Q_{cum}[t], K_{cum}[t], V_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

因此 Y.cumsum(dim=0) 与对累积 query、key、value 逐时间步执行 ANN SDPA 的结果一致。输出是浮点差分值,可能为负,不是二值脉冲。当 T = 1 时,Y[0] 直接等于对第一步 query、key、value 执行 SDPA 的结果。输出 dtype 与 PyTorch SDPA 一致,且该算子对 autograd 透明。

参数:
  • query_seq (Tensor) -- query 时间序列,形状为 [T, ..., L, E],且 T > 0

  • key_seq (Tensor) -- key 时间序列,形状为 [T, ..., S, E],且时间维长度 必须与 query_seq 相同。

  • value_seq (Tensor) -- value 时间序列,形状为 [T, ..., S, Ev],且时间维 长度必须与 query_seq 相同。

  • attn_mask (Tensor or None) -- attention mask,broadcast 语义与 PyTorch SDPA 一致。

返回:

TD scaled dot-product attention 差分序列,形状为 [T, ..., L, Ev]

返回类型:

Tensor

抛出:

ValueError -- 若任一输入少于 3 维、时间维为空、三者时间维长度不一致, 或 is_causal=True 时同时传入 attn_mask


  • English

Apply TD scaled dot-product attention to complete query, key, and value time sequences:

\[Q_{cum}[t] = \sum_{i=0}^{t} Q[i], \quad K_{cum}[t] = \sum_{i=0}^{t} K[i], \quad V_{cum}[t] = \sum_{i=0}^{t} V[i]\]
\[Y_{cum}[t] = \operatorname{SDPA}(Q_{cum}[t], K_{cum}[t], V_{cum}[t])\]
\[Y[0] = Y_{cum}[0], \quad Y[t] = Y_{cum}[t] - Y_{cum}[t-1]\]

Thus, Y.cumsum(dim=0) matches ANN SDPA applied to cumulative query, key, and value at each time step. The output contains floating-point differential values, may be negative, and is not a binary spike tensor. When T = 1, Y[0] is exactly SDPA applied to the first query, key, and value step. The output dtype follows PyTorch SDPA, and the operator is transparent to autograd.

参数:
  • query_seq (Tensor) -- Query time sequence with shape [T, ..., L, E] and T > 0.

  • key_seq (Tensor) -- Key time sequence with shape [T, ..., S, E]. Its time dimension length must match query_seq.

  • value_seq (Tensor) -- Value time sequence with shape [T, ..., S, Ev]. Its time dimension length must match query_seq.

  • attn_mask (Tensor or None) -- Attention mask with the same broadcast semantics as PyTorch SDPA.

返回:

TD scaled dot-product attention differential sequence with shape [T, ..., L, Ev].

返回类型:

Tensor

抛出:

ValueError -- If any input has fewer than 3 dimensions, any time dimension is empty, the time lengths differ, or attn_mask is passed when is_causal=True.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.operators.TDMultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, batch_first=True, device=None, dtype=None)[源代码]#

基类:Module

API Language: 中文 | English


  • 中文

Temporal-difference (TD) MultiheadAttention 的窄子集实现。输入必须是 完整时间序列,时间维固定为第 0 维,形状为 [T, batch, seq, embed_dim]。该模块使用 TDLinear 生成 q/k/v projection,执行 TD scaled dot-product attention,再用 TDLinear 执行输出 projection。

返回值是 (attn_output_seq, None),用于匹配 torch.nn.MultiheadAttentionneed_weights=False 时的 tuple 返回结构。输出是浮点差分值,不是二值脉冲,也不是 fully spike-driven attention。输出 dtype 跟随 PyTorch Linear / SDPA; 推荐使用 float32float16bfloat16float64 输入。 该算子完全由 PyTorch 可微算子组成,对 autograd 透明。该算子无内部 状态,多次 forward 之间不需要调用 reset;支持 CPU 与 CUDA, 后端与 torch 一致,无 CuPy / Triton 专用路径。当前只支持 dropout=0.0batch_first=Trueneed_weights=False

该算子的机制来源于 SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN 中的 累积-差分等价转换思路。本实现是窄子集 TD wrapper,仍使用浮点 TDLinear 和 PyTorch SDPA,不是逐时间步在线 attention,也不是面向 神经形态硬件的 fully spike-driven MultiheadAttention。bias=True 时 projection bias 由 TDLinear 在累积输入上处理,避免普通 nn.Linear 直接作用在差分序列时产生重复累计 bias。

op = TDMultiheadAttention(embed_dim=8, num_heads=2)
x_seq = torch.randn(4, 2, 5, 8)
y_seq, weights = op(x_seq, x_seq, x_seq, need_weights=False)
参数:
  • embed_dim (int) -- 输入和输出 embedding 维度。

  • num_heads (int) -- attention head 数量,必须整除 embed_dim

  • dropout (float) -- attention dropout。当前必须为 0.0

  • bias (bool) -- 若为 True,q/k/v 和 out projection 使用 bias。

  • batch_first (bool) -- 当前必须为 True,即每个时间步的输入形状为 [batch, seq, embed_dim]

  • device (device or str or None) -- 参数初始化设备。

  • dtype (dtype or None) -- 参数初始化 dtype。

抛出:

ValueError -- 若 embed_dim 不能被 num_heads 整除、或传入 当前不支持的 dropout / batch_first


  • English

Narrow temporal-difference (TD) MultiheadAttention implementation. The input must be a complete time sequence whose time dimension is fixed at dimension 0, with shape [T, batch, seq, embed_dim]. This module uses TDLinear for q/k/v projections, applies TD scaled dot-product attention, and then applies a TDLinear output projection.

The return value is (attn_output_seq, None) to match the tuple structure of torch.nn.MultiheadAttention when need_weights=False. The output contains floating-point differential values, is not a binary spike tensor, and is not fully spike-driven attention. The output dtype follows PyTorch Linear / SDPA; float32, float16, bfloat16 and float64 inputs are recommended. The operator is composed entirely of differentiable PyTorch operations and is transparent to autograd. The operator is stateless, and repeated forward calls do not require reset. It supports CPU and CUDA, follows the torch backend behavior, and has no CuPy / Triton specific path. Currently only dropout=0.0, batch_first=True and need_weights=False are supported.

The mechanism follows the cumulative-difference equivalence idea in SpikeZIP-TF: Conversion is All You Need for Transformer-based SNN. This implementation is a narrow TD wrapper: it still uses floating-point TDLinear and PyTorch SDPA, is not step-wise online attention, and is not fully spike-driven MultiheadAttention for neuromorphic hardware. When bias=True, projection biases are handled by TDLinear on cumulative inputs, avoiding the repeated bias accumulation that would occur if ordinary nn.Linear were applied directly to differential sequences.

op = TDMultiheadAttention(embed_dim=8, num_heads=2)
x_seq = torch.randn(4, 2, 5, 8)
y_seq, weights = op(x_seq, x_seq, x_seq, need_weights=False)
参数:
  • embed_dim (int) -- Input and output embedding dimension.

  • num_heads (int) -- Number of attention heads. Must divide embed_dim.

  • dropout (float) -- Attention dropout. It must be 0.0 currently.

  • bias (bool) -- If True, use bias in q/k/v and output projections.

  • batch_first (bool) -- Must be True currently. Each time step has shape [batch, seq, embed_dim].

  • device (device or str or None) -- Device used to initialize parameters.

  • dtype (dtype or None) -- Dtype used to initialize parameters.

抛出:

ValueError -- If embed_dim is not divisible by num_heads, or unsupported dropout / batch_first is passed.

forward(query_seq, key_seq, value_seq, key_padding_mask=None, need_weights=False, attn_mask=None, average_attn_weights=True, is_causal=False)[源代码]#

API Language: 中文 | English


  • 中文

对完整 query/key/value 时间序列执行 TD multi-head attention。输入形状 为 [T, batch, seq, embed_dim],且 T > 0。当 need_weights=False 时返回 (attn_output_seq, None)。输出是浮点 差分值,且 attn_output_seq.cumsum(dim=0) 与对累积输入逐时间步执行 支持子集内的 ANN MultiheadAttention 输出一致。当 T = 1 时, attn_output_seq[0] 等于支持子集内 ANN MultiheadAttention 对第一步 输入的输出。输出 dtype 与 PyTorch Linear / SDPA 一致,且该算子对 autograd 透明。

参数:
  • query_seq (Tensor) -- query 时间序列,形状为 [T, batch, target_len, embed_dim]

  • key_seq (Tensor) -- key 时间序列,形状为 [T, batch, source_len, embed_dim]

  • value_seq (Tensor) -- value 时间序列,形状为 [T, batch, source_len, embed_dim]

  • key_padding_mask (Tensor or None) -- 当前不支持,必须为 None

  • need_weights (bool) -- 当前必须为 False

  • attn_mask (Tensor or None) -- attention mask,语义与 torch.nn.MultiheadAttention 一致;bool mask 中 True 表示禁止 attention。

  • average_attn_weights (bool) -- 为兼容 torch.nn.MultiheadAttention 调用签名保留;由于当前不返回 attention weights,必须为 True

  • is_causal (bool) -- 是否应用 causal attention mask。

返回:

(attn_output_seq, None),其中 attn_output_seq 形状为 [T, batch, target_len, embed_dim]

返回类型:

Tuple[Tensor, None]

抛出:

ValueError -- 若传入不支持的 mask/options 或非法输入形状。


  • English

Apply TD multi-head attention to complete query/key/value time sequences. Inputs have shape [T, batch, seq, embed_dim] with T > 0. When need_weights=False, this method returns (attn_output_seq, None). The output contains floating-point differential values, and attn_output_seq.cumsum(dim=0) matches ANN MultiheadAttention in the supported subset applied to cumulative inputs at each time step. When T = 1, attn_output_seq[0] equals the output of ANN MultiheadAttention in the supported subset applied to the first input step. The output dtype follows PyTorch Linear / SDPA, and the operator is transparent to autograd.

参数:
  • query_seq (Tensor) -- Query sequence with shape [T, batch, target_len, embed_dim] and T > 0.

  • key_seq (Tensor) -- Key sequence with shape [T, batch, source_len, embed_dim].

  • value_seq (Tensor) -- Value sequence with shape [T, batch, source_len, embed_dim].

  • key_padding_mask (Tensor or None) -- Unsupported in this narrow implementation.

  • need_weights (bool) -- Must be False. Attention weights are not implemented.

  • attn_mask (Tensor or None) -- Optional attention mask with the same semantics as torch.nn.MultiheadAttention; True values in a bool mask disallow attention.

  • average_attn_weights (bool) -- Kept for torch.nn.MultiheadAttention signature compatibility. It must be True because attention weights are not returned.

  • is_causal (bool) -- Whether to apply causal masking.

返回:

(attn_output_seq, None) where attn_output_seq has shape [T, batch, target_len, embed_dim].

返回类型:

Tuple[Tensor, None]

抛出:

ValueError -- If unsupported masks/options or invalid shapes are passed.

extra_repr()[源代码]#
返回类型:

str

class spikingjelly.activation_based.ann2snn.modules.VoltageHook(scale=1.0, momentum=0.1, mode='Max')[源代码]#

基类:Module

API Language - 中文 | English


  • 中文

VoltageHook 的构造函数。

参数:
  • scale (float) -- 缩放初始值

  • momentum (float) -- 动量值

  • mode (str, float) -- 模式。"Max" 表示记录ANN激活最大值;"99.9%" 表示记录99.9%分位点; 0-1 的 float 表示记录激活最大值的对应倍数


  • English

Constructor of VoltageHook.

参数:
  • scale (float) -- initial scaling value

  • momentum (float) -- momentum value

  • mode (str, float) -- Mode. "Max" means recording the maximum value of ANN activation; "99.9%" means recording the 99.9% percentile; a float of 0-1 means recording the corresponding multiple of the maximum value

forward(x)[源代码]#

API Language - 中文 | English


  • 中文

前向传播函数。不对输入张量做任何处理,只是抓取ReLU的激活值用于确定ANN激活范围。

参数:

x (Tensor) -- 输入张量

返回:

原输入张量

返回类型:

Tensor


  • English

Forward function. It doesn't process input tensors, but hooks the activation values of ReLU to determine ANN activation ranges.

参数:

x (Tensor) -- input tensor

返回:

original input tensor

返回类型:

Tensor

class spikingjelly.activation_based.ann2snn.modules.VoltageScaler(scale=1.0)[源代码]#

基类:Module

API Language - 中文 | English


  • 中文

VoltageScaler 的构造函数。用于SNN推理中缩放电流。

参数:

scale (float) -- 缩放值


  • English

Constructor of VoltageScaler. Used for scaling current in SNN inference.

参数:

scale (float) -- scaling value

forward(x)[源代码]#

API Language - 中文 | English


  • 中文

前向传播函数。对输入电流进行缩放。

参数:

x (Tensor) -- 输入张量,亦即输入电流

返回:

缩放后的电流

返回类型:

Tensor


  • English

Forward function. Scales the input current.

参数:

x (Tensor) -- input tensor, or input current

返回:

current after scaling

返回类型:

Tensor

spikingjelly.activation_based.ann2snn.utils.download_url(url, dst)[源代码]#

API Language - 中文 | English


  • 中文

从指定 URL 下载文件并保存到目标路径。支持断点续传。

参数:
  • url (str) -- 文件的下载链接

  • dst (str) -- 保存文件的目标路径

返回:

文件的总大小(以字节为单位)

返回类型:

int


  • English

Download a file from a given URL and save it to a destination path. Supports resuming interrupted downloads.

参数:
  • url (str) -- the download URL of the file

  • dst (str) -- the destination path to save the file

返回:

the total file size in bytes

返回类型:

int

Examples#