spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton 源代码

from typing import Optional, Tuple
import errno
import importlib.util
import linecache
import os
from pathlib import Path
import hashlib
import re
import stat
import sys
import tempfile
import threading
import types

import torch
import torch.fx as fx

try:
    import triton
    import triton.language as tl
except BaseException as e:
    import logging

    from .. import dummy

    logging.info(
        f"spikingjelly.activation_based.triton_kernel.torch2triton.graph2triton: {e}"
    )
    triton = dummy.DummyImport()
    tl = dummy.DummyImport()

from ..triton_utils import type_str_dict


__all__ = [
    "generate_triton_code_str",
    "compile_triton_code_str",
]


_MODULE_CACHE_LOCK_GUARD = threading.Lock()
_MODULE_CACHE_LOCKS = {}
_CODEGEN_CACHE_DIR = None
_CODEGEN_CACHE_DIR_LOCK = threading.Lock()
_NAMESPACE_METADATA_KEYS = {
    "__name__",
    "__spec__",
    "__loader__",
    "__package__",
    "__path__",
    "__file__",
    "__cached__",
    "__builtins__",
    "__doc__",
}


def _get_module_cache_lock(module_name: str) -> threading.Lock:
    with _MODULE_CACHE_LOCK_GUARD:
        return _MODULE_CACHE_LOCKS.setdefault(module_name, threading.Lock())


def _generate_hash(s: str, w: int = 8) -> str:
    hasher = hashlib.sha256(s.encode("utf-8"))
    return hasher.hexdigest()[:w]


def _safe_codegen_stem(kernel_name: str) -> str:
    name = str(kernel_name).replace("\\", "/").rsplit("/", 1)[-1]
    safe = re.sub(r"[^0-9A-Za-z_.-]+", "_", name).strip("._")
    return (safe or "kernel")[:128]


def _has_real_triton_runtime() -> bool:
    return isinstance(triton, types.ModuleType) and isinstance(tl, types.ModuleType)


def _codegen_cache_dir() -> Path:
    global _CODEGEN_CACHE_DIR
    if _CODEGEN_CACHE_DIR is not None:
        return _CODEGEN_CACHE_DIR
    with _CODEGEN_CACHE_DIR_LOCK:
        if _CODEGEN_CACHE_DIR is not None:
            return _CODEGEN_CACHE_DIR
        cache_dir = _resolve_codegen_cache_dir()
        _CODEGEN_CACHE_DIR = cache_dir
        return cache_dir


def _resolve_codegen_cache_dir() -> Path:
    candidates = []
    uid = getattr(os, "getuid", lambda: None)()
    try:
        candidates.append(Path.home() / ".spikingjelly" / "triton_codegen")
    except RuntimeError:
        pass
    temp_suffix = f"_{uid}" if uid is not None else ""
    candidates.append(
        Path(tempfile.gettempdir()) / f"spikingjelly_triton_codegen{temp_suffix}"
    )
    last_error = None
    for cache_dir in candidates:
        try:
            cache_dir.mkdir(parents=True, exist_ok=True, mode=0o700)
            if uid is not None:
                st = cache_dir.stat()
                if st.st_uid == uid:
                    try:
                        os.chmod(cache_dir, 0o700)
                    except OSError:
                        pass
                    st = cache_dir.stat()
                mode = stat.S_IMODE(st.st_mode)
                if st.st_uid != uid or not (mode & stat.S_IWUSR) or (mode & 0o077):
                    continue
            with tempfile.NamedTemporaryFile(dir=cache_dir, delete=True):
                pass
            return cache_dir
        except OSError as e:
            last_error = e
            if e.errno not in (errno.EACCES, errno.EROFS, errno.EPERM):
                raise
    if last_error is not None:
        raise last_error
    raise RuntimeError("Failed to initialize Triton codegen cache directory")


def _uw(arg) -> str:
    """Unwrap an argument to its string representation for Triton code generation."""
    if isinstance(arg, fx.Node):
        return arg.name
    elif isinstance(arg, torch.dtype):
        return type_str_dict[arg]
    return str(arg)


