spikingjelly.activation_based.triton_kernel.torch2triton package#

PyTorch to Graph#

spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_inference_graph(fn: Callable, example_inputs: tuple) Graph[源代码]#

Generate an optimized inference FX graph. API Language: 中文 | English


  • 中文

生成推理计算图

参数:
  • fn (Callable) -- EN: Callable to trace. Chinese: 待追踪的可调用对象。

  • example_inputs (tuple) -- EN: Example inputs used for tracing. Chinese: 用于追踪的示例输入。

返回:

EN: Optimized forward FX graph. Chinese: 优化后的前向 FX 图。

返回类型:

torch.fx.Graph

抛出:

ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。

Chinese:

为给定的 PyTorch 函数生成优化后的推理 FX 图。

English:

Generate an optimized inference FX graph for a PyTorch callable.


  • English

Generate inference graph

抛出:

ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。

返回类型:

torch.fx.Graph

spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_forward_and_backward_graph(fn: Callable, example_inputs: tuple, requires_grad: Sequence[bool] | None = None) Tuple[Graph, Graph][源代码]#

Generate optimized forward/backward FX graphs. API Language: 中文 | English


  • 中文

生成前向和反向计算图

参数:
  • fn (Callable) -- EN: Callable to trace. Chinese: 待追踪的可调用对象。

  • example_inputs (tuple) -- EN: Example inputs used for tracing. Chinese: 用于追踪的示例输入。

  • requires_grad (Optional[Sequence[bool]]) -- EN: Optional gradient-requirement flags for each example input. Chinese: 每个示例输入对应的可选求导标志。

返回:

EN: Optimized forward and backward FX graphs. Chinese: 优化后的前向与反向 FX 图。

返回类型:

Tuple[torch.fx.Graph, torch.fx.Graph]

抛出:

ValueError -- EN: Raised when requires_grad length mismatches example_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当 requires_grad 长度与 example_inputs 不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。

Chinese:

为给定的 PyTorch 函数生成优化后的前向与反向 FX 图。

English:

Generate optimized forward and backward FX graphs for a PyTorch callable.


  • English

Generate forward and backward graphs

抛出:

ValueError -- EN: Raised when requires_grad length mismatches example_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当 requires_grad 长度与 example_inputs 不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。

返回类型:

Tuple[torch.fx.Graph, torch.fx.Graph]

Graph to Triton#

spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.generate_triton_code_str(graph: Graph, fn_name: str, verbose: bool = False) Tuple[str, str][源代码]#

Given a fx.Graph, generate its corresponding Triton code string. API Language: 中文 | English


  • 中文

生成Triton代码字符串

返回类型:

None

参数:
  • graph (fx.Graph)

  • fn_name (str) -- name of the original PyTorch function. For generating the Triton kernel name.

  • verbose (bool, optional) -- Defaults to False.

返回:

the generated Triton code string and the name of the Triton function.

返回类型:

Tuple[str, str]


  • English

Generate Triton code string

返回:

None

返回类型:

None

spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.compile_triton_code_str(triton_code: str, kernel_name: str, verbose: bool = False, name_space: dict | None = None)[源代码]#

Compile a Triton code string into a runnable Triton JIT function. API Language: 中文 | English


  • 中文

编译Triton代码字符串

返回类型:

None

Materializes the Triton code under the persistent codegen cache, loads or reuses the matching module object, and extracts the requested JIT function. :param triton_code: The Triton code string to compile/cache. :type triton_code: str :param kernel_name: The name of the Triton function to extract. :type kernel_name: str :param verbose: If True, print whether the cached source was

written or reused, along with its path. Defaults to False.

参数:

name_space (Optional[dict], optional) -- Optional globals injected before execution. When provided, it will be updated with symbols defined by the compiled module. Calls without name_space reuse a cached module keyed by the generated source hash; calls with name_space reload so injected symbols stay fresh.

返回:

The compiled Triton JIT function.

返回类型:

triton.JITFunction


  • English

Compile Triton code string

返回:

None

返回类型:

None