spikingjelly.activation_based.triton_kernel.triton_utils 源代码
"""Borrowed from:
https://github.com/AllenYolk/flash-snn/tree/main/flashsnn/utils
https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
"""
import contextlib
import functools
import os
import tempfile
import threading
from typing import Callable
import torch
from packaging import version
from . import dummy
try:
from torch.library import triton_op
_TRITON_OP_AVAILABLE = True
except BaseException:
triton_op = dummy.DummyImport()
_TRITON_OP_AVAILABLE = False
try:
import triton
import triton.language as tl
type_dict = {
torch.bool: tl.int1,
torch.float32: tl.float32,
torch.float16: tl.float16,
}
type_str_dict = {
torch.bool: "tl.int1",
torch.float32: "tl.float32",
torch.float16: "tl.float16",
}
# check bfloat16 support
dc = torch.cuda.get_device_capability()
if dc[0] < 8 or not hasattr(tl, "bfloat16") or not hasattr(torch, "bfloat16"):
print("bfloat16 is not supported on this device.")
else:
type_dict[torch.bfloat16] = tl.bfloat16
type_str_dict[torch.bfloat16] = "tl.bfloat16"
except BaseException as e:
import logging
logging.info(f"spikingjelly.activation_based.triton_kernel.triton_utils: {e}")
triton = dummy.DummyImport()
tl = dummy.DummyImport()
type_dict = {}
type_str_dict = {}
@triton.jit
def convert_and_store(pointer, value, boundary_check):
# For block pointers created by tl.make_block_pointer(),
# implicit type casting is not supported when calling tl.store().
# This function manually converts dtype and then stores the data.
value = value.to(pointer.dtype.element_ty.element_ty)
tl.store(pointer, value, boundary_check=boundary_check)
def _env_flag_enabled(var_name: str) -> bool:
v = os.getenv(var_name)
if v is None:
return True
return v.strip().lower() not in ("0", "false", "off", "no")
[文档]
def register_op(opname: str, mutates_args=()):
if _env_flag_enabled("SJ_USE_TRITON_OP") and _TRITON_OP_AVAILABLE:
return triton_op(opname, mutates_args=mutates_args)
return torch.library.custom_op(opname, mutates_args=mutates_args)
[文档]
def wrap_triton(kernel):
if (
_TRITON_OP_AVAILABLE
and _env_flag_enabled("SJ_USE_TRITON_OP")
and _env_flag_enabled("SJ_USE_WRAP_TRITON")
):
return torch.library.wrap_triton(kernel)
return kernel
[文档]
def contiguous_and_device_guard(f: Callable) -> Callable:
"""Make sure all input tensors are contiguous and set to the same device."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
contiguous_args = (
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
}
# find the first tensor in the argument list
first_tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
first_tensor = arg
break
if first_tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
first_tensor = value
break
if first_tensor is not None and first_tensor.device.type == "cuda":
ctx = torch.cuda.device(first_tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return f(*contiguous_args, **contiguous_kwargs)
return wrapper
_TMP_PY_LOCK = threading.Lock()
_TMP_PY_TRACKER = threading.local()
[文档]
def ensure_cleanup_tmp_python_files(f: Callable) -> Callable:
"""Remove temporary python files returned or created by a wrapped function."""
@functools.wraps(f)
def wrapper(*args, **kwargs):
with _TMP_PY_LOCK:
tmp_paths = []
_TMP_PY_TRACKER.paths = tmp_paths
original_named_temporary_file = tempfile.NamedTemporaryFile
def tracking_named_temporary_file(*ntf_args, **ntf_kwargs):
tmp = original_named_temporary_file(*ntf_args, **ntf_kwargs)
tmp_name = getattr(tmp, "name", None)
if isinstance(tmp_name, str) and tmp_name.endswith(".py"):
thread_paths = getattr(_TMP_PY_TRACKER, "paths", None)
if thread_paths is not None:
thread_paths.append(tmp_name)
return tmp
tempfile.NamedTemporaryFile = tracking_named_temporary_file
try:
result = f(*args, **kwargs)
if isinstance(result, str) and result.endswith(".py"):
tmp_paths.append(result)
elif isinstance(result, tempfile._TemporaryFileWrapper):
tmp_paths.append(result.name)
return result
finally:
tempfile.NamedTemporaryFile = original_named_temporary_file
for path in tmp_paths:
try:
if path and os.path.exists(path):
os.remove(path)
except OSError:
pass
_TMP_PY_TRACKER.paths = []
return wrapper
@functools.lru_cache(maxsize=None)
def _check_pytorch_version(version_s: str = "2.4") -> bool:
return version.parse(torch.__version__) >= version.parse(version_s)
if _check_pytorch_version("2.4"):
amp_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type="cuda")
amp_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type="cuda")
else:
amp_custom_fwd = torch.cuda.amp.custom_fwd
amp_custom_bwd = torch.cuda.amp.custom_bwd