spikingjelly.activation_based.triton_kernel.flexsn.info 源代码

from collections import namedtuple

import torch.fx as fx

__all__ = ["FlexSNInfo", "extract_info"]


FlexSNInfo = namedtuple(
    typename="FlexSNInfo",
    field_names=[
        "num_inputs",
        "num_outputs",
        "num_states",
        "fwd_core_args",
        "fwd_core_returns",
        "fwd_core_recipients",
        "fwd_kernel_returns",
        "num_fwd_kernel_returns",
        "c2k_return_mapping",
    ],
)


[文档] def extract_info( fwd_graph: fx.Graph, num_inputs: int = 1, num_states: int = 0, num_outputs: int = 1, ) -> FlexSNInfo: r""" **API Language:** :ref:`中文 <extract_info-cn>` | :ref:`English <extract_info-en>` ---- .. _extract_info-cn: * **中文** 从前向计算图中提取信息。前向图应具有以下签名: ``[*inputs, *states] -> [*outputs, *states, *intermediates]`` 提取的信息包括: * fwd_core_args: 核心参数 * fwd_core_returns: 前向图的返回值名称 * fwd_core_recipients: 接收核心返回值的变量名 * fwd_kernel_returns: 前向 kernel 的返回值名称(无重复) * num_fwd_kernel_returns: fwd_kernel_returns 的长度 * c2k_return_mapping: 中间结果与 kernel 返回值之间的映射 :param fwd_graph: 前向计算图 :type fwd_graph: fx.Graph :param num_inputs: 输入数量,默认为 1 :type num_inputs: int :param num_states: 状态数量,默认为 0 :type num_states: int :param num_outputs: 输出数量,默认为 1 :type num_outputs: int :return: 提取的 FlexSN 元信息 :rtype: FlexSNInfo ---- .. _extract_info-en: * **English** Extract useful information from the forward graph. The forward graph should have the following signature: ``[*inputs, *states] -> [*outputs, *states, *intermediates]`` The extracted information includes: * fwd_core_args: the core input argument names * fwd_core_returns: the return value names of the forward graph * fwd_core_recipients: the variable names receiving the core return values * fwd_kernel_returns: the forward kernel return value names (no duplicates) * num_fwd_kernel_returns: the length of fwd_kernel_returns * c2k_return_mapping: mapping from intermediate results to kernel returns :param fwd_graph: The forward computational graph :type fwd_graph: fx.Graph :param num_inputs: Number of inputs. Default: 1 :type num_inputs: int :param num_states: Number of states. Default: 0 :type num_states: int :param num_outputs: Number of outputs. Default: 1 :type num_outputs: int :return: The extracted FlexSN metadata :rtype: FlexSNInfo """ fwd_core_args = [n.name for n in fwd_graph.find_nodes(op="placeholder")] fwd_core_returns = [] for n in fwd_graph.find_nodes(op="output"): for a in n.args[0]: fwd_core_returns.append(a.name) num_args = num_inputs + num_states num_outputs_states = num_outputs + num_states assert len(fwd_core_args) == num_args assert len(fwd_core_returns) >= num_outputs_states symbols = {} # varname in core -> varname in kernel fwd_kernel_returns = [] fwd_core_recipients = [] for i, s in enumerate(fwd_core_returns[:num_outputs]): # 1. outputs symbols[s] = f"s{i}" fwd_core_recipients.append(f"s{i}") fwd_kernel_returns.append(f"s{i}") for i, v in enumerate( fwd_core_returns[num_outputs:num_outputs_states] ): # 2. states symbols[v] = f"v{i}" fwd_core_recipients.append(f"v{i}") fwd_kernel_returns.append(f"v{i}") # states are also returned by kernel n = 0 c2k_return_mapping = [] for ret in fwd_core_returns[num_outputs_states:]: # 3. intermediates if ret in symbols: # duplicated core return detected fwd_core_recipients.append("_") # omit the return value # if ret is in symbols, symbols[ret] must be in fwd_kernel_returns else: # not duplicated symbols[ret] = f"res{n}_f" fwd_core_recipients.append(f"res{n}_f") fwd_kernel_returns.append(f"res{n}_f") n += 1 idx = fwd_kernel_returns.index(symbols[ret]) # locate the symbol c2k_return_mapping.append(idx) return FlexSNInfo( num_inputs=num_inputs, num_outputs=num_outputs, num_states=num_states, fwd_core_args=fwd_core_args, fwd_core_returns=fwd_core_returns, fwd_core_recipients=fwd_core_recipients, fwd_kernel_returns=fwd_kernel_returns, num_fwd_kernel_returns=len(fwd_kernel_returns), c2k_return_mapping=c2k_return_mapping, )