spikingjelly.activation_based.triton_kernel.flexsn package#
Function Info#
- class spikingjelly.activation_based.triton_kernel.flexsn.info.FlexSNInfo(num_inputs, num_outputs, num_states, fwd_core_args, fwd_core_returns, fwd_core_recipients, fwd_kernel_returns, num_fwd_kernel_returns, c2k_return_mapping)#
基类:
tupleCreate new instance of FlexSNInfo(num_inputs, num_outputs, num_states, fwd_core_args, fwd_core_returns, fwd_core_recipients, fwd_kernel_returns, num_fwd_kernel_returns, c2k_return_mapping)
- c2k_return_mapping#
Alias for field number 8
- fwd_core_args#
Alias for field number 3
- fwd_core_recipients#
Alias for field number 5
- fwd_core_returns#
Alias for field number 4
- fwd_kernel_returns#
Alias for field number 6
- num_fwd_kernel_returns#
Alias for field number 7
- num_inputs#
Alias for field number 0
- num_outputs#
Alias for field number 1
- num_states#
Alias for field number 2
- spikingjelly.activation_based.triton_kernel.flexsn.info.extract_info(fwd_graph: Graph, num_inputs: int = 1, num_states: int = 0, num_outputs: int = 1) FlexSNInfo[源代码]#
-
中文
从前向计算图中提取信息。前向图应具有以下签名:
[*inputs, *states] -> [*outputs, *states, *intermediates]提取的信息包括:
fwd_core_args: 核心参数
fwd_core_returns: 前向图的返回值名称
fwd_core_recipients: 接收核心返回值的变量名
fwd_kernel_returns: 前向 kernel 的返回值名称(无重复)
num_fwd_kernel_returns: fwd_kernel_returns 的长度
c2k_return_mapping: 中间结果与 kernel 返回值之间的映射
- 参数:
- 返回:
提取的 FlexSN 元信息
- 返回类型:
English
Extract useful information from the forward graph. The forward graph should have the following signature:
[*inputs, *states] -> [*outputs, *states, *intermediates]The extracted information includes:
fwd_core_args: the core input argument names
fwd_core_returns: the return value names of the forward graph
fwd_core_recipients: the variable names receiving the core return values
fwd_kernel_returns: the forward kernel return value names (no duplicates)
num_fwd_kernel_returns: the length of fwd_kernel_returns
c2k_return_mapping: mapping from intermediate results to kernel returns
- 参数:
- 返回:
The extracted FlexSN metadata
- 返回类型:
Template#
FlexSN Triton Kernel Templates.
Insert a single-step kernel into a multi-step kernel template.
- spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_inference_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#
Compile a Triton kernel for FlexSN inference (no backward). API Language: 中文 | English
中文
get flexsn inference kernel 函数
- 参数:
- 返回:
Compiled Triton kernel executable
- 返回类型:
triton.runtime.JITFunction
English
Get Flexsn Inference Kernel function
- spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_inference_final_state_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#
Compile a Triton kernel for FlexSN inference returning the final state. API Language: 中文 | English
中文
get flexsn inference final state kernel 函数
- 参数:
- 返回:
Compiled Triton kernel executable
- 返回类型:
triton.runtime.JITFunction
English
Get Flexsn Inference Final State Kernel function
- spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_forward_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#
Compile a Triton kernel for FlexSN forward pass (with state saving). API Language: 中文 | English
中文
get flexsn forward kernel 函数
- 参数:
- 返回:
Compiled Triton kernel executable
- 返回类型:
triton.runtime.JITFunction
English
Get Flexsn Forward Kernel function
- spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_backward_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#
Compile a Triton kernel for FlexSN backward pass. API Language: 中文 | English
中文
get flexsn backward kernel 函数
- 参数:
- 返回:
Compiled Triton kernel executable
- 返回类型:
triton.runtime.JITFunction
English
Get Flexsn Backward Kernel function
Kernel Builders#
Build Triton scan kernels for FlexSN's Triton/Inductor backends.
Three entry points: * build_inference_kernel — no-grad fast path (make_fx, no PYTORCH_JIT=0 needed) * build_inference_final_state_kernel — inference path that returns final states only * build_training_kernels — forward + backward kernels for full BPTT training
- spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_inference_kernel(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None)[源代码]#
-
中文
为
core_fn构建单次 scan 的推理 Triton kernel。 构建出的 kernel 会把core_fn的单步计算包裹在tl.static_range(T)时间循环中,因此无论T多大,单次推理都只触发一次 kernel launch。- 参数:
core_fn (Callable) -- 单步动力学函数,签名应为
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- 每个时间步输入张量的数量
num_states (int) -- 状态张量的数量
num_outputs (int) -- 每个时间步输出张量的数量
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量
[*inputs, *states]。若为None,则自动构造单位大小的 CUDAfloat32张量
- 返回:
(kernel, info),其中kernel为编译后的 Triton kernel,info为调用spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference()所需的FlexSNInfo- 返回类型:
Tuple[object, FlexSNInfo]
English
Build a single-pass scan inference Triton kernel for
core_fn. The generated kernel wrapscore_fn's per-step computation in atl.static_range(T)loop, so one inference call launches exactly one kernel regardless ofT.- 参数:
core_fn (Callable) -- Single-step dynamics callable with signature
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- Number of per-step input tensors
num_states (int) -- Number of state tensors
num_outputs (int) -- Number of per-step output tensors
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors
[*inputs, *states]. IfNone, unit-sized CUDAfloat32tensors are created
- 返回:
(kernel, info)wherekernelis the compiled Triton kernel andinfois theFlexSNInfometadata required byspikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference()- 返回类型:
Tuple[object, FlexSNInfo]
- spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_inference_final_state_kernel(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None)[源代码]#
-
中文
为
core_fn构建返回输出序列与最终状态的推理 Triton kernel。 该变体与build_inference_kernel()一样会追踪core_fn并生成 scan kernel,但它只返回最终状态张量,而不是完整状态序列。- 参数:
core_fn (Callable) -- 单步动力学函数,签名应为
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- 每个时间步输入张量的数量
num_states (int) -- 状态张量的数量
num_outputs (int) -- 每个时间步输出张量的数量
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量
[*inputs, *states]
- 返回:
(kernel, info),分别来自get_flexsn_inference_final_state_kernel和extract_info- 返回类型:
Tuple[object, FlexSNInfo]
English
Build an inference Triton kernel that returns output sequences and final states for
core_fn. This variant tracescore_fnlikebuild_inference_kernel(), but materializes final state tensors instead of full state sequences.- 参数:
core_fn (Callable) -- Single-step dynamics callable with signature
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- Number of per-step input tensors
num_states (int) -- Number of state tensors
num_outputs (int) -- Number of per-step output tensors
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors
[*inputs, *states]
- 返回:
(kernel, info)produced byget_flexsn_inference_final_state_kernelandextract_info- 返回类型:
Tuple[object, FlexSNInfo]
- spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_training_kernels(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None, requires_grad: Tuple[bool, ...] | None = None)[源代码]#
-
中文
为 BPTT 训练构建 FlexSN 的前向与反向 Triton scan kernel。 该函数会使用
aot_function追踪core_fn的正向与反向图,并生成:保存反向所需中间量的前向 scan kernel
执行逆时间反向传播的 backward scan kernel
若某些输出(例如硬阈值脉冲)不可微,AOT backward 会省略对应梯度参数, 此函数会自动生成 shim 以对齐 kernel template 的调用约定。
- 参数:
core_fn (Callable) -- 单步动力学函数,签名应为
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- 每个时间步输入张量的数量
num_states (int) -- 状态张量的数量
num_outputs (int) -- 每个时间步输出张量的数量
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量
[*inputs, *states]requires_grad (Optional[Tuple[bool, ...]]) -- 指示
example_inputs中每个参数是否需要梯度。 若为None,则对所有浮点/复数输入启用梯度追踪
- 返回:
(fwd_kernel, bwd_kernel, info),可直接接入 FlexSN 当前共享的 custom-op 执行路径- 返回类型:
Tuple[object, object, FlexSNInfo]
English
Build FlexSN forward and backward Triton scan kernels for BPTT training. This function uses
aot_functionto trace both the forward and backward ofcore_fnand then produces:a forward scan kernel that saves intermediates needed by backward
a backward scan kernel that runs the reverse-time pass
When some outputs (for example, hard-threshold spike signals) are non-differentiable, AOT backward drops the corresponding gradient inputs. This function automatically generates a shim so the kernel template calling convention stays aligned.
- 参数:
core_fn (Callable) -- Single-step dynamics callable with signature
(*inputs, *states) -> (*outputs, *updated_states)num_inputs (int) -- Number of per-step input tensors
num_states (int) -- Number of state tensors
num_outputs (int) -- Number of per-step output tensors
example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors
[*inputs, *states]requires_grad (Optional[Tuple[bool, ...]]) -- Flags indicating whether each example input should require gradients. If
None, all floating-point and complex inputs are traced as differentiable
- 返回:
(fwd_kernel, bwd_kernel, info)suitable for FlexSN's shared custom-op execution path- 返回类型:
Tuple[object, object, FlexSNInfo]
Wrapper#
- spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_backward(f, info: FlexSNInfo, *args, input_templates: Tuple[Tensor, ...] | None = None, state_templates: Tuple[Tensor, ...] | None = None) tuple[源代码]#
Run the training backward kernel for FlexSN. API Language: 中文 | English
中文
flexsn backward 函数
- 参数:
f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。
args (tuple) -- EN: Gradients followed by any saved tensors accepted by the kernel.
input_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-input-sequence templates used to allocate input
state_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-initial-state templates used to allocate initial
- 返回:
EN: Gradients for inputs and initial states. When
T == 0or all- 返回类型:
- Chinese:
执行 FlexSN 训练反向 kernel。
- English:
Execute the FlexSN training backward kernel. The leading
info.num_outputs + info.num_statesentries correspond to output/state-sequence gradients, and the remaining entries are saved tensors from the forward pass. Chinese: kernel 接收的梯度与保存张量。前info.num_outputs + info.num_states个参数对应输出/状态序列梯度, 其余参数 为前向保存张量。 gradients. They are required when all incoming gradients areNoneand ensure that state-only cores still return correctly shaped input-sequence gradients. When omitted, the fallback allocation path is only valid for single-input cores whose output-sequence gradients are present, because it infers the input gradient shape from the first non-Noneoutput-sequence gradient. Chinese: 每个输入序列的模板, 用于分配输入梯度。 当所有传入梯度都为None时必须提供, 并确保仅有状态输出的 core 仍返回 形状正确的输入梯度。若省略该参数, 当前仅支持单输入且存在输出序列梯度的 core, 因为回退路径会从第一个非None的输出序列梯度推断输入梯度形状。 state gradients. When provided, the returned state gradients preserve the original initial-state shapes instead of inferring them from state-sequence gradients. Chinese: 每个初始状态的模板, 用于分配初始状态梯度。提供后, 返回的 状态梯度会保持初始状态原始形状, 而不是从状态序列梯度中推断。 incoming gradients areNone, returns zero-filled gradients that follow the provided templates. Chinese: 输入与初始状态的梯度;当T == 0或所有传入 梯度都为None时, 返回符合模板的零梯度。
English
Flexsn Backward function
- 参数:
args (tuple) -- EN: Gradients followed by any saved tensors accepted by the kernel.
input_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-input-sequence templates used to allocate input
state_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-initial-state templates used to allocate initial
- 返回:
EN: Gradients for inputs and initial states. When
T == 0or all- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_backward_ncl_bucket(ncl: int) int[源代码]#
Bucket a flattened sequence size for backward-kernel tuning. API Language: 中文 | English
中文
flexsn backward ncl bucket 函数
- 参数:
ncl (int) -- EN: Flattened element count per time step. Chinese: 单个时间步展平后的元素数。
- 返回:
EN: Bucket index in
[0, 4]. Chinese:[0, 4]范围内的分桶索引。- 返回类型:
- Chinese:
将展平后的单步元素数
NCL映射到 backward kernel 的调优分桶。- English:
Map the flattened per-step element count
NCLto the backward-kernel autotuning bucket.
English
Flexsn Backward Ncl Bucket function
- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_forward(f, info: FlexSNInfo, *args) tuple[源代码]#
Run the training forward kernel for FlexSN. API Language: 中文 | English
中文
flexsn forward 函数
- 参数:
f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。
args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
- 返回:
EN: Forward outputs plus any saved tensors required by backward. When
T == 0, returns empty tensors following the expected templates. Chinese: 前向输出以及 backward 所需的保存张量;当T == 0时, 返回符合模板的空张量。- 返回类型:
- Chinese:
执行 FlexSN 训练前向 kernel。
- English:
Execute the FlexSN training forward kernel.
English
Execute the FlexSN training forward kernel.
- 参数:
f (object) -- Triton kernel callable
info (FlexSNInfo) -- FlexSN metadata
args -- Input/state sequences accepted by the kernel
- 返回:
Forward outputs plus any saved tensors required by backward. When
T == 0, returns empty tensors following the expected templates.- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference(f, info: FlexSNInfo, *args) tuple[源代码]#
Run the inference kernel for a multi-step FlexSN core. API Language: 中文 | English
中文
flexsn inference 函数
- 参数:
f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。
args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
- 返回:
EN: Output/state sequences. When
T == 0, returns empty tensors with the expected templates. Chinese: 输出/状态序列;当T == 0时, 返回符合模板的空张量。- 返回类型:
- Chinese:
执行 FlexSN 多步推理 kernel。
- English:
Execute the FlexSN multi-step inference kernel.
English
Execute the FlexSN multi-step inference kernel.
- 参数:
f (object) -- Triton kernel callable
info (FlexSNInfo) -- FlexSN metadata
args -- Input/state sequences accepted by the kernel
- 返回:
Output/state sequences. When
T == 0, returns empty tensors with the expected templates.- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference_final_state(f, info: FlexSNInfo, *args) tuple[源代码]#
Run the inference kernel and materialize final states. API Language: 中文 | English
中文
flexsn inference final state 函数
- 参数:
f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。
args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
- 返回:
EN: Output sequences followed by final states. When
T == 0, output sequences are empty, provided initial states are cloned, and missing states are zero-filled. Chinese: 输出序列后接最终状态;当T == 0时, 输出序列为空, 已提供的初始状态会被克隆, 缺失状态会以零填充。- 返回类型:
- Chinese:
执行带最终状态物化的 FlexSN 多步推理 kernel。
- English:
Execute the FlexSN inference kernel and materialize final states.
English
Execute the FlexSN inference kernel and materialize final states.
- 参数:
f (object) -- Triton kernel callable
info (FlexSNInfo) -- FlexSN metadata
args -- Input/state sequences accepted by the kernel
- 返回:
Output sequences followed by final states. When
T == 0, output sequences are empty, provided initial states are cloned, and missing states are zero-filled.- 返回类型:
Custom Ops#
中文
custom_ops 模块为 FlexSN 的共享 Triton 路径提供底层 opaque custom op。
它负责在 Python 侧保存 Triton kernel 与元数据,并通过轻量级整数 handle
把这些 kernel 暴露给 torch.compile / AOTAutograd。
该模块的主要职责包括:
在 Python 注册表中维护 FlexSN kernel handle;
将前向 / 反向封装为
torch.librarycustom op,避免编译器追踪 Python kernel 对象与 Triton launcher 细节;为 fake tensor、autograd setup/backward、final-state fast path 提供辅助实现。
通常用户不需要直接调用这些函数;它们主要由
spikingjelly.activation_based.neuron.flexsn.FlexSN 与
spikingjelly.activation_based.neuron.flexsn.FlexSNKernel 间接使用。
English
The custom_ops module provides the low-level opaque custom ops used by
FlexSN's shared Triton execution path. It stores Triton kernels and metadata in
a Python-side registry and exposes them to torch.compile / AOTAutograd
through lightweight integer handle values.
Its main responsibilities are:
maintaining the Python-side FlexSN kernel-handle registry;
wrapping forward and backward as
torch.librarycustom ops so the compiler does not trace Python kernel objects or Triton launcher internals;providing helpers for fake tensors, autograd setup/backward, and final-state fast paths.
Most users do not need to call these functions directly; they are primarily
used indirectly by
spikingjelly.activation_based.neuron.flexsn.FlexSN and
spikingjelly.activation_based.neuron.flexsn.FlexSNKernel.
- class spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.FlexSNKernelHandle(inference_kernel: object | None, inference_info: FlexSNInfo | None, inference_final_state_kernel: object | None, inference_final_state_info: FlexSNInfo | None, forward_kernel: object | None, backward_kernel: object | None, training_info: FlexSNInfo | None, owner_refs: int = 1, active_refs: int = 0)[源代码]#
基类:
object
中文
保存 FlexSN kernel 句柄所关联的 Triton kernel、元数据和引用计数信息。 该结构由
custom_ops模块内部注册表维护,用于把 Python 侧 kernel 对象 绑定到整数handle。
English
Store the Triton kernels, metadata, and reference-counting state associated with a FlexSN kernel handle. Instances are managed by the internal
custom_opsregistry and bind Python-side kernel objects to integerhandlevalues.- inference_info: FlexSNInfo | None#
- inference_final_state_info: FlexSNInfo | None#
- training_info: FlexSNInfo | None#
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.register_flexsn_kernel_handle(*, inference_kernel, inference_info, inference_final_state_kernel, inference_final_state_info, forward_kernel, backward_kernel, training_info) int[源代码]#
-
中文
将一组 FlexSN Triton kernel 与对应元数据注册到 Python 侧注册表,并返回 一个整数
handle。该handle可在后续 custom op 调用中引用这组 kernel。- 参数:
inference_kernel -- 推理 kernel
inference_info -- 推理路径对应的
FlexSNInfoinference_final_state_kernel -- 仅返回最终状态的推理 kernel
inference_final_state_info -- 最终状态推理路径对应的
FlexSNInfoforward_kernel -- 训练前向 kernel
backward_kernel -- 训练反向 kernel
training_info -- 训练路径对应的
FlexSNInfo
- 返回:
新注册的整数句柄
- 返回类型:
English
Register a bundle of FlexSN Triton kernels and their metadata in the Python-side registry and return an integer
handle. The returned handle can be referenced by later custom-op calls.- 参数:
inference_kernel -- Inference kernel
inference_info --
FlexSNInfofor the inference pathinference_final_state_kernel -- Inference kernel that returns final states only
inference_final_state_info --
FlexSNInfofor the inference-final-state pathforward_kernel -- Training forward kernel
backward_kernel -- Training backward kernel
training_info --
FlexSNInfofor the training path
- 返回:
Newly registered integer handle
- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.retain_flexsn_kernel_handle(handle: int) None[源代码]#
-
中文
增加指定 FlexSN kernel handle 的活动引用计数。通常在 autograd context 保存该 handle 时调用,确保相关 kernel 在 backward 完成前不会被清理。
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
English
Increase the active-reference count of the specified FlexSN kernel handle. This is typically used when an autograd context needs to keep the handle alive until backward finishes.
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.retain_owner_flexsn_kernel_handle(handle: int) None[源代码]#
-
中文
增加指定 FlexSN kernel handle 的所有者引用计数。通常在对象拷贝或新的拥有者 接管该 handle 时调用。
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
English
Increase the owner-reference count of the specified FlexSN kernel handle. This is typically used when an object copy or another owner takes over the handle.
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.release_flexsn_kernel_handle(handle: int) None[源代码]#
-
中文
释放一个所有者引用。若所有者引用与活动引用都归零,则相关 kernel 会从注册表中 删除并尝试执行清理。
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
English
Release one owner reference. When both owner and active references reach zero, the associated kernels are removed from the registry and cleaned up.
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.release_active_flexsn_kernel_handle(handle: int) None[源代码]#
-
中文
释放一个活动引用。若所有者引用与活动引用都归零,则相关 kernel 会从注册表中 删除并尝试执行清理。
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
English
Release one active reference. When both owner and active references reach zero, the associated kernels are removed from the registry and cleaned up.
- 参数:
handle (int) -- FlexSN kernel handle
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.attach_flexsn_handle_finalizer(owner, handle: int)[源代码]#
-
中文
为指定
owner绑定一个weakref.finalize,在对象销毁时自动释放对应的 FlexSN kernel handle。- 参数:
owner -- 句柄所有者对象
handle (int) -- FlexSN kernel handle
- 返回:
绑定好的 finalizer
- 返回类型:
English
Attach a
weakref.finalizeobject toownerso the corresponding FlexSN kernel handle is released automatically when the owner is destroyed.- 参数:
owner -- Owner object of the handle
handle (int) -- FlexSN kernel handle
- 返回:
Attached finalizer
- 返回类型:
HigherOrderOp & Eager Scan#
FlexSN time-step scan as a HigherOrderOperator.
Current progress:
M1:
* HOP definition with an eager Python time-step loop impl.
* Eager autograd works via the natural computation graph (x[t] indexing
and
torch.stackare differentiable, so the per-stepcore_fngraph is correctly chained through time). Verified withgradcheck.
M2:
* AOTAutograd tracing (torch.fx.experimental.proxy_tensor.make_fx /
torch._functorch.aot_autograd.aot_function) works by unrolling the scan into T copies ofcore_fn's aten ops.
M3:
* FlexSN(backend="hop") is available as an explicit backend.
* Dynamo recognizes flex_sn_scan via a compatibility registration and can
rewrite the call into a HOP node with a traced
GraphModulebody.
torch.compile(fullgraph=True)for the HOP backend is verified on the Linux CI/server environment, including tensor lifted freevars/closures.
M4:
* lowerable_scan re-expresses the FlexSN step function through PyTorch's
built-in
torch.ops.higher_order.scanwhen that API is available.
It is kept as an explicit experimental helper for investigating a single-scan-node forward path instead of fully unrolling the body.
lowerable_while_loop_scanprovides an alternative experimental forward path based ontorch.ops.higher_order.while_loop. On the Linux validation environment, itstorch.compile(fullgraph=True) + no_gradpath is working after switching to fixed-shape queue carries instead ofx[t]indexing.The experimental while-loop path is wired into
FlexSN(backend="hop")viaSJ_ENABLE_EXPERIMENTAL_LOWERABLE_WHILE_LOOP=1for compile-time forward evaluation, and has been validated on: - a single FlexSN layer, -Linear -> FlexSN -> Linear, -SpikingVGGforward inference.
Current limitations:
The custom Dynamo registration in this file is still a compatibility shim, not a true in-tree
BaseHOPintegration.lowerable_scanis currently an experimental helper, not the default compiled path. In the PyTorch versions we validate against, fake/proxy/export handling for this out-of-tree scan shape is not yet stable enough to enable it by default.Training and autograd still use the existing eager/unrolled path.
The current while-loop lowering is functionally correct for the validated forward
no_gradcases, but it is not yet faster than the currentbackend="inductor"custom-op compile path on the server benchmark.A true first-class Inductor lowering for
flex_sn_scanitself does not exist yet; the current "less unrolled" path relies on PyTorch's built-in scan / while_loop decomposition.
Usage:
from spikingjelly.activation_based.triton_kernel.flexsn import flex_sn_scan
# inputs_seq: tuple of T-leading tensors, e.g. shape [T, N, ...]
# init_states: tuple of per-step state tensors, e.g. shape [N, ...]
# returns: (*output_seqs, *state_seqs) — each with shape [T, ...]
result = flex_sn_scan(
core_fn, num_inputs, num_states, num_outputs, *inputs_seq, *init_states
)
Captured tensor freevars from core_fn are appended after the
[*inputs_seq, *init_states] segment when Dynamo rewrites the HOP call.
- spikingjelly.activation_based.triton_kernel.flexsn.hop.dynamo_hop_available() bool[源代码]#
Report whether the FlexSN Dynamo HOP registration succeeded. API Language: 中文 | English
中文
dynamo hop available 函数
- 返回:
EN:
Truewhen the Dynamo compatibility shim for- 返回类型:
- Chinese:
返回 FlexSN 的 Dynamo HOP 注册是否成功。
- English:
Return whether the FlexSN-specific Dynamo HigherOrderOperator registration has been installed successfully.
flex_sn_scanis registered; otherwiseFalse. Chinese: 当flex_sn_scan的 Dynamo 兼容注册已完成时返回True,否则返回False。
English
Dynamo Hop Available function
- 返回:
EN:
Truewhen the Dynamo compatibility shim for- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.hop.eager_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run the FlexSN scan with an eager Python time-step loop. API Language: 中文 | English
中文
eager scan 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable with signaturenum_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
通过 Python 时间步循环执行 FlexSN scan。
- English:
Run the FlexSN scan with an eager Python loop. This helper is reused by both the HOP eager implementation and the Dynamo-friendly
backend="inductor"path, sotorch.compilecan trace the unrolled loop into a standard FX graph. WhenT == 0,output_template_specsmust describe the output sequence shapes/dtypes so empty outputs can be materialized without executingcore_fn.(*step_inputs, *states, *lifted_args). Chinese: 单步core可调用对象, 签名为(*step_inputs, *states, *lifted_args)。 Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to build empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回num_outputs个输出序列, 再返回num_states个状态序列。
English
Eager Scan function
- 参数:
core_fn (
Callable) -- EN: Single-step core callable with signaturenum_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- spikingjelly.activation_based.triton_kernel.flexsn.hop.eager_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run the eager scan and return output sequences plus final states. API Language: 中文 | English
中文
eager scan final state 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable. Chinese: 单步core可调用对象。num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
执行 eager scan, 返回输出序列以及最终状态。
- English:
Variant of
eager_scan()used whenstore_state_seqs=Falseso the HOP backend does not materialize full state sequences only to discard them. WhenT == 0,output_template_specsis used to build empty output sequences and the provided initial states are cloned into the returned final states. Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to materialize empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 Chinese: 先返回num_outputs个输出序列, 再返回最终状态。
English
Eager Scan Final State function
- 参数:
num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]
- class spikingjelly.activation_based.triton_kernel.flexsn.hop.FlexSNScan[源代码]#
基类:
HigherOrderOperatorHOP that runs a user-defined single-step
corefunction over the API Language: 中文 | English
中文
FlexSN 可微扫描操作(differentiable scanning operation)的高层封装。
提供与 PyTorch 原生
scan操作兼容的扫描函数,支持在脉冲神经网络中 高效执行可微的时间维扫描计算。包含 eager 模式和可降图(lowerable)模式, 可根据上下文自动选择合适的执行后端。其核心可调用对象需遵循[*inputs, *states] -> [*outputs, *states, *intermediates]签名。- 返回类型:
None
leading time dimension of its inputs. The HOP is invoked with a flat argument list so that Dynamo / AOTAutograd can treat it uniformly. Shapes/semantics: *
core_fn: callable with signature(*step_inputs, *states) -> (*step_outputs, *updated_states).num_inputs/num_states/num_outputs: int literals used to partition the flat tensor args.flat_args: firstnum_inputstensors are input sequences with leading time dimT; the nextnum_statestensors are initial states (no time dim); any remaining tensors are lifted freevars that are passed through tocore_fnunchanged at every time step.
Return:
num_outputsoutput sequences followed bynum_statesstate sequences, all stacked along the leading time dim.
English
Flexsnscan function
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run FlexSN scan through PyTorch's built-in
scanHOP. API Language: 中文 | English
中文
lowerable scan 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable. Chinese: 单步core可调用对象。num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
通过 PyTorch 内置
scanHOP 执行 FlexSN scan。- English:
Keep the FlexSN scan as a single higher-order op under tracing so downstream compilers can lower it as a loop instead of fully unrolling the body
Ttimes. This remains an experimental helper rather than the default compiled path. WhenT == 0,output_template_specsmust containnum_outputsitems shaped as(shape, dtype)or(shape, dtype, device). Runtime devices default to the first input sequence when omitted. Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to materialize empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回num_outputs个输出序列, 再返回num_states个状态序列。
English
Lowerable Scan function
- 参数:
num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan_available() bool[源代码]#
Report whether PyTorch's built-in
scanHOP is available. API Language: 中文 | English
中文
lowerable scan available 函数
- 返回:
EN:
Truewhentorch.ops.higher_order.scanis available;- 返回类型:
- Chinese:
返回当前环境是否提供 PyTorch 内置
scanHOP。- English:
Return whether the current environment exposes PyTorch's built-in
scanhigher-order operator. otherwiseFalse. Chinese: 若torch.ops.higher_order.scan可用则 返回True,否则返回False。
English
Lowerable Scan Available function
- 返回:
EN:
Truewhentorch.ops.higher_order.scanis available;- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run the built-in
scanHOP and return final states only. API Language: 中文 | English
中文
lowerable scan final state 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable. Chinese: 单步core可调用对象。num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
通过内置
scanHOP 执行 FlexSN, 返回输出序列与最终状态。- English:
Final-state variant of
lowerable_scan(). WhenT == 0,output_template_specsmaterializes empty output sequences and the initial states are cloned into the returned final states. Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to materialize empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 Chinese: 先返回num_outputs个输出序列, 再返回最终状态。
English
Lowerable Scan Final State function
- 参数:
num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run FlexSN scan through PyTorch's built-in
while_loopHOP. API Language: 中文 | English
中文
lowerable while loop scan 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable. Chinese: 单步core可调用对象。num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
通过 PyTorch 内置
while_loopHOP 执行 FlexSN scan。- English:
Experimental helper for studying whether a first-class loop representation is a better fit than the current unrolled scan path. Current while-loop capture does not support symbolic
x[t]indexing here, so this implementation keeps functional queue buffers. WhenT == 0,output_template_specsmaterializes empty output sequences without runningcore_fn. Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to materialize empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回num_outputs个输出序列, 再返回num_states个状态序列。
English
Lowerable While Loop Scan function
- 参数:
num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed bynum_states- 返回类型:
Tuple[torch.Tensor, ...]
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_available() bool[源代码]#
Report whether PyTorch's built-in
while_loopHOP is available. API Language: 中文 | English
中文
lowerable while loop available 函数
- 返回:
EN:
Truewhentorch.ops.higher_order.while_loopis- 返回类型:
- Chinese:
返回当前环境是否提供 PyTorch 内置
while_loopHOP。- English:
Return whether the current environment exposes PyTorch's built-in
while_loophigher-order operator. available; otherwiseFalse. Chinese: 若torch.ops.higher_order.while_loop可用则返回True,否则返回False。
English
Lowerable While Loop Available function
- 返回:
EN:
Truewhentorch.ops.higher_order.while_loopis- 返回类型:
- spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#
Run the while-loop HOP and return output sequences plus final states. API Language: 中文 | English
中文
lowerable while loop scan final state 函数
- 参数:
core_fn (
Callable) -- EN: Single-step core callable. Chinese: 单步core可调用对象。num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]
- Chinese:
执行 while-loop HOP, 返回输出序列以及最终状态。
- English:
Final-state variant of
lowerable_while_loop_scan(). WhenT == 0,output_template_specsis used to build empty output sequences and the provided initial states are cloned into the returned final states. Chinese: 带时间维T的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。(shape, dtype, device)templates used to materialize empty output sequences whenT == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在T == 0时用于构造空输出序列的可选模板, 每项为(shape, dtype)或(shape, dtype, device);省略device时运行时设备跟随第一个输入序列。 Chinese: 先返回num_outputs个输出序列, 再返回最终状态。
English
Lowerable While Loop Scan Final State function
- 参数:
num_inputs (int) -- EN: Number of T-leading input sequences.
num_states (int) -- EN: Number of initial-state tensors.
num_outputs (int) -- EN: Number of per-step outputs.
flat_args (
torch.Tensor) -- EN: Flattened input sequences, initial states, then liftedoutput_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional
(shape, dtype)or
- 返回:
EN:
num_outputsoutput sequences followed by the final states.- 返回类型:
Tuple[torch.Tensor, ...]