spikingjelly.activation_based.triton_kernel.torch2triton package#
PyTorch to Graph#
- spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_inference_graph(fn: Callable, example_inputs: tuple) Graph[源代码]#
Generate an optimized inference FX graph. API Language: 中文 | English
中文
生成推理计算图
- 参数:
fn (
Callable) -- EN: Callable to trace. Chinese: 待追踪的可调用对象。example_inputs (tuple) -- EN: Example inputs used for tracing. Chinese: 用于追踪的示例输入。
- 返回:
EN: Optimized forward FX graph. Chinese: 优化后的前向 FX 图。
- 返回类型:
- 抛出:
ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。
- Chinese:
为给定的 PyTorch 函数生成优化后的推理 FX 图。
- English:
Generate an optimized inference FX graph for a PyTorch callable.
English
Generate inference graph
- 抛出:
ValueError -- EN: Raised when the traced callable fails to produce a forward graph. Chinese: 当被追踪函数未能产生前向 FX 图时抛出。
- 返回类型:
- spikingjelly.activation_based.triton_kernel.torch2triton.torch2graph.generate_forward_and_backward_graph(fn: Callable, example_inputs: tuple, requires_grad: Sequence[bool] | None = None) Tuple[Graph, Graph][源代码]#
Generate optimized forward/backward FX graphs. API Language: 中文 | English
中文
生成前向和反向计算图
- 参数:
- 返回:
EN: Optimized forward and backward FX graphs. Chinese: 优化后的前向与反向 FX 图。
- 返回类型:
Tuple[torch.fx.Graph, torch.fx.Graph]
- 抛出:
ValueError -- EN: Raised when
requires_gradlength mismatchesexample_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当requires_grad长度与example_inputs不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。
- Chinese:
为给定的 PyTorch 函数生成优化后的前向与反向 FX 图。
- English:
Generate optimized forward and backward FX graphs for a PyTorch callable.
English
Generate forward and backward graphs
- 抛出:
ValueError -- EN: Raised when
requires_gradlength mismatchesexample_inputs, when the callable does not return a tensor/list/tuple, or when no differentiable output exists. Chinese: 当requires_grad长度与example_inputs不匹配、函数返回值不是张量/列表/元组、或不存在可求导输出时抛出。- 返回类型:
Tuple[torch.fx.Graph, torch.fx.Graph]
Graph to Triton#
- spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.generate_triton_code_str(graph: Graph, fn_name: str, verbose: bool = False) Tuple[str, str][源代码]#
Given a fx.Graph, generate its corresponding Triton code string. API Language: 中文 | English
中文
生成Triton代码字符串
- 返回类型:
None
- 参数:
- 返回:
the generated Triton code string and the name of the Triton function.
- 返回类型:
English
Generate Triton code string
- 返回:
None
- 返回类型:
None
- spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton.compile_triton_code_str(triton_code: str, kernel_name: str, verbose: bool = False, name_space: dict | None = None)[源代码]#
Compile a Triton code string into a runnable Triton JIT function. API Language: 中文 | English
中文
编译Triton代码字符串
- 返回类型:
None
Materializes the Triton code under the persistent codegen cache, loads or reuses the matching module object, and extracts the requested JIT function. :param triton_code: The Triton code string to compile/cache. :type triton_code: str :param kernel_name: The name of the Triton function to extract. :type kernel_name: str :param verbose: If True, print whether the cached source was
written or reused, along with its path. Defaults to False.
- 参数:
name_space (Optional[dict], optional) -- Optional globals injected before execution. When provided, it will be updated with symbols defined by the compiled module. Calls without
name_spacereuse a cached module keyed by the generated source hash; calls withname_spacereload so injected symbols stay fresh.- 返回:
The compiled Triton JIT function.
- 返回类型:
triton.JITFunction
English
Compile Triton code string
- 返回:
None
- 返回类型:
None