"""FlexSN time-step scan as a HigherOrderOperator.
Current progress:
M1:
* HOP definition with an eager Python time-step loop impl.
* Eager autograd works via the natural computation graph (``x[t]`` indexing
and ``torch.stack`` are differentiable, so the per-step ``core_fn`` graph
is correctly chained through time). Verified with ``gradcheck``.
M2:
* AOTAutograd tracing (``torch.fx.experimental.proxy_tensor.make_fx`` /
``torch._functorch.aot_autograd.aot_function``) works by unrolling the scan
into T copies of ``core_fn``'s aten ops.
M3:
* ``FlexSN(backend="hop")`` is available as an explicit backend.
* Dynamo recognizes ``flex_sn_scan`` via a compatibility registration and can
rewrite the call into a HOP node with a traced ``GraphModule`` body.
* ``torch.compile(fullgraph=True)`` for the HOP backend is verified on the
Linux CI/server environment, including tensor lifted freevars/closures.
M4:
* ``lowerable_scan`` re-expresses the FlexSN step function through PyTorch's
built-in ``torch.ops.higher_order.scan`` when that API is available.
* It is kept as an explicit experimental helper for investigating a
single-scan-node forward path instead of fully unrolling the body.
* ``lowerable_while_loop_scan`` provides an alternative experimental forward
path based on ``torch.ops.higher_order.while_loop``. On the Linux validation
environment, its ``torch.compile(fullgraph=True) + no_grad`` path is working
after switching to fixed-shape queue carries instead of ``x[t]`` indexing.
* The experimental while-loop path is wired into ``FlexSN(backend="hop")`` via
``SJ_ENABLE_EXPERIMENTAL_LOWERABLE_WHILE_LOOP=1`` for compile-time forward
evaluation, and has been validated on:
- a single FlexSN layer,
- ``Linear -> FlexSN -> Linear``,
- ``SpikingVGG`` forward inference.
Current limitations:
* The custom Dynamo registration in this file is still a compatibility shim,
not a true in-tree ``BaseHOP`` integration.
* ``lowerable_scan`` is currently an experimental helper, not the default
compiled path. In the PyTorch versions we validate against, fake/proxy/export
handling for this out-of-tree scan shape is not yet stable enough to enable
it by default.
* Training and autograd still use the existing eager/unrolled path.
* The current while-loop lowering is functionally correct for the validated
forward ``no_grad`` cases, but it is not yet faster than the current
``backend="inductor"`` custom-op compile path on the server benchmark.
* A true first-class Inductor lowering for ``flex_sn_scan`` itself does not
exist yet; the current "less unrolled" path relies on PyTorch's built-in
scan / while_loop decomposition.
Usage::
from spikingjelly.activation_based.triton_kernel.flexsn import flex_sn_scan
# inputs_seq: tuple of T-leading tensors, e.g. shape [T, N, ...]
# init_states: tuple of per-step state tensors, e.g. shape [N, ...]
# returns: (*output_seqs, *state_seqs) — each with shape [T, ...]
result = flex_sn_scan(
core_fn, num_inputs, num_states, num_outputs, *inputs_seq, *init_states
)
Captured tensor freevars from ``core_fn`` are appended after the
``[*inputs_seq, *init_states]`` segment when Dynamo rewrites the HOP call.
"""
from __future__ import annotations
import inspect
import warnings
from typing import Callable, Optional, Tuple, Union
import torch
import torch.utils._pytree as pytree
from torch._ops import HigherOrderOperator
try:
from torch._higher_order_ops.scan import scan_op as _torch_scan_op
except (ImportError, AttributeError):
_torch_scan_op = None
try:
from torch._higher_order_ops.while_loop import while_loop as _torch_while_loop
except (ImportError, AttributeError):
_torch_while_loop = None
__all__ = [
"dynamo_hop_available",
"eager_scan",
"eager_scan_final_state",
"flex_sn_scan",
"FlexSNScan",
"lowerable_scan",
"lowerable_scan_available",
"lowerable_scan_final_state",
"lowerable_while_loop_scan",
"lowerable_while_loop_available",
"lowerable_while_loop_scan_final_state",
]
[文档]
class FlexSNScan(HigherOrderOperator):
"""HOP that runs a user-defined single-step ``core`` function over the
**API Language:**
:ref:`中文 <FlexSNScan-cn>` | :ref:`English <FlexSNScan-en>`
----
.. _FlexSNScan-cn:
* **中文**
FlexSN 可微扫描操作(differentiable scanning operation)的高层封装。
提供与 PyTorch 原生 ``scan`` 操作兼容的扫描函数,支持在脉冲神经网络中
高效执行可微的时间维扫描计算。包含 eager 模式和可降图(lowerable)模式,
可根据上下文自动选择合适的执行后端。其核心可调用对象需遵循
``[*inputs, *states] -> [*outputs, *states, *intermediates]`` 签名。
:rtype: None
leading time dimension of its inputs.
The HOP is invoked with a flat argument list so that Dynamo / AOTAutograd
can treat it uniformly. Shapes/semantics:
* ``core_fn``: callable with signature
``(*step_inputs, *states) -> (*step_outputs, *updated_states)``.
* ``num_inputs`` / ``num_states`` / ``num_outputs``: int literals used to
partition the flat tensor args.
* ``flat_args``: first ``num_inputs`` tensors are input sequences with
leading time dim ``T``; the next ``num_states`` tensors are initial
states (no time dim); any remaining tensors are lifted freevars that are
passed through to ``core_fn`` unchanged at every time step.
Return: ``num_outputs`` output sequences followed by ``num_states`` state
sequences, all stacked along the leading time dim.
----
.. _FlexSNScan-en:
* **English**
Flexsnscan function
:return: None
:rtype: None
"""
def __init__(self) -> None:
super().__init__("flex_sn_scan")
def __call__(
self,
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Invoke the FlexSN scan HigherOrderOperator.
Chinese:
调用 FlexSN scan HigherOrderOperator。
English:
Invoke the FlexSN scan HigherOrderOperator with flattened
input-sequence, initial-state, and lifted tensor arguments.
:param core_fn: EN: Single-step core callable with signature
``(*step_inputs, *states, *lifted_args)``.
Chinese: 单步 ``core`` 可调用对象, 签名为
``(*step_inputs, *states, *lifted_args)``。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
Chinese: 带时间维 ``T`` 的输入序列数量。
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors without a time
dimension. Chinese: 不带时间维的初始状态张量数量。
:type num_states: int
:param num_outputs: EN: Number of per-step outputs produced by
``core_fn``. Chinese: ``core_fn`` 每个时间步产生的输出数量。
:type num_outputs: int
:param flat_args: EN: Flattened tensor arguments: first the
``num_inputs`` input sequences ``[T, ...]``, then the
``num_states`` initial states, then any lifted tensor freevars.
Chinese: 展平后的张量参数: 先是 ``num_inputs`` 个输入序列 ``[T, ...]``,
再是 ``num_states`` 个初始状态, 最后是提升出来的张量自由变量。
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional output templates used when
``T == 0`` to materialize empty output sequences without executing
``core_fn``. Each item is ``(shape, dtype)`` or
``(shape, dtype, device)``; when ``device`` is omitted, the runtime
device follows the first input sequence. Chinese: 可选输出模板, 在
``T == 0`` 时用于在不执行 ``core_fn`` 的情况下构造空输出序列。每个模板
为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时, 运行时设备跟随第一个输入序列。
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
state sequences, all stacked along the leading time dimension.
Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回 ``num_states`` 个
状态序列, 均沿首个时间维进行堆叠。
:rtype: Tuple[torch.Tensor, ...]
"""
return super().__call__(
core_fn,
num_inputs,
num_states,
num_outputs,
*flat_args,
output_template_specs=output_template_specs,
)
flex_sn_scan = FlexSNScan()
_DYNAMO_HOP_REGISTERED = False
OutputTemplateSpec = Union[
Tuple[Tuple[int, ...], torch.dtype],
Tuple[Tuple[int, ...], torch.dtype, torch.device],
]
OutputTemplateSpecs = Tuple[OutputTemplateSpec, ...]
def _as_tuple(outputs):
if isinstance(outputs, torch.Tensor):
return (outputs,)
return tuple(outputs)
def _empty_outputs_from_template(
input_seqs: Tuple[torch.Tensor, ...],
num_outputs: int,
output_template_specs: Optional[OutputTemplateSpecs],
) -> Tuple[torch.Tensor, ...]:
if num_outputs == 0:
return ()
if output_template_specs is None:
raise ValueError(
"FlexSN HOP empty scans require output_template_specs so output "
"shapes and dtypes can be built without executing core_fn."
)
if len(output_template_specs) != num_outputs:
raise ValueError(
f"expected {num_outputs} output template specs, got "
f"{len(output_template_specs)}"
)
outputs = []
for spec in output_template_specs:
if len(spec) == 2:
shape, dtype = spec
device = input_seqs[0].device
else:
shape, dtype, device = spec
if device == input_seqs[0].device:
outputs.append(input_seqs[0].new_empty((0, *shape), dtype=dtype))
else:
outputs.append(torch.empty((0, *shape), dtype=dtype, device=device))
return tuple(outputs)
def _flatten_dynamo_body_result(value) -> Tuple[object, ...]:
if isinstance(value, torch.Tensor):
return (value,)
if isinstance(value, (tuple, list)):
return tuple(
leaf for item in value for leaf in _flatten_dynamo_body_result(item)
)
variable_items = getattr(value, "items", None)
if isinstance(variable_items, (tuple, list)):
return tuple(
leaf
for item in variable_items
for leaf in _flatten_dynamo_body_result(item)
)
return (value,)
def _dynamo_leaf_example_value(value):
if isinstance(value, torch.Tensor):
return value
as_proxy = getattr(value, "as_proxy", None)
if callable(as_proxy):
try:
proxy = as_proxy()
except Exception:
return None
node = getattr(proxy, "node", None)
meta = getattr(node, "meta", None)
if isinstance(meta, dict):
return meta.get("example_value")
return None
def _output_template_specs_from_dynamo_body_result(
body_result,
num_outputs: int,
) -> Optional[OutputTemplateSpecs]:
leaves = _flatten_dynamo_body_result(body_result)
if len(leaves) < num_outputs:
return None
specs = []
for leaf in leaves[:num_outputs]:
example_value = _dynamo_leaf_example_value(leaf)
if not isinstance(example_value, torch.Tensor):
return None
specs.append((tuple(example_value.shape), example_value.dtype))
return tuple(specs)
[文档]
def lowerable_scan_available() -> bool:
"""Report whether PyTorch's built-in ``scan`` HOP is available.
**API Language:**
:ref:`中文 <lowerable_scan_available-cn>` | :ref:`English <lowerable_scan_available-en>`
----
.. _lowerable_scan_available-cn:
* **中文**
lowerable scan available 函数
:return: EN: ``True`` when ``torch.ops.higher_order.scan`` is available;
:rtype: bool
Chinese:
返回当前环境是否提供 PyTorch 内置 ``scan`` HOP。
English:
Return whether the current environment exposes PyTorch's built-in
``scan`` higher-order operator.
otherwise ``False``. Chinese: 若 ``torch.ops.higher_order.scan`` 可用则
返回 ``True``,否则返回 ``False``。
----
.. _lowerable_scan_available-en:
* **English**
Lowerable Scan Available function
:return: EN: ``True`` when ``torch.ops.higher_order.scan`` is available;
:rtype: bool
"""
return _torch_scan_op is not None
[文档]
def dynamo_hop_available() -> bool:
"""Report whether the FlexSN Dynamo HOP registration succeeded.
**API Language:**
:ref:`中文 <dynamo_hop_available-cn>` | :ref:`English <dynamo_hop_available-en>`
----
.. _dynamo_hop_available-cn:
* **中文**
dynamo hop available 函数
:return: EN: ``True`` when the Dynamo compatibility shim for
:rtype: bool
Chinese:
返回 FlexSN 的 Dynamo HOP 注册是否成功。
English:
Return whether the FlexSN-specific Dynamo HigherOrderOperator
registration has been installed successfully.
``flex_sn_scan`` is registered; otherwise ``False``. Chinese:
当 ``flex_sn_scan`` 的 Dynamo 兼容注册已完成时返回 ``True``,否则返回
``False``。
----
.. _dynamo_hop_available-en:
* **English**
Dynamo Hop Available function
:return: EN: ``True`` when the Dynamo compatibility shim for
:rtype: bool
"""
return _DYNAMO_HOP_REGISTERED
[文档]
def lowerable_while_loop_available() -> bool:
"""Report whether PyTorch's built-in ``while_loop`` HOP is available.
**API Language:**
:ref:`中文 <lowerable_while_loop_available-cn>` | :ref:`English <lowerable_while_loop_available-en>`
----
.. _lowerable_while_loop_available-cn:
* **中文**
lowerable while loop available 函数
:return: EN: ``True`` when ``torch.ops.higher_order.while_loop`` is
:rtype: bool
Chinese:
返回当前环境是否提供 PyTorch 内置 ``while_loop`` HOP。
English:
Return whether the current environment exposes PyTorch's built-in
``while_loop`` higher-order operator.
available; otherwise ``False``. Chinese: 若
``torch.ops.higher_order.while_loop`` 可用则返回 ``True``,否则返回
``False``。
----
.. _lowerable_while_loop_available-en:
* **English**
Lowerable While Loop Available function
:return: EN: ``True`` when ``torch.ops.higher_order.while_loop`` is
:rtype: bool
"""
return _torch_while_loop is not None
def _callable_positional_arg_range(
fn: Callable,
) -> Optional[Tuple[int, Optional[int]]]:
target = fn.forward if isinstance(fn, torch.nn.Module) else fn
try:
signature = inspect.signature(target)
except (TypeError, ValueError):
return None
for parameter in signature.parameters.values():
if (
parameter.kind == inspect.Parameter.KEYWORD_ONLY
and parameter.default is inspect.Parameter.empty
):
return None
min_required = 0
positional_capacity = 0
for parameter in signature.parameters.values():
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
return min_required, None
if parameter.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
positional_capacity += 1
if parameter.default is inspect.Parameter.empty:
min_required += 1
return min_required, positional_capacity
def _callable_accepts_positional_args(fn: Callable, n_args: int) -> bool | None:
arg_range = _callable_positional_arg_range(fn)
if arg_range is None:
target = fn.forward if isinstance(fn, torch.nn.Module) else fn
try:
signature = inspect.signature(target)
except (TypeError, ValueError):
return None
if any(
parameter.kind == inspect.Parameter.KEYWORD_ONLY
and parameter.default is inspect.Parameter.empty
for parameter in signature.parameters.values()
):
return False
return None
min_required, capacity = arg_range
if n_args < min_required:
return False
if capacity is None:
return True
return n_args <= capacity
def _reorder_placeholders_to_canonical_args(
graph: torch.fx.Graph, canonical_arg_names: Tuple[str, ...]
) -> Tuple[torch.fx.Node, ...]:
placeholders = [node for node in graph.nodes if node.op == "placeholder"]
if not placeholders:
return ()
by_name = {node.name: node for node in placeholders}
ordered = [by_name[name] for name in canonical_arg_names if name in by_name]
ordered.extend(node for node in placeholders if node not in ordered)
if ordered != placeholders:
first_non_placeholder = next(
(node for node in graph.nodes if node.op != "placeholder"), None
)
if first_non_placeholder is not None:
for node in ordered:
first_non_placeholder.prepend(node)
return tuple(node for node in graph.nodes if node.op == "placeholder")
def _check_lifted_arg_arity(
core_fn: Callable,
num_inputs: int,
num_states: int,
lifted_args: Tuple[torch.Tensor, ...],
*,
skip_check: bool = False,
) -> None:
if skip_check:
return
expected = num_inputs + num_states
total = expected + len(lifted_args)
accepts = _callable_accepts_positional_args(core_fn, total)
if accepts is False:
raise ValueError(
f"flex_sn_scan expected {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {total}"
)
[文档]
def eager_scan(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run the FlexSN scan with an eager Python time-step loop.
**API Language:**
:ref:`中文 <eager_scan-cn>` | :ref:`English <eager_scan-en>`
----
.. _eager_scan-cn:
* **中文**
eager scan 函数
:param core_fn: EN: Single-step core callable with signature
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
Chinese:
通过 Python 时间步循环执行 FlexSN scan。
English:
Run the FlexSN scan with an eager Python loop. This helper is reused by
both the HOP eager implementation and the Dynamo-friendly
``backend="inductor"`` path, so ``torch.compile`` can trace the
unrolled loop into a standard FX graph.
When ``T == 0``, ``output_template_specs`` must describe the output
sequence shapes/dtypes so empty outputs can be materialized without
executing ``core_fn``.
``(*step_inputs, *states, *lifted_args)``.
Chinese: 单步 ``core`` 可调用对象, 签名为
``(*step_inputs, *states, *lifted_args)``。
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to build empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
state sequences. Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回
``num_states`` 个状态序列。
----
.. _eager_scan-en:
* **English**
Eager Scan function
:param core_fn: EN: Single-step core callable with signature
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
"""
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
inputs_seq = flat_args[:num_inputs]
states = list(flat_args[num_inputs:expected])
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and inputs_seq[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
T = inputs_seq[0].shape[0]
for i, x in enumerate(inputs_seq):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
inputs_seq, num_outputs, output_template_specs
)
empty_states = tuple(state.new_empty((0, *state.shape)) for state in states)
return (*empty_outputs, *empty_states)
output_buffers = [[] for _ in range(num_outputs)]
state_buffers = [[] for _ in range(num_states)]
for t in range(T):
step_inputs = tuple(x[t] for x in inputs_seq)
results = core_fn(*step_inputs, *states, *lifted_args)
if not isinstance(results, (tuple, list)):
results = (results,)
if len(results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(results)} values, "
f"expected num_outputs + num_states "
f"= {num_outputs + num_states}"
)
outputs = results[:num_outputs]
states = list(results[num_outputs:])
for i, y in enumerate(outputs):
output_buffers[i].append(y)
for i, s in enumerate(states):
state_buffers[i].append(s)
output_seqs = tuple(torch.stack(buf, dim=0) for buf in output_buffers)
state_seqs = tuple(torch.stack(buf, dim=0) for buf in state_buffers)
return (*output_seqs, *state_seqs)
flex_sn_scan.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd)(eager_scan)
# HOPs route every tensor call through the Autograd dispatch key even when
# ``requires_grad=False``. Re-entering ``eager_scan`` from Autograd is
# correct: the inner ``core_fn`` invocations build a standard per-timestep
# autograd graph which is chained via ``torch.stack``/indexing, giving a
# full BPTT graph. AOTAutograd (``aot_function`` / ``make_fx``) traces this
# graph natively by unrolling; see module docstring.
flex_sn_scan.py_impl(torch._C.DispatchKey.Autograd)(eager_scan)
[文档]
def eager_scan_final_state(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run the eager scan and return output sequences plus final states.
**API Language:**
:ref:`中文 <eager_scan_final_state-cn>` | :ref:`English <eager_scan_final_state-en>`
----
.. _eager_scan_final_state-cn:
* **中文**
eager scan final state 函数
:param core_fn: EN: Single-step core callable. Chinese: 单步 ``core`` 可调用对象。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
Chinese:
执行 eager scan, 返回输出序列以及最终状态。
English:
Variant of :func:`eager_scan` used when ``store_state_seqs=False`` so
the HOP backend does not materialize full state sequences only to
discard them. When ``T == 0``, ``output_template_specs`` is used to
build empty output sequences and the provided initial states are cloned
into the returned final states.
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to materialize empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回最终状态。
----
.. _eager_scan_final_state-en:
* **English**
Eager Scan Final State function
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
"""
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
inputs_seq = flat_args[:num_inputs]
states = list(flat_args[num_inputs:expected])
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and inputs_seq[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
T = inputs_seq[0].shape[0]
for i, x in enumerate(inputs_seq):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
inputs_seq, num_outputs, output_template_specs
)
final_states = tuple(s.clone() for s in states)
return (*empty_outputs, *final_states)
output_buffers = [[] for _ in range(num_outputs)]
for t in range(T):
step_inputs = tuple(x[t] for x in inputs_seq)
results = core_fn(*step_inputs, *states, *lifted_args)
if not isinstance(results, (tuple, list)):
results = (results,)
if len(results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(results)} values, "
f"expected num_outputs + num_states "
f"= {num_outputs + num_states}"
)
outputs = results[:num_outputs]
states = list(results[num_outputs:])
for i, y in enumerate(outputs):
output_buffers[i].append(y)
output_seqs = tuple(torch.stack(buf, dim=0) for buf in output_buffers)
return (*output_seqs, *tuple(states))
[文档]
def lowerable_scan(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run FlexSN scan through PyTorch's built-in ``scan`` HOP.
**API Language:**
:ref:`中文 <lowerable_scan-cn>` | :ref:`English <lowerable_scan-en>`
----
.. _lowerable_scan-cn:
* **中文**
lowerable scan 函数
:param core_fn: EN: Single-step core callable. Chinese: 单步 ``core`` 可调用对象。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
Chinese:
通过 PyTorch 内置 ``scan`` HOP 执行 FlexSN scan。
English:
Keep the FlexSN scan as a single higher-order op under tracing so
downstream compilers can lower it as a loop instead of fully unrolling
the body ``T`` times. This remains an experimental helper rather than
the default compiled path.
When ``T == 0``, ``output_template_specs`` must contain
``num_outputs`` items shaped as ``(shape, dtype)`` or
``(shape, dtype, device)``. Runtime devices default to the first input
sequence when omitted.
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to materialize empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
state sequences. Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回
``num_states`` 个状态序列。
----
.. _lowerable_scan-en:
* **English**
Lowerable Scan function
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
"""
if _torch_scan_op is None:
raise RuntimeError("PyTorch scan HOP is unavailable in this environment")
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
input_seqs = flat_args[:num_inputs]
init_states = flat_args[num_inputs:expected]
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and input_seqs[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
T = input_seqs[0].shape[0]
for i, x in enumerate(input_seqs):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
input_seqs, num_outputs, output_template_specs
)
empty_states = tuple(
state.new_empty((0, *state.shape)) for state in init_states
)
return (*empty_outputs, *empty_states)
def combine_fn(carry, step_inputs, additional_inputs):
carry = tuple(carry)
step_inputs = tuple(step_inputs)
additional_inputs = tuple(additional_inputs)
results = core_fn(*step_inputs, *carry, *additional_inputs)
results = (
tuple(results) if not isinstance(results, torch.Tensor) else (results,)
)
if len(results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(results)} values, "
f"expected num_outputs + num_states = {num_outputs + num_states}"
)
outputs = list(results[:num_outputs])
next_states = list(results[num_outputs:])
output_states = [state.clone() for state in next_states]
return next_states, [*outputs, *output_states]
leaves_init = list(init_states)
leaves_xs = list(input_seqs)
_, spec_init = pytree.tree_flatten(leaves_init)
_, spec_xs = pytree.tree_flatten(leaves_xs)
def wrapped_combine_fn(*args):
expected_args = len(leaves_init) + len(leaves_xs) + len(lifted_args)
if len(args) != expected_args:
raise ValueError(
f"scan combine_fn expected {expected_args} flattened args, got {len(args)}"
)
carry = pytree.tree_unflatten(args[: len(leaves_init)], spec_init)
xs = pytree.tree_unflatten(
args[len(leaves_init) : len(leaves_init) + len(leaves_xs)],
spec_xs,
)
additional_inputs = tuple(args[len(leaves_init) + len(leaves_xs) :])
return combine_fn(carry, xs, additional_inputs)
result = _torch_scan_op(
wrapped_combine_fn,
leaves_init,
leaves_xs,
additional_inputs=lifted_args,
)
result = tuple(result)
# PyTorch scan returns final carry first, followed by the stacked outputs.
return result[num_states:]
[文档]
def lowerable_scan_final_state(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run the built-in ``scan`` HOP and return final states only.
**API Language:**
:ref:`中文 <lowerable_scan_final_state-cn>` | :ref:`English <lowerable_scan_final_state-en>`
----
.. _lowerable_scan_final_state-cn:
* **中文**
lowerable scan final state 函数
:param core_fn: EN: Single-step core callable. Chinese: 单步 ``core`` 可调用对象。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
Chinese:
通过内置 ``scan`` HOP 执行 FlexSN, 返回输出序列与最终状态。
English:
Final-state variant of :func:`lowerable_scan`. When ``T == 0``,
``output_template_specs`` materializes empty output sequences and the
initial states are cloned into the returned final states.
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to materialize empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回最终状态。
----
.. _lowerable_scan_final_state-en:
* **English**
Lowerable Scan Final State function
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
"""
if _torch_scan_op is None:
raise RuntimeError("PyTorch scan HOP is unavailable in this environment")
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
input_seqs = flat_args[:num_inputs]
init_states = flat_args[num_inputs:expected]
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and input_seqs[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
T = input_seqs[0].shape[0]
for i, x in enumerate(input_seqs):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
input_seqs, num_outputs, output_template_specs
)
final_states = tuple(s.clone() for s in init_states)
return (*empty_outputs, *final_states)
def combine_fn(carry, step_inputs, additional_inputs):
carry = tuple(carry)
step_inputs = tuple(step_inputs)
additional_inputs = tuple(additional_inputs)
results = core_fn(*step_inputs, *carry, *additional_inputs)
results = (
tuple(results) if not isinstance(results, torch.Tensor) else (results,)
)
if len(results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(results)} values, "
f"expected num_outputs + num_states = {num_outputs + num_states}"
)
outputs = list(results[:num_outputs])
next_states = list(results[num_outputs:])
return next_states, outputs
leaves_init = list(init_states)
leaves_xs = list(input_seqs)
_, spec_init = pytree.tree_flatten(leaves_init)
_, spec_xs = pytree.tree_flatten(leaves_xs)
def wrapped_combine_fn(*args):
expected_args = len(leaves_init) + len(leaves_xs) + len(lifted_args)
if len(args) != expected_args:
raise ValueError(
f"scan combine_fn expected {expected_args} flattened args, got {len(args)}"
)
carry = pytree.tree_unflatten(args[: len(leaves_init)], spec_init)
xs = pytree.tree_unflatten(
args[len(leaves_init) : len(leaves_init) + len(leaves_xs)],
spec_xs,
)
additional_inputs = tuple(args[len(leaves_init) + len(leaves_xs) :])
return combine_fn(carry, xs, additional_inputs)
result = _torch_scan_op(
wrapped_combine_fn,
leaves_init,
leaves_xs,
additional_inputs=lifted_args,
)
result = tuple(result)
# PyTorch scan returns final carry first; keep that as the final states.
final_states = result[:num_states]
output_seqs = result[num_states:]
return (*output_seqs, *final_states)
def _ensure_contiguous(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dim() >= 4:
return tensor.contiguous(memory_format=torch.contiguous_format)
return tensor.contiguous()
def _carry_device(*tensor_groups) -> torch.device:
for group in tensor_groups:
for tensor in group:
return tensor.device
return torch.device("cpu")
def _append_to_tail(buffer: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
return torch.cat(
(buffer[1:], _ensure_contiguous(value).unsqueeze(0)),
dim=0,
)
def _shift_input_queue(queue: torch.Tensor) -> torch.Tensor:
return torch.cat((queue[1:], queue[-1:]), dim=0)
[文档]
def lowerable_while_loop_scan(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run FlexSN scan through PyTorch's built-in ``while_loop`` HOP.
**API Language:**
:ref:`中文 <lowerable_while_loop_scan-cn>` | :ref:`English <lowerable_while_loop_scan-en>`
----
.. _lowerable_while_loop_scan-cn:
* **中文**
lowerable while loop scan 函数
:param core_fn: EN: Single-step core callable. Chinese: 单步 ``core`` 可调用对象。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
Chinese:
通过 PyTorch 内置 ``while_loop`` HOP 执行 FlexSN scan。
English:
Experimental helper for studying whether a first-class loop
representation is a better fit than the current unrolled scan path.
Current while-loop capture does not support symbolic ``x[t]`` indexing
here, so this implementation keeps functional queue buffers.
When ``T == 0``, ``output_template_specs`` materializes empty output
sequences without running ``core_fn``.
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to materialize empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
state sequences. Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回
``num_states`` 个状态序列。
----
.. _lowerable_while_loop_scan-en:
* **English**
Lowerable While Loop Scan function
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by ``num_states``
:rtype: Tuple[torch.Tensor, ...]
"""
if _torch_while_loop is None:
raise RuntimeError("PyTorch while_loop HOP is unavailable in this environment")
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
input_seqs = tuple(flat_args[:num_inputs])
init_states = tuple(flat_args[num_inputs:expected])
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and input_seqs[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
lifted_args = tuple(_ensure_contiguous(arg) for arg in lifted_args)
T = input_seqs[0].shape[0]
for i, x in enumerate(input_seqs):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
input_seqs, num_outputs, output_template_specs
)
empty_states = tuple(
state.new_empty((0, *state.shape)) for state in init_states
)
return (*empty_outputs, *empty_states)
input_seqs = tuple(_ensure_contiguous(seq) for seq in input_seqs)
init_states = tuple(_ensure_contiguous(state) for state in init_states)
first_step_inputs = tuple(_ensure_contiguous(x[0]) for x in input_seqs)
first_results = core_fn(*first_step_inputs, *init_states, *lifted_args)
first_results = (
tuple(first_results)
if not isinstance(first_results, torch.Tensor)
else (first_results,)
)
if len(first_results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(first_results)} values, "
f"expected num_outputs + num_states = {num_outputs + num_states}"
)
first_outputs = tuple(_ensure_contiguous(x) for x in first_results[:num_outputs])
first_states = tuple(_ensure_contiguous(x) for x in first_results[num_outputs:])
output_buffers = tuple(
_append_to_tail(out.new_zeros((T, *out.shape)), out) for out in first_outputs
)
state_buffers = tuple(
_append_to_tail(state.new_zeros((T, *state.shape)), state)
for state in first_states
)
pending_inputs = tuple(_shift_input_queue(seq) for seq in input_seqs)
t0 = torch.tensor(
1,
dtype=torch.int64,
device=_carry_device(first_states, input_seqs, first_outputs),
)
def cond_fn(t, *carry):
return t < T
def body_fn(t, *carry):
pending_seq_end = num_inputs
states_end = pending_seq_end + num_states
outputs_end = states_end + num_outputs
lifted_end = outputs_end + len(lifted_args)
step_input_queues = carry[:pending_seq_end]
states = carry[pending_seq_end:states_end]
outputs_acc = carry[states_end:outputs_end]
lifted = carry[outputs_end:lifted_end]
states_acc = carry[lifted_end:]
step_inputs = tuple(_ensure_contiguous(queue[0]) for queue in step_input_queues)
results = core_fn(*step_inputs, *states, *lifted)
results = (
tuple(results) if not isinstance(results, torch.Tensor) else (results,)
)
if len(results) != len(first_results):
raise ValueError(
f"core returned {len(results)} values at runtime, "
f"expected {len(first_results)}"
)
outputs = tuple(_ensure_contiguous(x) for x in results[:num_outputs])
next_states = tuple(_ensure_contiguous(x) for x in results[num_outputs:])
next_pending_inputs = tuple(
_shift_input_queue(queue) for queue in step_input_queues
)
if len(outputs_acc) != len(outputs):
raise ValueError(
f"core returned {len(outputs)} outputs at runtime, "
f"expected {len(outputs_acc)}"
)
next_output_acc = tuple(
_append_to_tail(outputs_acc[i], outputs[i]) for i in range(len(outputs_acc))
)
if len(states_acc) != len(next_states):
raise ValueError(
f"core returned {len(next_states)} states at runtime, "
f"expected {len(states_acc)}"
)
next_state_acc = tuple(
_append_to_tail(states_acc[i], next_states[i])
for i in range(len(states_acc))
)
return (
t + 1,
*next_pending_inputs,
*next_states,
*next_output_acc,
*lifted,
*next_state_acc,
)
final = _torch_while_loop(
cond_fn,
body_fn,
(
t0,
*pending_inputs,
*first_states,
*output_buffers,
*lifted_args,
*state_buffers,
),
)
final = tuple(final)
pending_seq_end = 1 + num_inputs
states_end = pending_seq_end + num_states
outputs_end = states_end + num_outputs
lifted_end = outputs_end + len(lifted_args)
final_output_buffers = final[states_end:outputs_end]
final_state_buffers = final[lifted_end:]
return (*final_output_buffers, *final_state_buffers)
[文档]
def lowerable_while_loop_scan_final_state(
core_fn: Callable,
num_inputs: int,
num_states: int,
num_outputs: int,
*flat_args: torch.Tensor,
output_template_specs: Optional[OutputTemplateSpecs] = None,
) -> Tuple[torch.Tensor, ...]:
"""Run the while-loop HOP and return output sequences plus final states.
**API Language:**
:ref:`中文 <lowerable_while_loop_scan_final_state-cn>` | :ref:`English <lowerable_while_loop_scan_final_state-en>`
----
.. _lowerable_while_loop_scan_final_state-cn:
* **中文**
lowerable while loop scan final state 函数
:param core_fn: EN: Single-step core callable. Chinese: 单步 ``core`` 可调用对象。
:type core_fn: ``Callable``
:param num_inputs: EN: Number of T-leading input sequences.
:type num_inputs: int
:param num_states: EN: Number of initial-state tensors.
:type num_states: int
:param num_outputs: EN: Number of per-step outputs.
:type num_outputs: int
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:type flat_args: ``torch.Tensor``
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
Chinese:
执行 while-loop HOP, 返回输出序列以及最终状态。
English:
Final-state variant of :func:`lowerable_while_loop_scan`. When
``T == 0``, ``output_template_specs`` is used to build empty output
sequences and the provided initial states are cloned into the returned
final states.
Chinese: 带时间维 ``T`` 的输入序列数量。
Chinese: 初始状态张量数量。
Chinese: 每个时间步输出数量。
tensor freevars. Chinese: 展平后的输入序列、初始状态以及提升出来的张量自由变量。
``(shape, dtype, device)`` templates used to materialize empty output
sequences when ``T == 0``. Omitted devices follow the first input
sequence at runtime. Chinese: 在 ``T == 0`` 时用于构造空输出序列的可选模板,
每项为 ``(shape, dtype)`` 或 ``(shape, dtype, device)``;省略 ``device``
时运行时设备跟随第一个输入序列。
Chinese: 先返回 ``num_outputs`` 个输出序列, 再返回最终状态。
----
.. _lowerable_while_loop_scan_final_state-en:
* **English**
Lowerable While Loop Scan Final State function
:param num_inputs: EN: Number of T-leading input sequences.
:param num_states: EN: Number of initial-state tensors.
:param num_outputs: EN: Number of per-step outputs.
:param flat_args: EN: Flattened input sequences, initial states, then lifted
:param output_template_specs: EN: Optional ``(shape, dtype)`` or
:type core_fn: ``Callable``
:type num_inputs: int
:type num_states: int
:type num_outputs: int
:type flat_args: ``torch.Tensor``
:type output_template_specs: Optional[OutputTemplateSpecs]
:return: EN: ``num_outputs`` output sequences followed by the final states.
:rtype: Tuple[torch.Tensor, ...]
"""
if _torch_while_loop is None:
raise RuntimeError("PyTorch while_loop HOP is unavailable in this environment")
expected = num_inputs + num_states
if len(flat_args) < expected:
raise ValueError(
f"flex_sn_scan expected at least {expected} tensor args "
f"(num_inputs={num_inputs} + num_states={num_states}), "
f"got {len(flat_args)}"
)
if num_inputs == 0:
raise ValueError("flex_sn_scan requires at least one input sequence")
input_seqs = tuple(flat_args[:num_inputs])
init_states = tuple(flat_args[num_inputs:expected])
lifted_args = tuple(flat_args[expected:])
_check_lifted_arg_arity(
core_fn,
num_inputs,
num_states,
lifted_args,
skip_check=(
num_inputs > 0
and input_seqs[0].shape[0] == 0
and isinstance(core_fn, torch.fx.GraphModule)
),
)
lifted_args = tuple(_ensure_contiguous(arg) for arg in lifted_args)
T = input_seqs[0].shape[0]
for i, x in enumerate(input_seqs):
if x.shape[0] != T:
raise ValueError(f"input {i} has leading dim {x.shape[0]}, expected {T}")
if T == 0:
empty_outputs = _empty_outputs_from_template(
input_seqs, num_outputs, output_template_specs
)
return (*empty_outputs, *(s.clone() for s in init_states))
input_seqs = tuple(_ensure_contiguous(seq) for seq in input_seqs)
init_states = tuple(_ensure_contiguous(state) for state in init_states)
first_step_inputs = tuple(_ensure_contiguous(x[0]) for x in input_seqs)
first_results = core_fn(*first_step_inputs, *init_states, *lifted_args)
first_results = (
tuple(first_results)
if not isinstance(first_results, torch.Tensor)
else (first_results,)
)
if len(first_results) != num_outputs + num_states:
raise ValueError(
f"core returned {len(first_results)} values, "
f"expected num_outputs + num_states = {num_outputs + num_states}"
)
first_outputs = tuple(_ensure_contiguous(x) for x in first_results[:num_outputs])
first_states = tuple(_ensure_contiguous(x) for x in first_results[num_outputs:])
output_buffers = tuple(
_append_to_tail(out.new_zeros((T, *out.shape)), out) for out in first_outputs
)
pending_inputs = tuple(_shift_input_queue(seq) for seq in input_seqs)
t0 = torch.tensor(
1,
dtype=torch.int64,
device=_carry_device(first_states, input_seqs, first_outputs),
)
def cond_fn(t, *carry):
return t < T
def body_fn(t, *carry):
pending_seq_end = num_inputs
states_end = pending_seq_end + num_states
step_input_queues = carry[:pending_seq_end]
outputs_end = states_end + num_outputs
lifted_end = outputs_end + len(lifted_args)
states = carry[pending_seq_end:states_end]
outputs_acc = carry[states_end:outputs_end]
lifted = carry[outputs_end:lifted_end]
step_inputs = tuple(_ensure_contiguous(queue[0]) for queue in step_input_queues)
results = core_fn(*step_inputs, *states, *lifted)
results = (
tuple(results) if not isinstance(results, torch.Tensor) else (results,)
)
if len(results) != len(first_results):
raise ValueError(
f"core returned {len(results)} values at runtime, "
f"expected {len(first_results)}"
)
outputs = tuple(_ensure_contiguous(x) for x in results[:num_outputs])
next_states = tuple(_ensure_contiguous(x) for x in results[num_outputs:])
next_pending_inputs = tuple(
_shift_input_queue(queue) for queue in step_input_queues
)
if len(outputs_acc) != len(outputs):
raise ValueError(
f"core returned {len(outputs)} outputs at runtime, "
f"expected {len(outputs_acc)}"
)
next_output_acc = tuple(
_append_to_tail(outputs_acc[i], outputs[i]) for i in range(len(outputs_acc))
)
return (
t + 1,
*next_pending_inputs,
*next_states,
*next_output_acc,
*lifted,
)
final = _torch_while_loop(
cond_fn,
body_fn,
(
t0,
*pending_inputs,
*first_states,
*output_buffers,
*lifted_args,
),
)
final = tuple(final)
pending_seq_end = 1 + num_inputs
states_end = pending_seq_end + num_states
outputs_end = states_end + num_outputs
final_states = final[pending_seq_end:states_end]
final_output_buffers = final[states_end:outputs_end]
return (*final_output_buffers, *final_states)
def _register_dynamo_hop() -> None:
global _DYNAMO_HOP_REGISTERED
try:
from torch._dynamo.variables import higher_order_ops as hop_vars
from torch._dynamo.variables.builder import wrap_fx_proxy
from torch._dynamo.variables.constant import ConstantVariable
from torch._dynamo.variables.functions import (
NestedUserFunctionVariable,
UserFunctionVariable,
)
from torch._dynamo.variables.higher_order_ops import (
TorchHigherOrderOperatorVariable,
make_attr,
speculate_subgraph,
)
from torch._dynamo.variables.tensor import TensorVariable
except (ImportError, ModuleNotFoundError, AttributeError):
return
except Exception as e:
# Import-time registration must never break package import on
# unsupported or drifting Torch internals; warn and leave the HOP
# available through its eager fallback instead.
warnings.warn(
f"FlexSN HOP Dynamo registration failed unexpectedly: {type(e).__name__}: {e}",
stacklevel=2,
)
return
make_descriptor = TorchHigherOrderOperatorVariable.__dict__.get("make")
original_make_is_bound = make_descriptor is None
if make_descriptor is None:
make_descriptor = TorchHigherOrderOperatorVariable.make
make_func = (
make_descriptor.__func__
if isinstance(make_descriptor, (classmethod, staticmethod))
else make_descriptor
)
if getattr(make_func, "_spikingjelly_flexsn_hop", False):
_DYNAMO_HOP_REGISTERED = True
return
original_make = make_descriptor
install_subgraph = getattr(hop_vars, "add_subgraph", None)
if install_subgraph is None:
def install_subgraph(tx, source, name, gm):
return tx.output.install_subgraph(name, gm)
class FlexSNScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
_HOP_NAME = "spikingjelly.flex_sn_scan"
_ALLOW_FALLBACK_TO_EAGER = False
def call_function(self, tx, args, kwargs):
output_template_specs_arg = kwargs.pop("output_template_specs", None)
if kwargs:
raise hop_vars.unimplemented(
"flex_sn_scan only supports output_template_specs as a kwarg"
)
explicit_output_template_specs = None
if output_template_specs_arg is not None:
try:
explicit_output_template_specs = (
output_template_specs_arg.as_python_constant()
)
except Exception as e:
raise hop_vars.unimplemented(
"flex_sn_scan output_template_specs must be a Python constant"
) from e
if len(args) < 4:
raise hop_vars.unimplemented(
"flex_sn_scan expects body_fn, num_inputs, num_states, "
"num_outputs, and tensor arguments"
)
body_fn = args[0]
if not isinstance(
body_fn, (UserFunctionVariable, NestedUserFunctionVariable)
):
raise hop_vars.unimplemented(
"flex_sn_scan expects a user-defined Python function body"
)
const_args = args[1:4]
if not all(isinstance(arg, ConstantVariable) for arg in const_args):
raise hop_vars.unimplemented(
"flex_sn_scan expects num_inputs/num_states/num_outputs to be constants"
)
num_inputs, num_states, num_outputs = [
arg.as_python_constant() for arg in const_args
]
flat_args = args[4:]
expected = num_inputs + num_states
if len(flat_args) < expected:
raise hop_vars.unimplemented(
f"flex_sn_scan expected at least {expected} tensor args, got {len(flat_args)}"
)
if num_inputs == 0:
raise hop_vars.unimplemented(
"flex_sn_scan requires at least one input sequence"
)
if not all(isinstance(arg, TensorVariable) for arg in flat_args):
raise hop_vars.unimplemented(
"flex_sn_scan only supports tensor inputs and states"
)
def _make_step_template(arg: TensorVariable):
example_value = arg.as_proxy().node.meta["example_value"]
if example_value.shape[0] > 0:
return arg.call_method(tx, "__getitem__", [ConstantVariable(0)], {})
shape_without_t = tuple(example_value.shape[1:])
proxy = tx.output.create_proxy(
"call_function",
torch.ops.aten.new_empty.default,
args=(arg.as_proxy(), shape_without_t),
kwargs={},
)
return wrap_fx_proxy(
tx=tx,
proxy=proxy,
example_value=example_value.new_empty(shape_without_t),
)
step_inputs = [_make_step_template(arg) for arg in flat_args[:num_inputs]]
body_args = [*step_inputs, *flat_args[num_inputs:]]
canonical_body_arg_names = tuple(
arg.as_proxy().node.name for arg in body_args
)
speculated = speculate_subgraph(
tx,
body_fn,
body_args,
{},
"flex_sn_scan",
source_target=self.value,
)
if len(speculated) == 4:
(
_body_r,
body_graph,
body_lifted_freevars,
_parent_proxy_map,
) = speculated
elif len(speculated) == 3:
_body_r, body_graph, body_lifted_freevars = speculated
else:
raise hop_vars.unimplemented(
"flex_sn_scan received an unsupported speculate_subgraph result"
)
if hasattr(body_lifted_freevars, "keys"):
lifted_freevars = tuple(body_lifted_freevars.keys())
else:
lifted_freevars = tuple(body_lifted_freevars)
if lifted_freevars and not all(
isinstance(freevar, torch.fx.Proxy) for freevar in lifted_freevars
):
raise hop_vars.unimplemented(
"flex_sn_scan only supports tensor lifted freevars"
)
for freevar in lifted_freevars:
example_value = freevar.node.meta.get("example_value")
if not isinstance(example_value, torch.Tensor):
raise hop_vars.unimplemented(
"flex_sn_scan only supports tensor lifted freevars"
)
placeholders = _reorder_placeholders_to_canonical_args(
body_graph, canonical_body_arg_names
)
placeholder_freevar_names = tuple(
node.name for node in placeholders[len(body_args) :]
)
if placeholder_freevar_names:
freevars_by_name = {
freevar.node.name: freevar for freevar in lifted_freevars
}
missing = [
name
for name in placeholder_freevar_names
if name not in freevars_by_name
]
if missing:
raise hop_vars.unimplemented(
"flex_sn_scan could not map lifted tensor freevars"
)
lifted_freevars = tuple(
freevars_by_name[name] for name in placeholder_freevar_names
)
else:
lifted_freevars = ()
body_gm = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = install_subgraph(tx, self.source, "flex_sn_scan_body", body_gm)
body_node = make_attr(tx, body_name)
output_template_specs = _output_template_specs_from_dynamo_body_result(
_body_r,
num_outputs,
)
if explicit_output_template_specs is not None:
output_template_specs = explicit_output_template_specs
proxy_kwargs = (
{}
if output_template_specs is None
else {"output_template_specs": output_template_specs}
)
proxy = tx.output.create_proxy(
"call_function",
self.value,
args=(
body_node,
num_inputs,
num_states,
num_outputs,
*(arg.as_proxy() for arg in flat_args),
*lifted_freevars,
),
kwargs=proxy_kwargs,
)
body_leaves = _flatten_dynamo_body_result(_body_r)
example_value = []
T = flat_args[0].as_proxy().node.meta["example_value"].shape[0]
for i in range(num_outputs + num_states):
if i >= len(body_leaves):
example_value = None
break
leaf_ev = _dynamo_leaf_example_value(body_leaves[i])
if not isinstance(leaf_ev, torch.Tensor):
example_value = None
break
example_value.append(leaf_ev.new_empty((T, *leaf_ev.shape)))
if example_value is None:
example_value = eager_scan(
body_gm,
num_inputs,
num_states,
num_outputs,
*(arg.as_proxy().node.meta["example_value"] for arg in flat_args),
*(
freevar.node.meta["example_value"]
for freevar in lifted_freevars
),
output_template_specs=output_template_specs,
)
else:
example_value = tuple(example_value)
return wrap_fx_proxy(tx=tx, proxy=proxy, example_value=example_value)
def patched_make(cls, value, source=None, **kwargs):
if value is flex_sn_scan:
return FlexSNScanHigherOrderVariable(value, source, **kwargs)
if isinstance(original_make, classmethod):
return original_make.__func__(cls, value, source=source, **kwargs)
if isinstance(original_make, staticmethod):
return original_make.__func__(value, source=source, **kwargs)
if original_make_is_bound:
original_make_func = getattr(original_make, "__func__", None)
if original_make_func is not None:
return original_make_func(cls, value, source=source, **kwargs)
return original_make(value, source=source, **kwargs)
return original_make(cls, value, source=source, **kwargs)
patched_make._spikingjelly_flexsn_hop = True
# Dynamo does not currently expose a public registry for this HOP hook.
# Patch only the flex_sn_scan dispatch and delegate every other operator
# back to PyTorch's original implementation.
TorchHigherOrderOperatorVariable.make = classmethod(patched_make)
_DYNAMO_HOP_REGISTERED = True
_register_dynamo_hop()