spikingjelly.activation_based.triton_kernel.torch2triton package#

PyTorch to Graph#

spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_inference_graph(fn, example_inputs)[源代码]#

Generate an optimized inference FX graph.

API Language - 中文 | English


  • 中文

生成推理计算图

参数:
  • fn (Callable) -- EN: Callable to trace. Chinese: 待追踪的可调用对象。

  • example_inputs (tuple) -- EN: Example inputs used for tracing. Chinese: 用于追踪的示例输入。

返回:

EN: Optimized forward FX graph. Chinese: 优化后的前向 FX 图。

返回类型:

Graph

抛出:

ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。

Chinese:

为给定的 PyTorch 函数生成优化后的推理 FX 图。

English:

Generate an optimized inference FX graph for a PyTorch callable.


  • English

Generate inference graph

抛出:

ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。

返回类型:

Graph

参数:
  • fn (Callable)

  • example_inputs (tuple)

spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_forward_and_backward_graph(fn, example_inputs, requires_grad=None)[源代码]#

Generate optimized forward/backward FX graphs.

API Language - 中文 | English


  • 中文

生成前向和反向计算图

参数:
  • fn (Callable) -- EN: Callable to trace. Chinese: 待追踪的可调用对象。

  • example_inputs (tuple) -- EN: Example inputs used for tracing. Chinese: 用于追踪的示例输入。

  • requires_grad (Optional[Sequence[bool]]) -- EN: Optional gradient-requirement flags for each example input. Chinese: 每个示例输入对应的可选求导标志。

返回:

EN: Optimized forward and backward FX graphs. Chinese: 优化后的前向与反向 FX 图。

返回类型:

Tuple[Graph, Graph]

抛出:

ValueError -- EN: Raised when requires_grad length mismatches example_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当 requires_grad 长度与 example_inputs 不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。

Chinese:

为给定的 PyTorch 函数生成优化后的前向与反向 FX 图。

English:

Generate optimized forward and backward FX graphs for a PyTorch callable.


  • English

Generate forward and backward graphs

抛出:

ValueError -- EN: Raised when requires_grad length mismatches example_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当 requires_grad 长度与 example_inputs 不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。

返回类型:

Tuple[Graph, Graph]

参数:
  • fn (Callable)

  • example_inputs (tuple)

  • requires_grad (Optional[Sequence[bool]])

Graph to Triton#

spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.generate_triton_code_str(graph, fn_name, verbose=False)[源代码]#

Given a fx.Graph, generate its corresponding Triton code string.

API Language - 中文 | English


  • 中文

生成Triton代码字符串

参数:
  • graph (fx.Graph)

  • fn_name (str) -- name of the original PyTorch function. For generating the Triton kernel name.

  • verbose (bool, optional) -- Defaults to False.

返回:

the generated Triton code string and the name of the Triton function.

返回类型:

Tuple[str, str]


  • English

Generate Triton code string

spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.compile_triton_code_str(triton_code, kernel_name, verbose=False, name_space=None)[源代码]#

Compile a Triton code string into a runnable Triton JIT function.

API Language - 中文 | English


  • 中文

编译Triton代码字符串

将 Triton 代码写入持久化 codegen cache,加载或复用匹配的模块对象,并返回指定的 JIT 函数。

参数:
  • triton_code (str) -- The Triton code string to compile/cache.

  • kernel_name (str) -- The name of the Triton function to extract.

  • verbose (bool) -- If True, print whether the cached source was written or reused.

  • name_space (Optional[dict]) -- Optional globals injected before execution. When provided, it is updated with symbols defined by the compiled module.

返回:

The compiled Triton JIT function.

返回类型:

triton.JITFunction


  • English

Compile Triton code string