spikingjelly.activation_based.triton_kernel.flexsn package#

Function Info#

class spikingjelly.activation_based.triton_kernel.flexsn.info.FlexSNInfo(num_inputs, num_outputs, num_states, fwd_core_args, fwd_core_returns, fwd_core_recipients, fwd_kernel_returns, num_fwd_kernel_returns, c2k_return_mapping)#

基类:tuple

Create new instance of FlexSNInfo(num_inputs, num_outputs, num_states, fwd_core_args, fwd_core_returns, fwd_core_recipients, fwd_kernel_returns, num_fwd_kernel_returns, c2k_return_mapping)

c2k_return_mapping#

Alias for field number 8

fwd_core_args#

Alias for field number 3

fwd_core_recipients#

Alias for field number 5

fwd_core_returns#

Alias for field number 4

fwd_kernel_returns#

Alias for field number 6

num_fwd_kernel_returns#

Alias for field number 7

num_inputs#

Alias for field number 0

num_outputs#

Alias for field number 1

num_states#

Alias for field number 2

spikingjelly.activation_based.triton_kernel.flexsn.info.extract_info(fwd_graph: Graph, num_inputs: int = 1, num_states: int = 0, num_outputs: int = 1) FlexSNInfo[源代码]#

API Language: 中文 | English


  • 中文

从前向计算图中提取信息。前向图应具有以下签名: [*inputs, *states] -> [*outputs, *states, *intermediates]

提取的信息包括:

  • fwd_core_args: 核心参数

  • fwd_core_returns: 前向图的返回值名称

  • fwd_core_recipients: 接收核心返回值的变量名

  • fwd_kernel_returns: 前向 kernel 的返回值名称(无重复)

  • num_fwd_kernel_returns: fwd_kernel_returns 的长度

  • c2k_return_mapping: 中间结果与 kernel 返回值之间的映射

参数:
  • fwd_graph (fx.Graph) -- 前向计算图

  • num_inputs (int) -- 输入数量,默认为 1

  • num_states (int) -- 状态数量,默认为 0

  • num_outputs (int) -- 输出数量,默认为 1

返回:

提取的 FlexSN 元信息

返回类型:

FlexSNInfo


  • English

Extract useful information from the forward graph. The forward graph should have the following signature: [*inputs, *states] -> [*outputs, *states, *intermediates]

The extracted information includes:

  • fwd_core_args: the core input argument names

  • fwd_core_returns: the return value names of the forward graph

  • fwd_core_recipients: the variable names receiving the core return values

  • fwd_kernel_returns: the forward kernel return value names (no duplicates)

  • num_fwd_kernel_returns: the length of fwd_kernel_returns

  • c2k_return_mapping: mapping from intermediate results to kernel returns

参数:
  • fwd_graph (fx.Graph) -- The forward computational graph

  • num_inputs (int) -- Number of inputs. Default: 1

  • num_states (int) -- Number of states. Default: 0

  • num_outputs (int) -- Number of outputs. Default: 1

返回:

The extracted FlexSN metadata

返回类型:

FlexSNInfo

Template#

FlexSN Triton Kernel Templates.

Insert a single-step kernel into a multi-step kernel template.

spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_inference_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#

Compile a Triton kernel for FlexSN inference (no backward). API Language: 中文 | English


  • 中文

get flexsn inference kernel 函数

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction


  • English

Get Flexsn Inference Kernel function

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction

spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_inference_final_state_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#

Compile a Triton kernel for FlexSN inference returning the final state. API Language: 中文 | English


  • 中文

get flexsn inference final state kernel 函数

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction


  • English

Get Flexsn Inference Final State Kernel function

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction

spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_forward_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#

Compile a Triton kernel for FlexSN forward pass (with state saving). API Language: 中文 | English


  • 中文

get flexsn forward kernel 函数

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction


  • English

Get Flexsn Forward Kernel function

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction

spikingjelly.activation_based.triton_kernel.flexsn.template.get_flexsn_backward_kernel(core_str: str, core_name: str, info: FlexSNInfo, verbose: bool = False)[源代码]#

Compile a Triton kernel for FlexSN backward pass. API Language: 中文 | English


  • 中文

get flexsn backward kernel 函数

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction


  • English

Get Flexsn Backward Kernel function

参数:
  • core_str (str) -- Core kernel source code as a string

  • core_name (str) -- Unique name for the compiled kernel

  • info (FlexSNInfo) -- FlexSN kernel metadata

  • verbose (bool) -- If True, print compilation info

