spikingjelly.activation_based.cuda_kernel.auto_cuda.base 源代码

import numpy as np
import logging

try:
    import cupy
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.cuda_kernel.auto_cuda.base: {e}")
    cupy = None

import torch
import sys
import logging

from .. import cuda_utils
from .... import configure


[文档] def wrap_with_comment(code: str, comment: str): if logging.DEBUG >= logging.root.level: return ( "\n//------" + comment + " start------\n" + code + "\n//------" + comment + " end--------\n\n" ) else: return code
[文档] def startswiths(x: str, prefixes: tuple): ret = False for prefix in prefixes: if x.startswith(prefix): ret = True return ret
[文档] class CKernel: def __init__(self, kernel_name: str): r""" **API Language:** :ref:`中文 <ckernel-init-cn>` | :ref:`English <ckernel-init-en>` ---- .. _ckernel-init-cn: * **中文** 自定义 CUDA kernel 的基础封装类。它维护 kernel 形参表 ``cparams``、保留变量名 ``reserved_cnames``,以及可拼接的代码片段(declaration/head/core/tail)。 :param kernel_name: CUDA kernel 名称 :type kernel_name: str ---- .. _ckernel-init-en: * **English** Base wrapper for custom CUDA kernels. It stores kernel parameter metadata (``cparams``), reserved C variable names (``reserved_cnames``), and code segments (declaration/head/core/tail). :param kernel_name: CUDA kernel name :type kernel_name: str """ self.cparams = {} self.reserved_cnames = [] self.kernel_name = kernel_name self._core = ""
[文档] def check_attributes(self, **kwargs): r""" **API Language:** :ref:`中文 <ckernel-check-attributes-cn>` | :ref:`English <ckernel-check-attributes-en>` ---- .. _ckernel-check-attributes-cn: * **中文** 检查 ``kwargs`` 中给定属性值是否与当前对象属性一致。 :param kwargs: 待检查的属性键值对 :type kwargs: dict :return: 全部属性一致时返回 ``True``,否则返回 ``False`` :rtype: bool ---- .. _ckernel-check-attributes-en: * **English** Check whether provided attribute values in ``kwargs`` match current attributes on this object. :param kwargs: Attribute key-value pairs to check :type kwargs: dict :return: ``True`` if all attributes match; otherwise ``False`` :rtype: bool """ for key, value in kwargs.items(): if value != self.__getattribute__(key): return False else: return True
@property def core(self): return self._core @core.setter def core(self, value): self._core = value
[文档] def set_contiguous(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-set-contiguous-cn>` | :ref:`English <ckernel-set-contiguous-en>` ---- .. _ckernel-set-contiguous-cn: * **中文** 将 ``py_dict`` 中的 ``torch.Tensor``/``cupy.ndarray`` 转为连续内存;若出现 其他类型则抛出异常。 :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel-set-contiguous-en: * **English** Make ``torch.Tensor``/``cupy.ndarray`` values in ``py_dict`` contiguous. Raise an error for unsupported value types. :param py_dict: Kernel argument dictionary :type py_dict: dict """ # get contiguous for key, value in py_dict.items(): if isinstance(value, torch.Tensor): value = value.contiguous() elif isinstance(value, cupy.ndarray): value = cupy.ascontiguousarray(value) else: raise TypeError(type(value)) py_dict[key] = value
[文档] def get_device(self, py_dict: dict) -> int: r""" **API Language:** :ref:`中文 <ckernel-get-device-cn>` | :ref:`English <ckernel-get-device-en>` ---- .. _ckernel-get-device-cn: * **中文** 遍历 ``py_dict``,返回首个张量对象所在 CUDA 设备编号。 :param py_dict: kernel 参数字典 :type py_dict: dict :return: CUDA 设备编号 :rtype: int :raise ValueError: 当 ``py_dict`` 中没有张量参数时 ---- .. _ckernel-get-device-en: * **English** Traverse ``py_dict`` and return the CUDA device id of the first tensor-like value. :param py_dict: Kernel argument dictionary :type py_dict: dict :return: CUDA device id :rtype: int :raise ValueError: If no tensor-like value is found """ for item in py_dict.values(): if isinstance(item, torch.Tensor): return item.get_device() elif isinstance(item, cupy.ndarray): return item.device.id raise ValueError
[文档] def check_device(self, device: int, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-check-device-cn>` | :ref:`English <ckernel-check-device-en>` ---- .. _ckernel-check-device-cn: * **中文** 检查 ``py_dict`` 中所有张量是否都位于 ``device`` 指定的 CUDA 设备上。 :param device: 目标 CUDA 设备编号 :type device: int :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel-check-device-en: * **English** Validate that all tensor-like values in ``py_dict`` are on the target CUDA device ``device``. :param device: Target CUDA device id :type device: int :param py_dict: Kernel argument dictionary :type py_dict: dict """ for item in py_dict.values(): if isinstance(item, torch.Tensor): assert item.get_device() == device elif isinstance(item, cupy.ndarray): assert item.device.id == device
[文档] def check_keys(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-check-keys-cn>` | :ref:`English <ckernel-check-keys-en>` ---- .. _ckernel-check-keys-cn: * **中文** 检查 ``py_dict`` 的键集合是否与 ``self.cparams`` 一致,不一致时抛出异常。 :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel-check-keys-en: * **English** Check whether keys in ``py_dict`` exactly match keys in ``self.cparams``. Raise an error on mismatch. :param py_dict: Kernel argument dictionary :type py_dict: dict """ if py_dict.keys() != self.cparams.keys(): missed_keys = (py_dict.keys() | self.cparams.keys()) - ( py_dict.keys() & self.cparams.keys() ) if missed_keys.__len__() > 0: if (missed_keys & py_dict.keys()).__len__() > 0: msg = f"{missed_keys} is in py_dict but not in cparams!" else: msg = f"{missed_keys} is in cparams but not in py_dict!" raise ValueError(msg)
[文档] def check_ctypes(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-check-ctypes-cn>` | :ref:`English <ckernel-check-ctypes-en>` ---- .. _ckernel-check-ctypes-cn: * **中文** 检查 ``py_dict`` 中各参数的数据类型是否与 ``self.cparams`` 中声明的 CUDA C 参数类型匹配(例如 ``float`` / ``half2`` / ``int``)。 :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel-check-ctypes-en: * **English** Validate that runtime value dtypes in ``py_dict`` match declared CUDA C parameter types in ``self.cparams`` (e.g., ``float``/``half2``/``int``). :param py_dict: Kernel argument dictionary :type py_dict: dict """ for key, value in py_dict.items(): ctype: str = self.cparams[key] if isinstance(value, torch.Tensor): if value.dtype == torch.float: assert startswiths(ctype, ("const float", "float")) elif value.dtype == torch.half: assert startswiths(ctype, ("const half2", "half2")) if isinstance(value, cupy.ndarray): if value.dtype == np.float32: assert startswiths(ctype, ("const float", "float")) elif value.dtype == np.float16: assert startswiths(ctype, ("const half2", "half2")) elif value.dtype == int or ( hasattr(np, "int") and value.dtype == np.int ): assert startswiths(ctype, ("const int", "int"))
[文档] def check_half2(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-check-half2-cn>` | :ref:`English <ckernel-check-half2-en>` ---- .. _ckernel-check-half2-cn: * **中文** 供子类按需实现的 ``half2`` 参数校验接口。 :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel-check-half2-en: * **English** Extension hook for subclasses to implement ``half2``-related checks. :param py_dict: Kernel argument dictionary :type py_dict: dict """ raise NotImplementedError
[文档] def get_ptrs(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel-get-ptrs-cn>` | :ref:`English <ckernel-get-ptrs-en>` ---- .. _ckernel-get-ptrs-cn: * **中文** 提取 ``py_dict`` 中每个张量参数的底层指针(或等价对象),按键排序后的参数顺序返回。 :param py_dict: kernel 参数字典 :type py_dict: dict :return: 指针元组 :rtype: tuple ---- .. _ckernel-get-ptrs-en: * **English** Collect underlying pointers (or equivalent objects) for tensor-like values in ``py_dict`` and return them as a tuple. :param py_dict: Kernel argument dictionary :type py_dict: dict :return: Tuple of argument pointers :rtype: tuple """ ret_list = [] for item in py_dict.values(): if isinstance(item, torch.Tensor): ret_list.append(item.data_ptr()) elif isinstance(item, cupy.ndarray): ret_list.append(item) else: raise TypeError return tuple(ret_list)
[文档] def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): r""" **API Language:** :ref:`中文 <ckernel-call-cn>` | :ref:`English <ckernel-call-en>` ---- .. _ckernel-call-cn: * **中文** 执行 CUDA kernel。调用前会完成设备一致性检查、连续化、参数类型检查和键集合校验。 :param grid: CUDA grid 配置 :type grid: tuple :param block: CUDA block 配置 :type block: tuple :param py_dict: kernel 实参字典,键需与 ``self.cparams`` 一一对应 :type py_dict: dict ---- .. _ckernel-call-en: * **English** Execute the CUDA kernel after validating device consistency, contiguous layout, ctypes compatibility, and key alignment with ``self.cparams``. :param grid: CUDA grid configuration :type grid: tuple :param block: CUDA block configuration :type block: tuple :param py_dict: Runtime argument dictionary matching ``self.cparams`` :type py_dict: dict """ device = self.get_device(py_dict) self.check_device(device, py_dict) self.set_contiguous(py_dict) self.check_ctypes(py_dict) self.check_half2(py_dict) py_dict = dict(sorted(py_dict.items())) self.check_keys(py_dict) assert sys.version_info.major >= 3 and sys.version_info.minor >= 6 # 需要使用有序词典 # python >= 3.6时,字典默认是有序的 cp_kernel = cupy.RawKernel( self.full_codes, self.kernel_name, options=configure.cuda_compiler_options, backend=configure.cuda_compiler_backend, ) with cuda_utils.DeviceEnvironment(device): cp_kernel(grid, block, self.get_ptrs(py_dict), *args_1, **kwargs)
[文档] def add_param(self, ctype: str, cname: str): r""" **API Language:** :ref:`中文 <ckernel-add-param-cn>` | :ref:`English <ckernel-add-param-en>` ---- .. _ckernel-add-param-cn: * **中文** 向 ``self.cparams`` 添加一个 CUDA 形参声明。 :param ctype: CUDA 参数类型字符串 :type ctype: str :param cname: CUDA 参数名 :type cname: str :raise ValueError: 当参数名重复或与保留名冲突时 ---- .. _ckernel-add-param-en: * **English** Add one CUDA parameter declaration to ``self.cparams``. :param ctype: CUDA parameter type string :type ctype: str :param cname: CUDA parameter name :type cname: str :raise ValueError: If the name already exists or conflicts with reserved names """ # example: ctype = 'const float *', cname = 'x' if cname in self.cparams: raise ValueError(f"{cname} has been added to cparams!") if cname in self.reserved_cnames: raise ValueError( f"{cname} is the reserved cname. You should change the name of your variable to avoid conflict." ) self.cparams[cname] = ctype
@property def declaration(self): codes = f""" #include <cuda_fp16.h> extern "C" __global__ void {self.kernel_name}( """ self.cparams = dict(sorted(self.cparams.items())) params_list = [] for cname, ctype in self.cparams.items(): params_list.append(f"{ctype} {cname}") codes += ", ".join(params_list) codes += """ ) """ return codes @property def head(self): return "{" @property def tail(self): return "}" @property def full_codes(self): r""" **API Language:** :ref:`中文 <ckernel-full-codes-cn>` | :ref:`English <ckernel-full-codes-en>` ---- .. _ckernel-full-codes-cn: * **中文** 返回拼接后的完整 CUDA 代码字符串。 :return: 完整 CUDA 源码 :rtype: str ---- .. _ckernel-full-codes-en: * **English** Return the full CUDA source string assembled from declaration/head/core/tail. :return: Full CUDA source code :rtype: str """ return ( wrap_with_comment(self.declaration, "declaration") + wrap_with_comment(self.head, "head") + wrap_with_comment(self.core, "core") + wrap_with_comment(self.tail, "tail") )
[文档] class CKernel1D(CKernel): def __init__(self, *args, **kwargs): r""" **API Language:** :ref:`中文 <ckernel1d-init-cn>` | :ref:`English <ckernel1d-init-en>` ---- .. _ckernel1d-init-cn: * **中文** 一维(逐元素)CUDA kernel 封装类,继承自 :class:`CKernel`。该类默认添加 ``numel`` 形参并保留 ``index`` 作为线程索引变量名。 :param kernel_name: CUDA kernel 名称(通过 ``*args, **kwargs`` 传入基类) :type kernel_name: str ---- .. _ckernel1d-init-en: * **English** 1D (element-wise) CUDA kernel wrapper inherited from :class:`CKernel`. It adds ``numel`` by default and reserves ``index`` as the thread index variable name. :param kernel_name: CUDA kernel name (forwarded to the base class via ``*args, **kwargs``) :type kernel_name: str """ super().__init__(*args, **kwargs) self.cparams["numel"] = "const int &" self.reserved_cnames.append("index") @property def head(self): r""" **API Language:** :ref:`中文 <ckernel1d-head-cn>` | :ref:`English <ckernel1d-head-en>` ---- .. _ckernel1d-head-cn: * **中文** 返回 1D kernel 头部代码,包含线程索引计算和 ``index < numel`` 边界判断。 :return: CUDA 头部代码片段 :rtype: str ---- .. _ckernel1d-head-en: * **English** Return the 1D kernel head code, including thread-index computation and the ``index < numel`` guard. :return: CUDA head code snippet :rtype: str """ codes = """ { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < numel) { """ return codes @property def tail(self): r""" **API Language:** :ref:`中文 <ckernel1d-tail-cn>` | :ref:`English <ckernel1d-tail-en>` ---- .. _ckernel1d-tail-cn: * **中文** 返回 1D kernel 尾部代码,用于闭合头部中的代码块。 :return: CUDA 尾部代码片段 :rtype: str ---- .. _ckernel1d-tail-en: * **English** Return the 1D kernel tail code that closes the blocks opened in :attr:`head`. :return: CUDA tail code snippet :rtype: str """ codes = """ } } """ return codes
[文档] def check_half2(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel1d-check-half2-cn>` | :ref:`English <ckernel1d-check-half2-en>` ---- .. _ckernel1d-check-half2-cn: * **中文** 检查 ``py_dict`` 中 ``half``/``float16`` 张量的元素个数是否为偶数,以满足 ``half2`` 访存需求。 :param py_dict: kernel 参数字典 :type py_dict: dict .. admonition:: 注意 :class: note :meth:`CKernel1D.__call__` 会在执行前自动对奇数长度 half 张量补齐,因此通常 无需手工补齐。 ---- .. _ckernel1d-check-half2-en: * **English** Validate that ``half``/``float16`` tensor lengths in ``py_dict`` are even, which is required by ``half2`` operations. :param py_dict: Kernel argument dictionary :type py_dict: dict .. admonition:: Note :class: note :meth:`CKernel1D.__call__` pads odd-length half tensors before kernel launch, so manual padding is usually unnecessary. """ for key, value in py_dict.items(): if isinstance(value, torch.Tensor): if value.dtype == torch.half: assert value.numel() % 2 == 0, ( f"please pad the numel of {key} to assert mod 2 == 0! (for half2)" ) if isinstance(value, cupy.ndarray): if value.dtype == np.float16: assert value.size % 2 == 0, ( f"please pad the numel of {key} to assert mod 2 == 0! (for half2)" )
[文档] def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): r""" **API Language:** :ref:`中文 <ckernel1d-call-cn>` | :ref:`English <ckernel1d-call-en>` ---- .. _ckernel1d-call-cn: * **中文** 执行 1D CUDA kernel。对于 ``half``/``float16`` 且元素个数为奇数的张量,会先补齐 末元素后再调用基类执行,完成后再恢复原始形状与长度。 :param grid: CUDA grid 配置 :type grid: tuple :param block: CUDA block 配置 :type block: tuple :param py_dict: kernel 参数字典,键应与 ``self.cparams`` 对应 :type py_dict: dict ---- .. _ckernel1d-call-en: * **English** Execute the 1D CUDA kernel. For odd-length ``half``/``float16`` tensors, this method pads one trailing element before delegating to the base call, then removes the padding and restores the original shape. :param grid: CUDA grid configuration :type grid: tuple :param block: CUDA block configuration :type block: tuple :param py_dict: Runtime argument dictionary aligned with ``self.cparams`` :type py_dict: dict """ # pad half2 pad_keys = [] pad_shapes = [] for key, value in py_dict.items(): if isinstance(value, torch.Tensor) and value.dtype == torch.half: if value.numel() % 2 != 0: pad_shapes.append(value.shape) pad_keys.append(key) value = value.flatten() value = torch.cat((value, value[-1].view(1))) py_dict[key] = value elif isinstance(value, cupy.ndarray) and value.dtype == np.float16: if value.size % 2 != 0: pad_shapes.append(value.shape) pad_keys.append(key) value = cupy.reshape(value, -1) value = cupy.concatenate((value, cupy.reshape(value[-1], 1))) py_dict[key] = value super().__call__(grid, block, py_dict, *args_1, **kwargs) # move pad values for key, shape in zip(pad_keys, pad_shapes): value = py_dict[key] value = value[:-1] if isinstance(value, torch.Tensor): value = value.view(shape) elif isinstance(value, cupy.ndarray): value = cupy.reshape(value, shape) py_dict[key] = value
[文档] def simple_call(self, **kwargs): r""" **API Language:** :ref:`中文 <ckernel1d-simple-call-cn>` | :ref:`English <ckernel1d-simple-call-en>` ---- .. _ckernel1d-simple-call-cn: * **中文** ``CKernel1D.__call__`` 的简化入口。自动从 ``kwargs`` 中推导设备与 ``numel``, 并使用配置中的线程数计算 ``blocks`` 后执行 kernel。 :param kwargs: kernel 参数键值对(不需要手动提供 ``numel``) :type kwargs: dict ---- .. _ckernel1d-simple-call-en: * **English** A convenience wrapper of :meth:`CKernel1D.__call__`. It infers device and ``numel`` from ``kwargs``, computes launch blocks with configured threads, and launches the kernel. :param kwargs: Kernel argument mapping (``numel`` is inferred automatically) :type kwargs: dict """ py_dict = kwargs device = self.get_device(py_dict) numel = None for value in kwargs.values(): if isinstance(value, torch.Tensor): numel = value.numel() elif isinstance(value, cupy.ndarray): numel = value.size if numel is None: raise ValueError("No torch.Tensor or cupy.ndarray in kwargs!") with cuda_utils.DeviceEnvironment(device): threads = configure.cuda_threads blocks = cuda_utils.cal_blocks(numel) numel = cupy.asarray(numel) py_dict["numel"] = numel self.__call__((blocks,), (threads,), py_dict)
[文档] class CKernel2D(CKernel): def __init__(self, kernel_name: str, reverse: bool = False): r""" **API Language:** :ref:`中文 <ckernel2d-init-cn>` | :ref:`English <ckernel2d-init-en>` ---- .. _ckernel2d-init-cn: * **中文** 二维 CUDA kernel 封装,继承自 :class:`CKernel`。默认包含 ``numel`` 与 ``N`` 两个内置参数,分别表示总元素数与每个时间步的元素数。二维张量按 ``[T, N]`` 解释,其中 ``T`` 为序列长度,``N`` 为单步元素数量。 ``reverse`` 控制时间维循环方向: ``True`` 时使用 ``for(int t = numel - N + index; t >= 0; t -= dt)``; ``False`` 时使用 ``for(int t = index; t < numel; t += dt)``。 :param kernel_name: CUDA kernel 名称 :type kernel_name: str :param reverse: 是否使用反向时间循环 :type reverse: bool ---- .. _ckernel2d-init-en: * **English** A 2D CUDA kernel wrapper derived from :class:`CKernel`. It provides built-in parameters ``numel`` (total element count) and ``N`` (elements per time step). Any 2D tensor is interpreted as ``[T, N]``, where ``T`` is sequence length. ``reverse`` controls the temporal loop direction: ``True`` uses ``for(int t = numel - N + index; t >= 0; t -= dt)``; ``False`` uses ``for(int t = index; t < numel; t += dt)``. :param kernel_name: CUDA kernel name :type kernel_name: str :param reverse: Whether to use reverse temporal traversal :type reverse: bool """ super().__init__(kernel_name) self.cparams["numel"] = "const int &" self.reverse = reverse self.cparams["N"] = "const int &" self.reserved_cnames.append("index") self.reserved_cnames.append("dt") self.reserved_cnames.append("t") self._pre_core = "" self._post_core = "" @property def pre_core(self): return self._pre_core @pre_core.setter def pre_core(self, value: str): self._pre_core = value @property def post_core(self): return self._post_core @post_core.setter def post_core(self, value: str): self._post_core = value
[文档] def check_shape(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel2d-check-shape-cn>` | :ref:`English <ckernel2d-check-shape-en>` ---- .. _ckernel2d-check-shape-cn: * **中文** 检查 ``py_dict`` 中的张量维度。所有 ``torch.Tensor`` 与 ``cupy.ndarray`` 的维度都必须不超过 2。 :param py_dict: kernel 参数字典 :type py_dict: dict ---- .. _ckernel2d-check-shape-en: * **English** Validate tensor dimensionality in ``py_dict``. All ``torch.Tensor`` and ``cupy.ndarray`` values must have ``ndim <= 2``. :param py_dict: Kernel argument dictionary :type py_dict: dict """ # all tensors should be ndim <= 2 for value in py_dict.values(): if isinstance(value, torch.Tensor): assert value.ndim <= 2 elif isinstance(value, cupy.ndarray): assert value.ndim <= 2
[文档] def check_half2(self, py_dict: dict): r""" **API Language:** :ref:`中文 <ckernel2d-check-half2-cn>` | :ref:`English <ckernel2d-check-half2-en>` ---- .. _ckernel2d-check-half2-cn: * **中文** 检查 ``py_dict`` 中半精度张量是否满足 ``half2`` 对齐要求(偶数长度)。 对 ``torch.half`` 和 ``np.float16``: 1D 张量要求 ``numel`` 为偶数,2D 张量要求 ``shape[1]`` 为偶数。 :param py_dict: kernel 参数字典 :type py_dict: dict .. admonition:: Note :class: note 实际执行前,:meth:`CKernel2D.__call__` 会自动补齐奇数长度半精度张量; 本函数用于约束与校验。 ---- .. _ckernel2d-check-half2-en: * **English** Check whether half-precision tensors in ``py_dict`` satisfy ``half2`` alignment requirements (even length). For ``torch.half`` and ``np.float16`` values: 1D tensors require even ``numel``; 2D tensors require even ``shape[1]``. :param py_dict: Kernel argument dictionary :type py_dict: dict .. admonition:: Note :class: note :meth:`CKernel2D.__call__` performs automatic padding for odd-sized half tensors before launch. This method is for validation checks. """ for key, value in py_dict.items(): if isinstance(value, torch.Tensor): if value.dtype == torch.half: if value.ndim <= 1: assert value.numel() % 2 == 0 elif value.ndim == 2: assert value.shape[1] % 2 == 0 elif isinstance(value, cupy.ndarray): if value.dtype == np.float16: if value.ndim <= 1: assert value.size % 2 == 0 elif value.ndim == 2: assert value.shape[1] % 2 == 0
[文档] def __call__(self, grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs): r""" **API Language:** :ref:`中文 <ckernel2d-call-cn>` | :ref:`English <ckernel2d-call-en>` ---- .. _ckernel2d-call-cn: * **中文** 执行二维 CUDA kernel。``*args_1`` 与 ``**kwargs`` 会透传给 :class:`cupy.RawKernel` 调用。 ``py_dict`` 中 ``key`` 必须与 ``self.cparams`` 的形参名一一对应, ``value`` 需为匹配数据类型的 ``torch.Tensor`` 或 ``cupy.ndarray``。 键顺序可以任意,内部会按形参顺序重排。 :param grid: CUDA grid 配置 :type grid: tuple :param block: CUDA block 配置 :type block: tuple :param py_dict: kernel 参数字典 :type py_dict: dict .. admonition:: Note :class: note ``py_dict`` 中张量必须为 1D 或 2D。对于奇数长度半精度张量, 调用前会自动补齐,执行后移除补齐并恢复原始形状。 ---- .. _ckernel2d-call-en: * **English** Execute the 2D CUDA kernel. ``*args_1`` and ``**kwargs`` are forwarded directly to :class:`cupy.RawKernel`. Keys in ``py_dict`` must match ``self.cparams`` one-to-one, and values must be ``torch.Tensor``/``cupy.ndarray`` objects with compatible dtypes. Key order is arbitrary because arguments are aligned internally by formal parameter order. :param grid: CUDA grid configuration :type grid: tuple :param block: CUDA block configuration :type block: tuple :param py_dict: Kernel argument dictionary :type py_dict: dict .. admonition:: Note :class: note Tensor inputs must be 1D or 2D. Odd-sized half-precision tensors are padded before launch, then unpadded and reshaped back afterward. """ self.check_shape(py_dict) # pad half2 pad_keys = [] pad_shapes = [] for key, value in py_dict.items(): if isinstance(value, torch.Tensor) and value.dtype == torch.half: if value.ndim <= 1: # 1D tensor if value.numel() % 2 != 0: pad_shapes.append(value.shape) pad_keys.append(key) value = value.flatten() value = torch.cat((value, value[-1].view(1))) py_dict[key] = value elif value.shape[1] % 2 != 0: # 2D tensor with shape = [T, N] and N % 2 != 0 pad_shapes.append(value.shape) pad_keys.append(key) value = torch.cat((value, value[:, -1].view(-1, 1)), dim=1) # [T, N] -> [T, N + 1] py_dict[key] = value elif isinstance(value, cupy.ndarray) and value.dtype == np.float16: if value.ndim <= 1: # 1D tensor if value.size % 2 != 0: pad_shapes.append(value.shape) pad_keys.append(key) value = cupy.reshape(value, -1) value = cupy.concatenate((value, cupy.reshape(value[-1], 1))) py_dict[key] = value elif value.shape[1] % 2 != 0: pad_shapes.append(value.shape) pad_keys.append(key) # [T, N] -> [T, N + 1] value = cupy.concatenate( (value, cupy.reshape(value[:, -1], (-1, 1))), axis=1 ) py_dict[key] = value super().__call__(grid, block, py_dict, *args_1, **kwargs) # move pad values for i, key in enumerate(pad_keys): value = py_dict[key] shape = pad_shapes[i] if isinstance(value, torch.Tensor): if value.ndim <= 1: value = value[:-1] value = value.view(shape) else: value = value[:, :-1] elif isinstance(value, cupy.ndarray): if value.ndim <= 1: value = value[:, -1] value = cupy.reshape(value, shape) else: value = value[:, :-1] py_dict[key] = value
@property def head(self): codes = """ { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < N) { const int dt = N; """ codes += wrap_with_comment(self.pre_core, "pre_core") if self.reverse: codes += """ for(int t = numel - N + index; t >= 0; t -= dt) { """ else: codes += """ for(int t = index; t < numel; t += dt) { """ return codes @property def tail(self): codes = """ } """ codes += wrap_with_comment(self.post_core, "post_core") codes += """ } } """ return codes
[文档] def simple_call(self, **kwargs): r""" **API Language:** :ref:`中文 <ckernel2d-simple-call-cn>` | :ref:`English <ckernel2d-simple-call-en>` ---- .. _ckernel2d-simple-call-cn: * **中文** :meth:`CKernel2D.__call__` 的便捷封装。该函数会从 ``kwargs`` 自动推断 设备、``numel``、``N``,并根据配置计算 CUDA ``threads`` 与 ``blocks`` 后执行。 :param kwargs: kernel 参数键值对(无需手动传入 ``numel`` 与 ``N``) :type kwargs: dict ---- .. _ckernel2d-simple-call-en: * **English** A convenience wrapper of :meth:`CKernel2D.__call__`. It infers device, ``numel``, and ``N`` from ``kwargs``, computes launch ``threads`` and ``blocks`` from configuration, and then launches the kernel. :param kwargs: Kernel argument mapping (``numel`` and ``N`` are inferred) :type kwargs: dict """ py_dict = kwargs device = self.get_device(py_dict) numel = None N = None for value in kwargs.values(): if isinstance(value, torch.Tensor) and value.ndim == 2: numel = value.numel() N = value.shape[1] elif isinstance(value, cupy.ndarray) and value.ndim == 2: numel = value.size N = value.shape[1] if numel is None or N is None: raise ValueError("No 2D torch.Tensor or cupy.ndarray in kwargs!") with cuda_utils.DeviceEnvironment(device): threads = configure.cuda_threads blocks = cuda_utils.cal_blocks(numel) numel = cupy.asarray(numel) N = cupy.asarray(N) py_dict["numel"] = numel py_dict["N"] = N self.__call__((blocks,), (threads,), py_dict)
[文档] class CodeTyper: def __init__(self, indent_num: int): r""" **API Language:** :ref:`中文 <codetyper-init-cn>` | :ref:`English <codetyper-init-en>` ---- .. _codetyper-init-cn: * **中文** CUDA 代码缩进与拼接工具。内部维护 ``self.codes``,可逐段追加代码并按缩进格式化。 :param indent_num: 初始缩进空格数 :type indent_num: int ---- .. _codetyper-init-en: * **English** A helper for formatting and assembling CUDA code with indentation. The accumulated code text is stored in ``self.codes``. :param indent_num: Number of spaces used for initial indentation :type indent_num: int """ self.indent = " " * indent_num self.codes = "\n"
[文档] def append(self, codes: str): r""" **API Language:** :ref:`中文 <codetyper-append-cn>` | :ref:`English <codetyper-append-en>` ---- .. _codetyper-append-cn: * **中文** 将输入 CUDA 代码片段追加到 ``self.codes``。函数按 ``;`` 分句并逐句写入, 同时处理 ``{``/``}`` 这类块边界语句。 :param codes: 待追加的 CUDA 代码 :type codes: str ---- .. _codetyper-append-en: * **English** Append CUDA code snippets into ``self.codes``. The method splits by ``;`` and writes each statement with current indentation, while handling ``{`` and ``}`` block boundary tokens. :param codes: CUDA code snippet to append :type codes: str """ codes = codes.replace("\n", "") codes = codes.split(";") for i in range(codes.__len__()): if codes[i].__len__() > 0: if codes[i] in ("{", "}"): self.codes += self.indent + codes[i] + "\n" else: self.codes += self.indent + codes[i] + ";\n"
[文档] class CodeBlock: def __init__(self, env: CodeTyper): r""" **API Language:** :ref:`中文 <codeblock-init-cn>` | :ref:`English <codeblock-init-en>` ---- .. _codeblock-init-cn: * **中文** ``CodeTyper`` 的上下文管理器工具,用于自动插入代码块 ``{...}`` 并维护缩进, 便于组织包含中间变量的多行 CUDA 逻辑。 :param env: 目标代码环境 :type env: CodeTyper ---- .. _codeblock-init-en: * **English** A context-manager utility for ``CodeTyper`` that inserts ``{...}`` blocks and adjusts indentation automatically. It is useful for composing multi-line CUDA logic with intermediate variables. :param env: Target code-typing environment :type env: CodeTyper """ self.env = env def __enter__(self): r""" **API Language:** :ref:`中文 <codeblock-enter-cn>` | :ref:`English <codeblock-enter-en>` ---- .. _codeblock-enter-cn: * **中文** 进入上下文时写入 ``{`` 并增加一级缩进。 :return: 当前上下文对象(默认返回 ``None``) :rtype: None ---- .. _codeblock-enter-en: * **English** Enter the context by appending ``{`` and increasing indentation by one level. :return: Current context object (defaults to ``None``) :rtype: None """ self.env.append("{") self.env.indent += " " def __exit__(self, exc_type, exc_val, exc_tb): r""" **API Language:** :ref:`中文 <codeblock-exit-cn>` | :ref:`English <codeblock-exit-en>` ---- .. _codeblock-exit-cn: * **中文** 退出上下文时回退一级缩进并写入 ``}``。 :param exc_type: 异常类型 :type exc_type: type :param exc_val: 异常值 :type exc_val: BaseException :param exc_tb: 回溯对象 :type exc_tb: traceback :return: ``False``(默认行为,不屏蔽异常) :rtype: bool ---- .. _codeblock-exit-en: * **English** Exit the context by decreasing indentation by one level and appending ``}``. :param exc_type: Exception type :type exc_type: type :param exc_val: Exception value :type exc_val: BaseException :param exc_tb: Traceback object :type exc_tb: traceback :return: ``False`` by default (exceptions are not suppressed) :rtype: bool """ self.env.indent = self.env.indent[:-1] self.env.append("}")