from typing import Optional, Tuple
import torch
try:
import triton
except Exception as e:
import logging
from .. import dummy
logging.info(f"spikingjelly.activation_based.triton_kernel.flexsn.wrapper: {e}")
triton = dummy.DummyImport()
from ..triton_utils import type_dict
from .info import FlexSNInfo
__all__ = [
"flexsn_backward",
"flexsn_backward_ncl_bucket",
"flexsn_forward",
"flexsn_inference",
"flexsn_inference_final_state",
]
_BACKWARD_SMALL_MAX_NCL = 1 << 12
_BACKWARD_MEDIUM_MAX_NCL = 1 << 17
_BACKWARD_LARGE_MAX_NCL = 1 << 20
_BACKWARD_XLARGE_MAX_NCL = 1 << 23
[文档]
def flexsn_backward_ncl_bucket(ncl: int) -> int:
"""Bucket a flattened sequence size for backward-kernel tuning.
**API Language:**
:ref:`中文 <flexsn_backward_ncl_bucket-cn>` | :ref:`English <flexsn_backward_ncl_bucket-en>`
----
.. _flexsn_backward_ncl_bucket-cn:
* **中文**
flexsn backward ncl bucket 函数
:param ncl: EN: Flattened element count per time step. Chinese: 单个时间步展平后的元素数。
:type ncl: int
:return: EN: Bucket index in ``[0, 4]``. Chinese: ``[0, 4]`` 范围内的分桶索引。
:rtype: int
Chinese:
将展平后的单步元素数 ``NCL`` 映射到 backward kernel 的调优分桶。
English:
Map the flattened per-step element count ``NCL`` to the backward-kernel
autotuning bucket.
----
.. _flexsn_backward_ncl_bucket-en:
* **English**
Flexsn Backward Ncl Bucket function
:type ncl: int
:rtype: int
"""
if ncl <= _BACKWARD_SMALL_MAX_NCL:
return 0
if ncl <= _BACKWARD_MEDIUM_MAX_NCL:
return 1
if ncl <= _BACKWARD_LARGE_MAX_NCL:
return 2
if ncl <= _BACKWARD_XLARGE_MAX_NCL:
return 3
return 4
def _num_elements_per_step(x: torch.Tensor) -> int:
n = 1
for dim in x.shape[1:]:
n *= dim
return n
def _make_grid(ncl: int):
def grid(meta):
return (triton.cdiv(ncl, meta["BLOCK_NCL"]),)
return grid
def _first_non_none_tensor(tensors):
for tensor in tensors:
if tensor is not None:
return tensor
return None
def _allocate_state_grad(
i: int,
T: int,
state_templates: Optional[Tuple[torch.Tensor, ...]],
grad_state_seq_examples,
grad_example: torch.Tensor,
) -> torch.Tensor:
if state_templates is not None:
return (
torch.zeros_like(state_templates[i])
if T == 0
else torch.empty_like(state_templates[i])
)
if i < len(grad_state_seq_examples) and grad_state_seq_examples[i] is not None:
example = grad_state_seq_examples[i]
return (
example.new_zeros(example.shape[1:])
if T == 0
else example.new_empty(example.shape[1:])
)
return (
grad_example.new_zeros(grad_example.shape[1:])
if T == 0
else grad_example.new_empty(grad_example.shape[1:])
)
[文档]
def flexsn_inference(f, info: FlexSNInfo, *args) -> tuple:
"""Run the inference kernel for a multi-step FlexSN core.
**API Language:**
:ref:`中文 <flexsn_inference-cn>` | :ref:`English <flexsn_inference-en>`
----
.. _flexsn_inference-cn:
* **中文**
flexsn inference 函数
:param f: EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
:param info: EN: FlexSN metadata. Chinese: FlexSN 元信息。
:param args: EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
:return: EN: Output/state sequences. When ``T == 0``, returns empty tensors with the expected templates. Chinese: 输出/状态序列;当 ``T == 0`` 时, 返回符合模板的空张量。
:rtype: tuple
Chinese:
执行 FlexSN 多步推理 kernel。
English:
Execute the FlexSN multi-step inference kernel.
----
.. _flexsn_inference-en:
* **English**
Execute the FlexSN multi-step inference kernel.
:param f: Triton kernel callable
:type f: object
:param info: FlexSN metadata
:type info: FlexSNInfo
:param args: Input/state sequences accepted by the kernel
:return: Output/state sequences. When ``T == 0``, returns empty tensors with the expected templates.
:rtype: tuple
"""
x_example = args[0]
T = x_example.shape[0]
NCL = _num_elements_per_step(x_example)
dtype = x_example.dtype
outputs = [
torch.empty_like(x_example) for _ in range(info.num_outputs + info.num_states)
]
if T == 0:
return tuple(outputs)
grid = _make_grid(NCL)
f[grid](
*args,
*outputs,
T=T,
NCL=NCL,
dtype=type_dict[dtype],
)
return tuple(outputs)
[文档]
def flexsn_inference_final_state(f, info: FlexSNInfo, *args) -> tuple:
"""Run the inference kernel and materialize final states.
**API Language:**
:ref:`中文 <flexsn_inference_final_state-cn>` | :ref:`English <flexsn_inference_final_state-en>`
----
.. _flexsn_inference_final_state-cn:
* **中文**
flexsn inference final state 函数
:param f: EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
:param info: EN: FlexSN metadata. Chinese: FlexSN 元信息。
:param args: EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
:return: EN: Output sequences followed by final states. When ``T == 0``, output sequences are empty, provided initial states are cloned, and missing states are zero-filled. Chinese: 输出序列后接最终状态;当 ``T == 0`` 时, 输出序列为空, 已提供的初始状态会被克隆, 缺失状态会以零填充。
:rtype: tuple
Chinese:
执行带最终状态物化的 FlexSN 多步推理 kernel。
English:
Execute the FlexSN inference kernel and materialize final states.
----
.. _flexsn_inference_final_state-en:
* **English**
Execute the FlexSN inference kernel and materialize final states.
:param f: Triton kernel callable
:type f: object
:param info: FlexSN metadata
:type info: FlexSNInfo
:param args: Input/state sequences accepted by the kernel
:return: Output sequences followed by final states. When ``T == 0``, output sequences are empty, provided initial states are cloned, and missing states are zero-filled.
:rtype: tuple
"""
x_example = args[0]
T = x_example.shape[0]
NCL = _num_elements_per_step(x_example)
dtype = x_example.dtype
output_seqs = [torch.empty_like(x_example) for _ in range(info.num_outputs)]
init_states = args[info.num_inputs : info.num_inputs + info.num_states]
final_states = [
init_states[i].new_empty(init_states[i].shape)
if i < len(init_states)
else x_example.new_empty(x_example.shape[1:])
for i in range(info.num_states)
]
if T == 0:
final_states = [
(
init_states[i].clone()
if i < len(init_states)
else x_example.new_zeros(x_example.shape[1:])
)
for i in range(info.num_states)
]
return tuple([*output_seqs, *final_states])
grid = _make_grid(NCL)
f[grid](
*args,
*output_seqs,
*final_states,
T=T,
NCL=NCL,
dtype=type_dict[dtype],
)
return tuple([*output_seqs, *final_states])
[文档]
def flexsn_forward(f, info: FlexSNInfo, *args) -> tuple:
"""Run the training forward kernel for FlexSN.
**API Language:**
:ref:`中文 <flexsn_forward-cn>` | :ref:`English <flexsn_forward-en>`
----
.. _flexsn_forward-cn:
* **中文**
flexsn forward 函数
:param f: EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
:param info: EN: FlexSN metadata. Chinese: FlexSN 元信息。
:param args: EN: Input/state sequences accepted by the kernel. Chinese: kernel 接收的输入/状态序列。
:return: EN: Forward outputs plus any saved tensors required by backward. When ``T == 0``, returns empty tensors following the expected templates. Chinese: 前向输出以及 backward 所需的保存张量;当 ``T == 0`` 时, 返回符合模板的空张量。
:rtype: tuple
Chinese:
执行 FlexSN 训练前向 kernel。
English:
Execute the FlexSN training forward kernel.
----
.. _flexsn_forward-en:
* **English**
Execute the FlexSN training forward kernel.
:param f: Triton kernel callable
:type f: object
:param info: FlexSN metadata
:type info: FlexSNInfo
:param args: Input/state sequences accepted by the kernel
:return: Forward outputs plus any saved tensors required by backward. When ``T == 0``, returns empty tensors following the expected templates.
:rtype: tuple
"""
x_example = args[0]
T = x_example.shape[0]
NCL = _num_elements_per_step(x_example)
returns = [torch.empty_like(x_example) for _ in range(info.num_fwd_kernel_returns)]
dtype = x_example.dtype
if T == 0:
return tuple(returns)
grid = _make_grid(NCL)
f[grid](
*args,
*returns,
T=T,
NCL=NCL,
dtype=type_dict[dtype],
)
return tuple(returns)
[文档]
def flexsn_backward(
f,
info: FlexSNInfo,
*args,
input_templates: Optional[Tuple[torch.Tensor, ...]] = None,
state_templates: Optional[Tuple[torch.Tensor, ...]] = None,
) -> tuple:
"""Run the training backward kernel for FlexSN.
**API Language:**
:ref:`中文 <flexsn_backward-cn>` | :ref:`English <flexsn_backward-en>`
----
.. _flexsn_backward-cn:
* **中文**
flexsn backward 函数
:param f: EN: Triton kernel callable. Chinese: Triton kernel 可调用对象。
:param info: EN: FlexSN metadata. Chinese: FlexSN 元信息。
:param args: EN: Gradients followed by any saved tensors accepted by the kernel.
:type args: tuple
:param input_templates: EN: Per-input-sequence templates used to allocate input
:type input_templates: Optional[Tuple[torch.Tensor, ...]]
:param state_templates: EN: Per-initial-state templates used to allocate initial
:type state_templates: Optional[Tuple[torch.Tensor, ...]]
:return: EN: Gradients for inputs and initial states. When ``T == 0`` or all
:rtype: tuple
Chinese:
执行 FlexSN 训练反向 kernel。
English:
Execute the FlexSN training backward kernel.
The leading ``info.num_outputs + info.num_states`` entries correspond to
output/state-sequence gradients, and the remaining entries are saved tensors
from the forward pass. Chinese: kernel 接收的梯度与保存张量。前
``info.num_outputs + info.num_states`` 个参数对应输出/状态序列梯度, 其余参数
为前向保存张量。
gradients. They are required when all incoming gradients are ``None`` and
ensure that state-only cores still return correctly shaped input-sequence
gradients. When omitted, the fallback allocation path is only valid for
single-input cores whose output-sequence gradients are present, because
it infers the input gradient shape from the first non-``None``
output-sequence gradient. Chinese: 每个输入序列的模板, 用于分配输入梯度。
当所有传入梯度都为 ``None`` 时必须提供, 并确保仅有状态输出的 core 仍返回
形状正确的输入梯度。若省略该参数, 当前仅支持单输入且存在输出序列梯度的 core,
因为回退路径会从第一个非 ``None`` 的输出序列梯度推断输入梯度形状。
state gradients. When provided, the returned state gradients preserve the
original initial-state shapes instead of inferring them from state-sequence
gradients. Chinese: 每个初始状态的模板, 用于分配初始状态梯度。提供后, 返回的
状态梯度会保持初始状态原始形状, 而不是从状态序列梯度中推断。
incoming gradients are ``None``, returns zero-filled gradients that follow
the provided templates. Chinese: 输入与初始状态的梯度;当 ``T == 0`` 或所有传入
梯度都为 ``None`` 时, 返回符合模板的零梯度。
----
.. _flexsn_backward-en:
* **English**
Flexsn Backward function
:param args: EN: Gradients followed by any saved tensors accepted by the kernel.
:param input_templates: EN: Per-input-sequence templates used to allocate input
:param state_templates: EN: Per-initial-state templates used to allocate initial
:type args: tuple
:type input_templates: Optional[Tuple[torch.Tensor, ...]]
:type state_templates: Optional[Tuple[torch.Tensor, ...]]
:return: EN: Gradients for inputs and initial states. When ``T == 0`` or all
:rtype: tuple
"""
required_grad_count = info.num_outputs + info.num_states
grad_output_args = args[:required_grad_count]
grad_output_example = _first_non_none_tensor(grad_output_args[: info.num_outputs])
grad_example = _first_non_none_tensor(grad_output_args)
if input_templates is None:
if grad_example is None:
raise ValueError(
"input_templates are required when all incoming FlexSN gradients are None"
)
if info.num_inputs != 1:
raise ValueError(
"input_templates are required when FlexSN has multiple input sequences"
)
if grad_output_example is None:
raise ValueError(
"input_templates are required when FlexSN output-sequence gradients "
"are all None"
)
input_templates = tuple(grad_output_example for _ in range(info.num_inputs))
if len(input_templates) != info.num_inputs:
raise ValueError(
"input_templates must provide one template per FlexSN input sequence"
)
if state_templates is not None and len(state_templates) != info.num_states:
raise ValueError(
"state_templates must provide one template per FlexSN initial state"
)
if grad_example is None:
if state_templates is None and info.num_states > 0:
raise ValueError(
"state_templates are required when all incoming FlexSN gradients are None"
)
grad_inputs = [torch.zeros_like(template) for template in input_templates]
if state_templates is not None:
grad_inputs.extend(
torch.zeros_like(template) for template in state_templates
)
return tuple(grad_inputs)
T = grad_example.shape[0]
NCL = _num_elements_per_step(grad_example)
grad_inputs = [
(
torch.zeros_like(input_templates[i])
if T == 0
else torch.empty_like(input_templates[i])
)
for i in range(info.num_inputs)
]
grad_state_seq_examples = grad_output_args[
info.num_outputs : info.num_outputs + info.num_states
]
if state_templates is None and any(
grad is None for grad in grad_state_seq_examples
):
raise ValueError(
"state_templates are required when any incoming FlexSN "
"state-sequence gradient is None"
)
grad_kernel_args = [
grad if grad is not None else torch.zeros_like(grad_example)
for grad in grad_output_args
]
# State-sequence gradients include the leading time dimension. The wrapper
# returns gradients for the initial states, so their templates are shape[1:].
grad_inputs += [
_allocate_state_grad(
i,
T,
state_templates,
grad_state_seq_examples,
grad_example,
)
for i in range(info.num_states)
]
dtype = grad_example.dtype
if T == 0:
return tuple(grad_inputs)
grid = _make_grid(NCL)
f[grid](
*grad_kernel_args,
*args[required_grad_count:],
*grad_inputs,
T=T,
NCL=NCL,
dtype=type_dict[dtype],
)
return tuple(grad_inputs)