返回:

Compiled Triton kernel executable

返回类型:

triton.runtime.JITFunction

Kernel Builders#

Build Triton scan kernels for FlexSN's Triton/Inductor backends.

Three entry points: * build_inference_kernel — no-grad fast path (make_fx, no PYTORCH_JIT=0 needed) * build_inference_final_state_kernel — inference path that returns final states only * build_training_kernels — forward + backward kernels for full BPTT training

spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_inference_kernel(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None)[源代码]#

API Language: 中文 | English


  • 中文

core_fn 构建单次 scan 的推理 Triton kernel。 构建出的 kernel 会把 core_fn 的单步计算包裹在 tl.static_range(T) 时间循环中,因此无论 T 多大,单次推理都只触发一次 kernel launch。

参数:
  • core_fn (Callable) -- 单步动力学函数,签名应为 (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- 每个时间步输入张量的数量

  • num_states (int) -- 状态张量的数量

  • num_outputs (int) -- 每个时间步输出张量的数量

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量 [*inputs, *states] 。若为 None ,则自动构造单位大小的 CUDA float32 张量

返回:

(kernel, info) ,其中 kernel 为编译后的 Triton kernel, info 为调用 spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference() 所需的 FlexSNInfo

返回类型:

Tuple[object, FlexSNInfo]


  • English

Build a single-pass scan inference Triton kernel for core_fn. The generated kernel wraps core_fn's per-step computation in a tl.static_range(T) loop, so one inference call launches exactly one kernel regardless of T.

参数:
  • core_fn (Callable) -- Single-step dynamics callable with signature (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- Number of per-step input tensors

  • num_states (int) -- Number of state tensors

  • num_outputs (int) -- Number of per-step output tensors

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors [*inputs, *states]. If None, unit-sized CUDA float32 tensors are created

返回:

(kernel, info) where kernel is the compiled Triton kernel and info is the FlexSNInfo metadata required by spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference()

返回类型:

Tuple[object, FlexSNInfo]

spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_inference_final_state_kernel(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None)[源代码]#

API Language: 中文 | English


  • 中文

core_fn 构建返回输出序列与最终状态的推理 Triton kernel。 该变体与 build_inference_kernel() 一样会追踪 core_fn 并生成 scan kernel,但它只返回最终状态张量,而不是完整状态序列。

参数:
  • core_fn (Callable) -- 单步动力学函数,签名应为 (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- 每个时间步输入张量的数量

  • num_states (int) -- 状态张量的数量

  • num_outputs (int) -- 每个时间步输出张量的数量

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量 [*inputs, *states]

返回:

(kernel, info) ,分别来自 get_flexsn_inference_final_state_kernelextract_info

返回类型:

Tuple[object, FlexSNInfo]


  • English

Build an inference Triton kernel that returns output sequences and final states for core_fn. This variant traces core_fn like build_inference_kernel(), but materializes final state tensors instead of full state sequences.

参数:
  • core_fn (Callable) -- Single-step dynamics callable with signature (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- Number of per-step input tensors

  • num_states (int) -- Number of state tensors

  • num_outputs (int) -- Number of per-step output tensors

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors [*inputs, *states]

返回:

(kernel, info) produced by get_flexsn_inference_final_state_kernel and extract_info

返回类型:

Tuple[object, FlexSNInfo]

spikingjelly.activation_based.triton_kernel.flexsn.kernel.build_training_kernels(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, example_inputs: Tuple[Tensor, ...] | None = None, requires_grad: Tuple[bool, ...] | None = None)[源代码]#

API Language: 中文 | English


  • 中文

为 BPTT 训练构建 FlexSN 的前向与反向 Triton scan kernel。 该函数会使用 aot_function 追踪 core_fn 的正向与反向图,并生成:

  • 保存反向所需中间量的前向 scan kernel

  • 执行逆时间反向传播的 backward scan kernel

若某些输出(例如硬阈值脉冲)不可微,AOT backward 会省略对应梯度参数, 此函数会自动生成 shim 以对齐 kernel template 的调用约定。

参数:
  • core_fn (Callable) -- 单步动力学函数,签名应为 (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- 每个时间步输入张量的数量

  • num_states (int) -- 状态张量的数量

  • num_outputs (int) -- 每个时间步输出张量的数量

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- 可选的示例张量 [*inputs, *states]

  • requires_grad (Optional[Tuple[bool, ...]]) -- 指示 example_inputs 中每个参数是否需要梯度。 若为 None ,则对所有浮点/复数输入启用梯度追踪

返回:

(fwd_kernel, bwd_kernel, info) ,可直接接入 FlexSN 当前共享的 custom-op 执行路径

返回类型:

Tuple[object, object, FlexSNInfo]


  • English

Build FlexSN forward and backward Triton scan kernels for BPTT training. This function uses aot_function to trace both the forward and backward of core_fn and then produces:

  • a forward scan kernel that saves intermediates needed by backward

  • a backward scan kernel that runs the reverse-time pass

When some outputs (for example, hard-threshold spike signals) are non-differentiable, AOT backward drops the corresponding gradient inputs. This function automatically generates a shim so the kernel template calling convention stays aligned.

参数:
  • core_fn (Callable) -- Single-step dynamics callable with signature (*inputs, *states) -> (*outputs, *updated_states)

  • num_inputs (int) -- Number of per-step input tensors

  • num_states (int) -- Number of state tensors

  • num_outputs (int) -- Number of per-step output tensors

  • example_inputs (Optional[Tuple[torch.Tensor, ...]]) -- Optional example tensors [*inputs, *states]

  • requires_grad (Optional[Tuple[bool, ...]]) -- Flags indicating whether each example input should require gradients. If None, all floating-point and complex inputs are traced as differentiable

返回:

(fwd_kernel, bwd_kernel, info) suitable for FlexSN's shared custom-op execution path

返回类型:

Tuple[object, object, FlexSNInfo]

Wrapper#

spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_backward(f, info: FlexSNInfo, *args, input_templates: Tuple[Tensor, ...] | None = None, state_templates: Tuple[Tensor, ...] | None = None) tuple[源代码]#

Run the training backward kernel for FlexSN. API Language: 中文 | English


  • 中文

flexsn backward 函数

参数:
  • f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。

  • info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。

  • args (tuple) -- EN: Gradients followed by any saved tensors accepted by the kernel.

  • input_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-input-sequence templates used to allocate input

  • state_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-initial-state templates used to allocate initial

返回:

EN: Gradients for inputs and initial states. When T == 0 or all

返回类型:

tuple

Chinese:

执行 FlexSN 训练反向 kernel。

English:

Execute the FlexSN training backward kernel. The leading info.num_outputs + info.num_states entries correspond to output/state-sequence gradients, and the remaining entries are saved tensors from the forward pass. Chinese: kernel 接收的梯度与保存张量。前 info.num_outputs + info.num_states 个参数对应输出/状态序列梯度, 其余参数 为前向保存张量。 gradients. They are required when all incoming gradients are None and ensure that state-only cores still return correctly shaped input-sequence gradients. When omitted, the fallback allocation path is only valid for single-input cores whose output-sequence gradients are present, because it infers the input gradient shape from the first non-None output-sequence gradient. Chinese: 每个输入序列的模板, 用于分配输入梯度。 当所有传入梯度都为 None 时必须提供, 并确保仅有状态输出的 core 仍返回 形状正确的输入梯度。若省略该参数, 当前仅支持单输入且存在输出序列梯度的 core, 因为回退路径会从第一个非 None 的输出序列梯度推断输入梯度形状。 state gradients. When provided, the returned state gradients preserve the original initial-state shapes instead of inferring them from state-sequence gradients. Chinese: 每个初始状态的模板, 用于分配初始状态梯度。提供后, 返回的 状态梯度会保持初始状态原始形状, 而不是从状态序列梯度中推断。 incoming gradients are None, returns zero-filled gradients that follow the provided templates. Chinese: 输入与初始状态的梯度;当 T == 0 或所有传入 梯度都为 None 时, 返回符合模板的零梯度。


  • English

Flexsn Backward function

参数:
  • args (tuple) -- EN: Gradients followed by any saved tensors accepted by the kernel.

  • input_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-input-sequence templates used to allocate input

  • state_templates (Optional[Tuple[torch.Tensor, ...]]) -- EN: Per-initial-state templates used to allocate initial

返回:

EN: Gradients for inputs and initial states. When T == 0 or all

返回类型:

tuple

spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_backward_ncl_bucket(ncl: int) int[源代码]#

Bucket a flattened sequence size for backward-kernel tuning. API Language: 中文 | English


  • 中文

flexsn backward ncl bucket 函数

参数:

ncl (int) -- EN: Flattened element count per time step. Chinese: 单个时间步展平后的元素数。

返回:

EN: Bucket index in [0, 4]. Chinese: [0, 4] 范围内的分桶索引。

返回类型:

int

Chinese:

将展平后的单步元素数 NCL 映射到 backward kernel 的调优分桶。

English:

Map the flattened per-step element count NCL to the backward-kernel autotuning bucket.


  • English

Flexsn Backward Ncl Bucket function

返回类型:

int

spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_forward(f, info: FlexSNInfo, *args) tuple[源代码]#

Run the training forward kernel for FlexSN. API Language: 中文 | English


  • 中文

flexsn forward 函数

参数:
  • f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。

  • info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。

  • args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。

返回:

EN: Forward outputs plus any saved tensors required by backward. When T == 0, returns empty tensors following the expected templates. Chinese: 前向输出以及 backward 所需的保存张量;当 T == 0 时, 返回符合模板的空张量。

返回类型:

tuple

Chinese:

执行 FlexSN 训练前向 kernel。

English:

Execute the FlexSN training forward kernel.


  • English

Execute the FlexSN training forward kernel.

参数:
  • f (object) -- Triton kernel callable

  • info (FlexSNInfo) -- FlexSN metadata

  • args -- Input/state sequences accepted by the kernel

返回:

Forward outputs plus any saved tensors required by backward. When T == 0, returns empty tensors following the expected templates.

返回类型:

tuple

spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference(f, info: FlexSNInfo, *args) tuple[源代码]#

Run the inference kernel for a multi-step FlexSN core. API Language: 中文 | English


  • 中文

flexsn inference 函数

参数:
  • f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。

  • info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。

  • args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。

返回:

EN: Output/state sequences. When T == 0, returns empty tensors with the expected templates. Chinese: 输出/状态序列;当 T == 0 时, 返回符合模板的空张量。

返回类型:

tuple

Chinese:

执行 FlexSN 多步推理 kernel。

English:

Execute the FlexSN multi-step inference kernel.


  • English

Execute the FlexSN multi-step inference kernel.

参数:
  • f (object) -- Triton kernel callable

  • info (FlexSNInfo) -- FlexSN metadata

  • args -- Input/state sequences accepted by the kernel

返回:

Output/state sequences. When T == 0, returns empty tensors with the expected templates.

返回类型:

tuple

spikingjelly.activation_based.triton_kernel.flexsn.wrapper.flexsn_inference_final_state(f, info: FlexSNInfo, *args) tuple[源代码]#

Run the inference kernel and materialize final states. API Language: 中文 | English


  • 中文

flexsn inference final state 函数

参数:
  • f -- EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。

  • info -- EN: FlexSN metadata. Chinese: FlexSN 元信息。

  • args -- EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。

返回:

EN: Output sequences followed by final states. When T == 0, output sequences are empty, provided initial states are cloned, and missing states are zero-filled. Chinese: 输出序列后接最终状态;当 T == 0 时, 输出序列为空, 已提供的初始状态会被克隆, 缺失状态会以零填充。

返回类型:

tuple

Chinese:

执行带最终状态物化的 FlexSN 多步推理 kernel。

English:

Execute the FlexSN inference kernel and materialize final states.


  • English

Execute the FlexSN inference kernel and materialize final states.

参数:
  • f (object) -- Triton kernel callable

  • info (FlexSNInfo) -- FlexSN metadata

  • args -- Input/state sequences accepted by the kernel

返回:

Output sequences followed by final states. When T == 0, output sequences are empty, provided initial states are cloned, and missing states are zero-filled.

返回类型:

tuple

Custom Ops#

API Language: 中文 | English


  • 中文

custom_ops 模块为 FlexSN 的共享 Triton 路径提供底层 opaque custom op。 它负责在 Python 侧保存 Triton kernel 与元数据,并通过轻量级整数 handle 把这些 kernel 暴露给 torch.compile / AOTAutograd。

该模块的主要职责包括:

  • 在 Python 注册表中维护 FlexSN kernel handle;

  • 将前向 / 反向封装为 torch.library custom op,避免编译器追踪 Python kernel 对象与 Triton launcher 细节;

  • 为 fake tensor、autograd setup/backward、final-state fast path 提供辅助实现。

通常用户不需要直接调用这些函数;它们主要由 spikingjelly.activation_based.neuron.flexsn.FlexSNspikingjelly.activation_based.neuron.flexsn.FlexSNKernel 间接使用。


  • English

The custom_ops module provides the low-level opaque custom ops used by FlexSN's shared Triton execution path. It stores Triton kernels and metadata in a Python-side registry and exposes them to torch.compile / AOTAutograd through lightweight integer handle values.

Its main responsibilities are:

  • maintaining the Python-side FlexSN kernel-handle registry;

  • wrapping forward and backward as torch.library custom ops so the compiler does not trace Python kernel objects or Triton launcher internals;

  • providing helpers for fake tensors, autograd setup/backward, and final-state fast paths.

Most users do not need to call these functions directly; they are primarily used indirectly by spikingjelly.activation_based.neuron.flexsn.FlexSN and spikingjelly.activation_based.neuron.flexsn.FlexSNKernel.

class spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.FlexSNKernelHandle(inference_kernel: object | None, inference_info: FlexSNInfo | None, inference_final_state_kernel: object | None, inference_final_state_info: FlexSNInfo | None, forward_kernel: object | None, backward_kernel: object | None, training_info: FlexSNInfo | None, owner_refs: int = 1, active_refs: int = 0)[源代码]#

基类:object

API Language: 中文 | English


  • 中文

保存 FlexSN kernel 句柄所关联的 Triton kernel、元数据和引用计数信息。 该结构由 custom_ops 模块内部注册表维护,用于把 Python 侧 kernel 对象 绑定到整数 handle


  • English

Store the Triton kernels, metadata, and reference-counting state associated with a FlexSN kernel handle. Instances are managed by the internal custom_ops registry and bind Python-side kernel objects to integer handle values.

inference_kernel: object | None#
inference_info: FlexSNInfo | None#
inference_final_state_kernel: object | None#
inference_final_state_info: FlexSNInfo | None#
forward_kernel: object | None#
backward_kernel: object | None#
training_info: FlexSNInfo | None#
owner_refs: int = 1#
active_refs: int = 0#
spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.register_flexsn_kernel_handle(*, inference_kernel, inference_info, inference_final_state_kernel, inference_final_state_info, forward_kernel, backward_kernel, training_info) int[源代码]#

API Language: 中文 | English


  • 中文

将一组 FlexSN Triton kernel 与对应元数据注册到 Python 侧注册表,并返回 一个整数 handle 。该 handle 可在后续 custom op 调用中引用这组 kernel。

参数:
  • inference_kernel -- 推理 kernel

  • inference_info -- 推理路径对应的 FlexSNInfo

  • inference_final_state_kernel -- 仅返回最终状态的推理 kernel

  • inference_final_state_info -- 最终状态推理路径对应的 FlexSNInfo

  • forward_kernel -- 训练前向 kernel

  • backward_kernel -- 训练反向 kernel

  • training_info -- 训练路径对应的 FlexSNInfo

返回:

新注册的整数句柄

返回类型:

int


  • English

Register a bundle of FlexSN Triton kernels and their metadata in the Python-side registry and return an integer handle. The returned handle can be referenced by later custom-op calls.

参数:
  • inference_kernel -- Inference kernel

  • inference_info -- FlexSNInfo for the inference path

  • inference_final_state_kernel -- Inference kernel that returns final states only

  • inference_final_state_info -- FlexSNInfo for the inference-final-state path

  • forward_kernel -- Training forward kernel

  • backward_kernel -- Training backward kernel

  • training_info -- FlexSNInfo for the training path

返回:

Newly registered integer handle

返回类型:

int

spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.retain_flexsn_kernel_handle(handle: int) None[源代码]#

API Language: 中文 | English


  • 中文

增加指定 FlexSN kernel handle 的活动引用计数。通常在 autograd context 保存该 handle 时调用,确保相关 kernel 在 backward 完成前不会被清理。

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None


  • English

Increase the active-reference count of the specified FlexSN kernel handle. This is typically used when an autograd context needs to keep the handle alive until backward finishes.

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.retain_owner_flexsn_kernel_handle(handle: int) None[源代码]#

API Language: 中文 | English


  • 中文

增加指定 FlexSN kernel handle 的所有者引用计数。通常在对象拷贝或新的拥有者 接管该 handle 时调用。

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None


  • English

Increase the owner-reference count of the specified FlexSN kernel handle. This is typically used when an object copy or another owner takes over the handle.

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.release_flexsn_kernel_handle(handle: int) None[源代码]#

API Language: 中文 | English


  • 中文

释放一个所有者引用。若所有者引用与活动引用都归零,则相关 kernel 会从注册表中 删除并尝试执行清理。

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None


  • English

Release one owner reference. When both owner and active references reach zero, the associated kernels are removed from the registry and cleaned up.

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.release_active_flexsn_kernel_handle(handle: int) None[源代码]#

API Language: 中文 | English


  • 中文

释放一个活动引用。若所有者引用与活动引用都归零,则相关 kernel 会从注册表中 删除并尝试执行清理。

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None


  • English

Release one active reference. When both owner and active references reach zero, the associated kernels are removed from the registry and cleaned up.

参数:

handle (int) -- FlexSN kernel handle

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.flexsn.custom_ops.attach_flexsn_handle_finalizer(owner, handle: int)[源代码]#

API Language: 中文 | English


  • 中文

为指定 owner 绑定一个 weakref.finalize ,在对象销毁时自动释放对应的 FlexSN kernel handle。

参数:
  • owner -- 句柄所有者对象

  • handle (int) -- FlexSN kernel handle

返回:

绑定好的 finalizer

返回类型:

weakref.finalize


  • English

Attach a weakref.finalize object to owner so the corresponding FlexSN kernel handle is released automatically when the owner is destroyed.

参数:
  • owner -- Owner object of the handle

  • handle (int) -- FlexSN kernel handle

返回:

Attached finalizer

返回类型:

weakref.finalize

HigherOrderOp & Eager Scan#

FlexSN time-step scan as a HigherOrderOperator.

Current progress:

M1: * HOP definition with an eager Python time-step loop impl. * Eager autograd works via the natural computation graph (x[t] indexing

and torch.stack are differentiable, so the per-step core_fn graph is correctly chained through time). Verified with gradcheck.

M2: * AOTAutograd tracing (torch.fx.experimental.proxy_tensor.make_fx /

torch._functorch.aot_autograd.aot_function) works by unrolling the scan into T copies of core_fn's aten ops.

M3: * FlexSN(backend="hop") is available as an explicit backend. * Dynamo recognizes flex_sn_scan via a compatibility registration and can

rewrite the call into a HOP node with a traced GraphModule body.

  • torch.compile(fullgraph=True) for the HOP backend is verified on the Linux CI/server environment, including tensor lifted freevars/closures.

M4: * lowerable_scan re-expresses the FlexSN step function through PyTorch's

built-in torch.ops.higher_order.scan when that API is available.

  • It is kept as an explicit experimental helper for investigating a single-scan-node forward path instead of fully unrolling the body.

  • lowerable_while_loop_scan provides an alternative experimental forward path based on torch.ops.higher_order.while_loop. On the Linux validation environment, its torch.compile(fullgraph=True) + no_grad path is working after switching to fixed-shape queue carries instead of x[t] indexing.

  • The experimental while-loop path is wired into FlexSN(backend="hop") via SJ_ENABLE_EXPERIMENTAL_LOWERABLE_WHILE_LOOP=1 for compile-time forward evaluation, and has been validated on: - a single FlexSN layer, - Linear -> FlexSN -> Linear, - SpikingVGG forward inference.

Current limitations:

  • The custom Dynamo registration in this file is still a compatibility shim, not a true in-tree BaseHOP integration.

  • lowerable_scan is currently an experimental helper, not the default compiled path. In the PyTorch versions we validate against, fake/proxy/export handling for this out-of-tree scan shape is not yet stable enough to enable it by default.

  • Training and autograd still use the existing eager/unrolled path.

  • The current while-loop lowering is functionally correct for the validated forward no_grad cases, but it is not yet faster than the current backend="inductor" custom-op compile path on the server benchmark.

  • A true first-class Inductor lowering for flex_sn_scan itself does not exist yet; the current "less unrolled" path relies on PyTorch's built-in scan / while_loop decomposition.

Usage:

from spikingjelly.activation_based.triton_kernel.flexsn import flex_sn_scan

# inputs_seq: tuple of T-leading tensors, e.g. shape [T, N, ...]
# init_states: tuple of per-step state tensors, e.g. shape [N, ...]
# returns: (*output_seqs, *state_seqs) — each with shape [T, ...]
result = flex_sn_scan(
    core_fn, num_inputs, num_states, num_outputs, *inputs_seq, *init_states
)

Captured tensor freevars from core_fn are appended after the [*inputs_seq, *init_states] segment when Dynamo rewrites the HOP call.

spikingjelly.activation_based.triton_kernel.flexsn.hop.dynamo_hop_available() bool[源代码]#

Report whether the FlexSN Dynamo HOP registration succeeded. API Language: 中文 | English


  • 中文

dynamo hop available 函数

返回:

EN: True when the Dynamo compatibility shim for

返回类型:

bool

Chinese:

返回 FlexSN 的 Dynamo HOP 注册是否成功。

English:

Return whether the FlexSN-specific Dynamo HigherOrderOperator registration has been installed successfully. flex_sn_scan is registered; otherwise False. Chinese: 当 flex_sn_scan 的 Dynamo 兼容注册已完成时返回 True,否则返回 False


  • English

Dynamo Hop Available function

返回:

EN: True when the Dynamo compatibility shim for

返回类型:

bool

spikingjelly.activation_based.triton_kernel.flexsn.hop.eager_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run the FlexSN scan with an eager Python time-step loop. API Language: 中文 | English


  • 中文

eager scan 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable with signature

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

通过 Python 时间步循环执行 FlexSN scan。

English:

Run the FlexSN scan with an eager Python loop. This helper is reused by both the HOP eager implementation and the Dynamo-friendly backend="inductor" path, so torch.compile can trace the unrolled loop into a standard FX graph. When T == 0, output_template_specs must describe the output sequence shapes/dtypes so empty outputs can be materialized without executing core_fn. (*step_inputs, *states, *lifted_args). Chinese: 单步 core 可调用对象, 签名为 (*step_inputs, *states, *lifted_args)。 Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to build empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回 num_outputs 个输出序列, 再返回 num_states 个状态序列。


  • English

Eager Scan function

参数:
  • core_fn (Callable) -- EN: Single-step core callable with signature

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

spikingjelly.activation_based.triton_kernel.flexsn.hop.eager_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run the eager scan and return output sequences plus final states. API Language: 中文 | English


  • 中文

eager scan final state 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable. Chinese: 单步 core 可调用对象。

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

执行 eager scan, 返回输出序列以及最终状态。

English:

Variant of eager_scan() used when store_state_seqs=False so the HOP backend does not materialize full state sequences only to discard them. When T == 0, output_template_specs is used to build empty output sequences and the provided initial states are cloned into the returned final states. Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to materialize empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 Chinese: 先返回 num_outputs 个输出序列, 再返回最终状态。


  • English

Eager Scan Final State function

参数:
  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]

class spikingjelly.activation_based.triton_kernel.flexsn.hop.FlexSNScan[源代码]#

基类:HigherOrderOperator

HOP that runs a user-defined single-step core function over the API Language: 中文 | English


  • 中文

FlexSN 可微扫描操作(differentiable scanning operation)的高层封装。

提供与 PyTorch 原生 scan 操作兼容的扫描函数,支持在脉冲神经网络中 高效执行可微的时间维扫描计算。包含 eager 模式和可降图(lowerable)模式, 可根据上下文自动选择合适的执行后端。其核心可调用对象需遵循 [*inputs, *states] -> [*outputs, *states, *intermediates] 签名。

返回类型:

None

leading time dimension of its inputs. The HOP is invoked with a flat argument list so that Dynamo / AOTAutograd can treat it uniformly. Shapes/semantics: * core_fn: callable with signature

(*step_inputs, *states) -> (*step_outputs, *updated_states).

  • num_inputs / num_states / num_outputs: int literals used to partition the flat tensor args.

  • flat_args: first num_inputs tensors are input sequences with leading time dim T; the next num_states tensors are initial states (no time dim); any remaining tensors are lifted freevars that are passed through to core_fn unchanged at every time step.

Return: num_outputs output sequences followed by num_states state sequences, all stacked along the leading time dim.


  • English

Flexsnscan function

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run FlexSN scan through PyTorch's built-in scan HOP. API Language: 中文 | English


  • 中文

lowerable scan 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable. Chinese: 单步 core 可调用对象。

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

通过 PyTorch 内置 scan HOP 执行 FlexSN scan。

English:

Keep the FlexSN scan as a single higher-order op under tracing so downstream compilers can lower it as a loop instead of fully unrolling the body T times. This remains an experimental helper rather than the default compiled path. When T == 0, output_template_specs must contain num_outputs items shaped as (shape, dtype) or (shape, dtype, device). Runtime devices default to the first input sequence when omitted. Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to materialize empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回 num_outputs 个输出序列, 再返回 num_states 个状态序列。


  • English

Lowerable Scan function

参数:
  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan_available() bool[源代码]#

Report whether PyTorch's built-in scan HOP is available. API Language: 中文 | English


  • 中文

lowerable scan available 函数

返回:

EN: True when torch.ops.higher_order.scan is available;

返回类型:

bool

Chinese:

返回当前环境是否提供 PyTorch 内置 scan HOP。

English:

Return whether the current environment exposes PyTorch's built-in scan higher-order operator. otherwise False. Chinese: 若 torch.ops.higher_order.scan 可用则 返回 True,否则返回 False


  • English

Lowerable Scan Available function

返回:

EN: True when torch.ops.higher_order.scan is available;

返回类型:

bool

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run the built-in scan HOP and return final states only. API Language: 中文 | English


  • 中文

lowerable scan final state 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable. Chinese: 单步 core 可调用对象。

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

通过内置 scan HOP 执行 FlexSN, 返回输出序列与最终状态。

English:

Final-state variant of lowerable_scan(). When T == 0, output_template_specs materializes empty output sequences and the initial states are cloned into the returned final states. Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to materialize empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 Chinese: 先返回 num_outputs 个输出序列, 再返回最终状态。


  • English

Lowerable Scan Final State function

参数:
  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_scan(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run FlexSN scan through PyTorch's built-in while_loop HOP. API Language: 中文 | English


  • 中文

lowerable while loop scan 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable. Chinese: 单步 core 可调用对象。

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

通过 PyTorch 内置 while_loop HOP 执行 FlexSN scan。

English:

Experimental helper for studying whether a first-class loop representation is a better fit than the current unrolled scan path. Current while-loop capture does not support symbolic x[t] indexing here, so this implementation keeps functional queue buffers. When T == 0, output_template_specs materializes empty output sequences without running core_fn. Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to materialize empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 state sequences. Chinese: 先返回 num_outputs 个输出序列, 再返回 num_states 个状态序列。


  • English

Lowerable While Loop Scan function

参数:
  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by num_states

返回类型:

Tuple[torch.Tensor, ...]

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_available() bool[源代码]#

Report whether PyTorch's built-in while_loop HOP is available. API Language: 中文 | English


  • 中文

lowerable while loop available 函数

返回:

EN: True when torch.ops.higher_order.while_loop is

返回类型:

bool

Chinese:

返回当前环境是否提供 PyTorch 内置 while_loop HOP。

English:

Return whether the current environment exposes PyTorch's built-in while_loop higher-order operator. available; otherwise False. Chinese: 若 torch.ops.higher_order.while_loop 可用则返回 True,否则返回 False


  • English

Lowerable While Loop Available function

返回:

EN: True when torch.ops.higher_order.while_loop is

返回类型:

bool

spikingjelly.activation_based.triton_kernel.flexsn.hop.lowerable_while_loop_scan_final_state(core_fn: Callable, num_inputs: int, num_states: int, num_outputs: int, *flat_args: Tensor, output_template_specs: Tuple[Tuple[Tuple[int, ...], dtype] | Tuple[Tuple[int, ...], dtype, device], ...] | None = None) Tuple[Tensor, ...][源代码]#

Run the while-loop HOP and return output sequences plus final states. API Language: 中文 | English


  • 中文

lowerable while loop scan final state 函数

参数:
  • core_fn (Callable) -- EN: Single-step core callable. Chinese: 单步 core 可调用对象。

  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]

Chinese:

执行 while-loop HOP, 返回输出序列以及最终状态。

English:

Final-state variant of lowerable_while_loop_scan(). When T == 0, output_template_specs is used to build empty output sequences and the provided initial states are cloned into the returned final states. Chinese: 带时间维 T 的输入序列数量。 Chinese: 初始状态张量数量。 Chinese: 每个时间步输出数量。 tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。 (shape, dtype, device) templates used to materialize empty output sequences when T == 0. Omitted devices follow the first input sequence at runtime. Chinese: 在 T == 0 时用于构造空输出序列的可选模板, 每项为 (shape, dtype)(shape, dtype, device);省略 device 时运行时设备跟随第一个输入序列。 Chinese: 先返回 num_outputs 个输出序列, 再返回最终状态。


  • English

Lowerable While Loop Scan Final State function

参数:
  • num_inputs (int) -- EN: Number of T-leading input sequences.

  • num_states (int) -- EN: Number of initial-state tensors.

  • num_outputs (int) -- EN: Number of per-step outputs.

  • flat_args (torch.Tensor) -- EN: Flattened input sequences, initial states, then lifted

  • output_template_specs (Optional[OutputTemplateSpecs]) -- EN: Optional (shape, dtype) or

返回:

EN: num_outputs output sequences followed by the final states.

返回类型:

Tuple[torch.Tensor, ...]