# code generation rules
FX_TO_TRITON = {
    "add": lambda args, kwargs: f"{_uw(args[0])} + {_uw(args[1])}",
    "add.Scalar": lambda args, kwargs: f"{_uw(args[0])} + {_uw(args[1])}",
    "add.Tensor": lambda args, kwargs: (
        f"{_uw(args[0])} + {_uw(args[1])}"
        if kwargs.get("alpha", 1.0) == 1.0
        else f"{_uw(args[0])} + ({kwargs['alpha']} * {_uw(args[1])})"
    ),
    "sub": lambda args, kwargs: f"{_uw(args[0])} - {_uw(args[1])}",
    "sub.Tensor": lambda args, kwargs: (
        f"{_uw(args[0])} - {_uw(args[1])}"
        if kwargs.get("alpha", 1.0) == 1.0
        else f"{_uw(args[0])} - ({kwargs['alpha']} * {_uw(args[1])})"
    ),
    "sub.Scalar": lambda args, kwargs: f"{_uw(args[0])} - {_uw(args[1])}",
    "rsub.Scalar": lambda args, kwargs: f"{_uw(args[1])} - {_uw(args[0])}",
    "mul": lambda args, kwargs: f"{_uw(args[0])} * {_uw(args[1])}",
    "mul.Tensor": lambda args, kwargs: f"{_uw(args[0])} * {_uw(args[1])}",
    "mul.Scalar": lambda args, kwargs: f"{_uw(args[0])} * {_uw(args[1])}",
    "div": lambda args, kwargs: f"{_uw(args[0])} / {_uw(args[1])}",
    "div.Tensor": lambda args, kwargs: f"{_uw(args[0])} / {_uw(args[1])}",
    "div.Scalar": lambda args, kwargs: f"{_uw(args[0])} / {_uw(args[1])}",
    "bitwise_and.Tensor": lambda args, kwargs: f"{_uw(args[0])} & {_uw(args[1])}",
    "bitwise_or.Tensor": lambda args, kwargs: f"{_uw(args[0])} | {_uw(args[1])}",
    "bitwise_not.default": lambda args, kwargs: f"~{_uw(args[0])}",
    # logical_* follow ATen truthiness: non-zero = True; bitwise ops would give
    # wrong results for numeric inputs (e.g. logical_not(2) → False, but ~2 = -3)
    "logical_and.default": lambda args, kwargs: (
        f"({_uw(args[0])} != 0) & ({_uw(args[1])} != 0)"
    ),
    "logical_or.default": lambda args, kwargs: (
        f"({_uw(args[0])} != 0) | ({_uw(args[1])} != 0)"
    ),
    "logical_not.default": lambda args, kwargs: f"({_uw(args[0])} == 0)",
    "eq.Tensor": lambda args, kwargs: f"{_uw(args[0])} == {_uw(args[1])}",
    "eq.Scalar": lambda args, kwargs: f"{_uw(args[0])} == {_uw(args[1])}",
    "ge.Tensor": lambda args, kwargs: f"{_uw(args[0])} >= {_uw(args[1])}",
    "ge.Scalar": lambda args, kwargs: f"{_uw(args[0])} >= {_uw(args[1])}",
    "le.Tensor": lambda args, kwargs: f"{_uw(args[0])} <= {_uw(args[1])}",
    "le.Scalar": lambda args, kwargs: f"{_uw(args[0])} <= {_uw(args[1])}",
    "gt.Tensor": lambda args, kwargs: f"{_uw(args[0])} > {_uw(args[1])}",
    "gt.Scalar": lambda args, kwargs: f"{_uw(args[0])} > {_uw(args[1])}",
    "lt.Tensor": lambda args, kwargs: f"{_uw(args[0])} < {_uw(args[1])}",
    "lt.Scalar": lambda args, kwargs: f"{_uw(args[0])} < {_uw(args[1])}",
    "reciprocal.default":  # may result in change of dtype!!!
    lambda args, kwargs: f"(1. / {_uw(args[0])}).to({_uw(args[0])}.dtype)",
    "neg.default": lambda args, kwargs: f"-{_uw(args[0])}",
    "spike_fn.default": lambda args, kwargs: (
        f"({_uw(args[0])} >= 0.).to({_uw(args[0])}.dtype)"
    ),
    "detach.default": lambda args, kwargs: f"{_uw(args[0])}",
    "sigmoid.default":  # triton does not support exponential operations on fp16
    lambda args, kwargs: (
        f"tl.sigmoid({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "sigmoid_backward.default":  # args[1] is the output of sigmoid
    lambda args, kwargs: f"{_uw(args[0])} * {_uw(args[1])} * (1 - {_uw(args[1])})",
    "tanh_backward.default":  # args[0]=grad_out, args[1]=tanh_output
    lambda args, kwargs: f"{_uw(args[0])} * (1 - {_uw(args[1])} * {_uw(args[1])})",
    "threshold_backward.default":  # args: grad, input, threshold
    lambda args, kwargs: (
        f"tl.where({_uw(args[1])} > {_uw(args[2])}, {_uw(args[0])}, 0.0)"
    ),
    "_to_copy.default": lambda args, kwargs: (
        f"{_uw(args[0])}.to({_uw(kwargs['dtype'])})"
    ),
    "scalar_tensor.default": lambda args, kwargs: (
        f"tl.full([], {_uw(args[0])}, {_uw(kwargs['dtype'])})"
    ),
    "where.self": lambda args, kwargs: (
        f"tl.where({_uw(args[0])}.to(tl.int1), {_uw(args[1])}, {_uw(args[2])})"
    ),
    # ---------- unary math (upcast fp16→fp32 for transcendentals) ----------
    "exp.default": lambda args, kwargs: (
        f"tl.exp({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "log.default": lambda args, kwargs: (
        f"tl.log({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "log2.default": lambda args, kwargs: (
        f"tl.log2({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "sqrt.default": lambda args, kwargs: (
        f"tl.sqrt({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "rsqrt.default": lambda args, kwargs: (
        f"tl.rsqrt({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "abs.default": lambda args, kwargs: f"tl.abs({_uw(args[0])})",
    "tanh.default": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.tanh("
        f"{_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "sin.default": lambda args, kwargs: (
        f"tl.math.sin({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "cos.default": lambda args, kwargs: (
        f"tl.math.cos({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "erf.default": lambda args, kwargs: (
        f"tl.math.erf({_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    # ---------- rounding ----------
    "floor.default": lambda args, kwargs: f"tl.floor({_uw(args[0])})",
    "ceil.default": lambda args, kwargs: f"tl.ceil({_uw(args[0])})",
    "round.default": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.round("
        f"{_uw(args[0])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    # ---------- activation ----------
    "relu.default": lambda args, kwargs: f"tl.maximum({_uw(args[0])}, 0.0)",
    "sign.default": lambda args, kwargs: (
        f"({_uw(args[0])} > 0.).to({_uw(args[0])}.dtype)"
        f" - ({_uw(args[0])} < 0.).to({_uw(args[0])}.dtype)"
    ),
    "sgn.default": lambda args,  # complex sign; for real tensors same as sign
    kwargs: (
        f"({_uw(args[0])} > 0.).to({_uw(args[0])}.dtype)"
        f" - ({_uw(args[0])} < 0.).to({_uw(args[0])}.dtype)"
    ),
    # ---------- binary element-wise ----------
    "minimum.default": lambda args, kwargs: (
        f"tl.minimum({_uw(args[0])}, {_uw(args[1])})"
    ),
    "maximum.default": lambda args, kwargs: (
        f"tl.maximum({_uw(args[0])}, {_uw(args[1])})"
    ),
    "ne.Scalar": lambda args, kwargs: f"{_uw(args[0])} != {_uw(args[1])}",
    "ne.Tensor": lambda args, kwargs: f"{_uw(args[0])} != {_uw(args[1])}",
    "fmod.Scalar": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.fmod("
        f"{_uw(args[0])}.to(tl.float32),"
        f" tl.full([], {_uw(args[1])}, tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "fmod.Tensor": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.fmod("
        f"{_uw(args[0])}.to(tl.float32),"
        f" {_uw(args[1])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "pow.Tensor_Scalar": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.pow("
        f"{_uw(args[0])}.to(tl.float32),"
        f" tl.full([], {_uw(args[1])}, tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    "pow.Tensor_Tensor": lambda args, kwargs: (
        f"tl.extra.cuda.libdevice.pow("
        f"{_uw(args[0])}.to(tl.float32),"
        f" {_uw(args[1])}.to(tl.float32)).to({_uw(args[0])}.dtype)"
    ),
    # ---------- clamp ----------
    "clamp.default": lambda args, kwargs: (
        # args: (tensor, min_val, max_val) — both optional
        f"tl.minimum(tl.maximum({_uw(args[0])}, {_uw(args[1])}), {_uw(args[2])})"
        if len(args) >= 3 and args[1] is not None and args[2] is not None
        else f"tl.maximum({_uw(args[0])}, {_uw(args[1])})"
        if len(args) >= 2 and args[1] is not None
        else f"tl.minimum({_uw(args[0])}, {_uw(args[2])})"
        if len(args) >= 3 and args[2] is not None
        else _uw(args[0])
    ),
    "clamp_min.default": lambda args, kwargs: (
        f"tl.maximum({_uw(args[0])}, {_uw(args[1])})"
    ),
    "clamp_max.default": lambda args, kwargs: (
        f"tl.minimum({_uw(args[0])}, {_uw(args[1])})"
    ),
    # ---------- misc ----------
    "clone.default": lambda args, kwargs: f"{_uw(args[0])}",
    # Use tl.full to avoid propagating NaN/Inf from input values
    "zeros_like.default": lambda args, kwargs: (
        f"tl.full({_uw(args[0])}.shape, 0, {_uw(args[0])}.dtype)"
    ),
    "ones_like.default": lambda args, kwargs: (
        f"tl.full({_uw(args[0])}.shape, 1, {_uw(args[0])}.dtype)"
    ),
    # masked_fill(tensor, mask, value): fill where mask=True with value
    "masked_fill.Scalar": lambda args, kwargs: (
        f"tl.where({_uw(args[1])}.to(tl.int1), {_uw(args[2])}, {_uw(args[0])})"
    ),
    "masked_fill.Tensor": lambda args, kwargs: (
        f"tl.where({_uw(args[1])}.to(tl.int1), {_uw(args[2])}, {_uw(args[0])})"
    ),
}

INDENTATION = " " * 4  # four spaces


[文档] def generate_triton_code_str( graph: fx.Graph, fn_name: str, verbose: bool = False, ) -> Tuple[str, str]: """Given a fx.Graph, generate its corresponding Triton code string. **API Language:** :ref:`中文 <generate_triton_code_str-cn>` | :ref:`English <generate_triton_code_str-en>` ---- .. _generate_triton_code_str-cn: * **中文** 生成Triton代码字符串 :rtype: None Args: graph (fx.Graph) fn_name (str): name of the original PyTorch function. For generating the Triton kernel name. verbose (bool, optional): Defaults to False. Returns: Tuple[str, str]: the generated Triton code string and the name of the Triton function. ---- .. _generate_triton_code_str-en: * **English** Generate Triton code string :return: None :rtype: None """ if verbose: print(graph) inputs = [] triton_code_lines = [] for node in graph.nodes: if node.op == "placeholder": # function inputs inputs.append(node.name) elif node.op in ["call_function", "call_method"]: op_name = ( node.target.__name__ if node.op == "call_function" else node.target ) # e.g. mul.Tensor, spike_fn.default, rsub.Scalar, ... if op_name in FX_TO_TRITON: # apply the transpile rule rhs = FX_TO_TRITON[op_name](node.args, node.kwargs) triton_code_lines.append(f"{node.name} = {rhs}") else: raise NotImplementedError( f"{node.op} {op_name} has not yet been implemented " f"in FX_TO_TRITON mapping." ) elif node.op == "output": if isinstance(node.args[0], fx.Node): # only one return value things = node.args[0].name else: # multiple return values things = ", ".join(arg.name for arg in node.args[0]) triton_code_lines.append(f"return {things}") else: raise NotImplementedError( f"Operation {node.op} has not yet been implemented." ) triton_code_lines = f"{INDENTATION}" + f"\n{INDENTATION}".join(triton_code_lines) fn_name = f"{fn_name}_{_generate_hash(triton_code_lines)}" signature = ", ".join(inputs) signature = f"@triton.jit\ndef {fn_name}({signature}):" prefix = "import triton\nimport triton.language as tl" return f"{prefix}\n\n{signature}\n{triton_code_lines}", fn_name
[文档] def compile_triton_code_str( triton_code: str, kernel_name: str, verbose: bool = False, name_space: Optional[dict] = None, ): """Compile a Triton code string into a runnable Triton JIT function. **API Language:** :ref:`中文 <compile_triton_code_str-cn>` | :ref:`English <compile_triton_code_str-en>` ---- .. _compile_triton_code_str-cn: * **中文** 编译Triton代码字符串 :rtype: None Materializes the Triton code under the persistent codegen cache, loads or reuses the matching module object, and extracts the requested JIT function. Args: triton_code (str): The Triton code string to compile/cache. kernel_name (str): The name of the Triton function to extract. verbose (bool, optional): If True, print whether the cached source was written or reused, along with its path. Defaults to False. name_space (Optional[dict], optional): Optional globals injected before execution. When provided, it will be updated with symbols defined by the compiled module. Calls without ``name_space`` reuse a cached module keyed by the generated source hash; calls with ``name_space`` reload so injected symbols stay fresh. Returns: triton.JITFunction: The compiled Triton JIT function. ---- .. _compile_triton_code_str-en: * **English** Compile Triton code string :return: None :rtype: None """ if not _has_real_triton_runtime(): raise ImportError( "compile_triton_code_str requires a real Triton installation; " "the imported triton/tl modules are unavailable." ) caller_namespace = name_space cacheable = caller_namespace is None if caller_namespace is None: module_globals = {"triton": triton, "tl": tl} else: module_globals = { key: value for key, value in caller_namespace.items() if key not in _NAMESPACE_METADATA_KEYS } module_globals.pop(kernel_name, None) module_globals.setdefault("triton", triton) module_globals.setdefault("tl", tl) safe_kernel_name = _safe_codegen_stem(kernel_name) module_hash = _generate_hash(f"{kernel_name}\n{triton_code}", w=16) module_name = ( "spikingjelly.activation_based.triton_kernel.codegen." f"{safe_kernel_name}_{module_hash}" ) fpath = _codegen_cache_dir() / f"{safe_kernel_name}_{module_hash}.py" needs_write = not fpath.exists() if needs_write: tmp_path = None try: with tempfile.NamedTemporaryFile( "w", encoding="utf-8", dir=fpath.parent, delete=False, suffix=".tmp" ) as tmp_file: tmp_path = Path(tmp_file.name) tmp_file.write(triton_code) os.replace(tmp_path, fpath) try: os.chmod(fpath, 0o600) except OSError: pass except Exception: if tmp_path is not None: try: tmp_path.unlink() except FileNotFoundError: pass raise if verbose: action = "written to" if needs_write else "loaded from cache" print(f"Triton code `{kernel_name}` {action} {fpath}") linecache.checkcache(str(fpath)) with _get_module_cache_lock(module_name): module = sys.modules.get(module_name) if cacheable else None if module is None: spec = importlib.util.spec_from_file_location(module_name, fpath) if spec is None or spec.loader is None: raise ImportError(f"Could not create import spec for {fpath}") module = importlib.util.module_from_spec(spec) module.__dict__.update(module_globals) spec.loader.exec_module(module) if cacheable: sys.modules[module_name] = module if caller_namespace is not None: exported_symbols = { key: value for key, value in module.__dict__.items() if key not in _NAMESPACE_METADATA_KEYS } caller_namespace.update(exported_symbols) if kernel_name in module.__dict__: return module.__dict__[kernel_name] raise ValueError(f"Function {kernel_name} not found in compiled namespace")