spikingjelly.activation_based.cuda_kernel.spike_op 源代码

import logging
from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _pair, _single, _triple
from torch.types import _int, _size
from torch.utils.cpp_extension import load_inline

from . import tensor_cache

try:
    import cupy
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.spike_op: {e}")
    cupy = None


try:
    logging.warning(
        "spikingjelly.activation_based.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions."
    )
    logging.warning(
        f"If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory('', False)}.)"
    )
    cpp_wrapper = load_inline(
        name="cpp_wrapper",
        cpp_sources=r"""
        #include <ATen/ATen.h>
        #include <array>
        #include <tuple>
        #include <vector>

        std::tuple<at::Tensor, at::Tensor> cudnn_convolution_backward(
            const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
            at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation,
            at::IntArrayRef output_padding, int64_t groups,
            bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
            (void)benchmark;
            (void)deterministic;
            (void)allow_tf32;
            auto grads = at::convolution_backward(
                grad_output, input, weight, c10::nullopt,
                stride, padding, dilation, false, output_padding, groups,
                {output_mask[0], output_mask[1], false});
            return std::make_tuple(std::get<0>(grads), std::get<1>(grads));
        }
        """,
        functions=["cudnn_convolution_backward"],
        with_cuda=True,
    )
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.spike_op: {e}")
    cpp_wrapper = None


def _spike_conv_backward_common(
    spike: torch.Tensor,
    grad_output: torch.Tensor,
    weight: torch.Tensor,
    padding,
    stride,
    dilation,
    groups: int,
    output_mask: tuple[bool, bool],
):
    # This wrapper targets standard (non-transposed) convolution backward only.
    # `output_padding` is fixed to zeros to match conv1d/2d/3d forward usage in this module.
    if cpp_wrapper is None:
        raise RuntimeError(
            "cpp_wrapper is unavailable for spike convolution backward. "
            "Please ensure the inline extension can be built in this environment."
        )
    output_padding = [0] * (len(stride) if isinstance(stride, (list, tuple)) else 1)
    return cpp_wrapper.cudnn_convolution_backward(
        spike,
        grad_output,
        weight,
        padding,
        stride,
        dilation,
        output_padding,
        groups,
        torch.backends.cudnn.benchmark,
        torch.backends.cudnn.deterministic,
        torch.backends.cudnn.allow_tf32,
        output_mask,
    )


def _normalize_conv_hyperparams(
    spike_dim: int, stride, padding, dilation
) -> tuple[tuple[int, ...], Union[str, tuple[int, ...]], tuple[int, ...]]:
    if isinstance(padding, str):
        if spike_dim == 3:
            return tuple(_single(stride)), padding, tuple(_single(dilation))
        if spike_dim == 4:
            return tuple(_pair(stride)), padding, tuple(_pair(dilation))
        if spike_dim == 5:
            return tuple(_triple(stride)), padding, tuple(_triple(dilation))
        raise ValueError(
            f"spikeConvolution only supports 3D/4D/5D input, but got dim={spike_dim}."
        )

    if spike_dim == 3:
        return tuple(_single(stride)), tuple(_single(padding)), tuple(_single(dilation))
    if spike_dim == 4:
        return tuple(_pair(stride)), tuple(_pair(padding)), tuple(_pair(dilation))
    if spike_dim == 5:
        return tuple(_triple(stride)), tuple(_triple(padding)), tuple(_triple(dilation))
    raise ValueError(
        f"spikeConvolution only supports 3D/4D/5D input, but got dim={spike_dim}."
    )


@torch.library.custom_op("sj::cupy_spike_linear_forward", mutates_args=())
def cupy_spike_linear_forward(
    spike: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
    return F.linear(spike, weight, bias)


@torch.library.register_fake("sj::cupy_spike_linear_forward")
def _cupy_spike_linear_forward_fake(
    spike: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
    return F.linear(spike, weight, bias)


def _setup_cupy_spike_linear_context(ctx, inputs, output):
    del output
    spike, weight, bias = inputs
    ctx.s_shape = spike.shape
    ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
    ctx.save_for_backward(weight)
    ctx.has_bias = bias is not None


def _cupy_spike_linear_backward(ctx, grad_output):
    (weight,) = ctx.saved_tensors
    spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
    weight = weight.to(grad_output.dtype)
    grad_spike = F.linear(grad_output, weight.t(), bias=None)
    in_features = spike.shape[-1]
    out_features = grad_output.shape[-1]
    grad_weight = torch.mm(
        grad_output.reshape(-1, out_features).t(),
        spike.reshape(-1, in_features).to(grad_output.dtype),
    )
    grad_bias = grad_output.reshape(-1, out_features).sum(0) if ctx.has_bias else None
    return grad_spike, grad_weight, grad_bias


torch.library.register_autograd(
    "sj::cupy_spike_linear_forward",
    _cupy_spike_linear_backward,
    setup_context=_setup_cupy_spike_linear_context,
)


@torch.library.custom_op("sj::cupy_spike_convolution_forward", mutates_args=())
def cupy_spike_convolution_forward(
    spike: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor],
    stride: list[int],
    padding: list[int],
    dilation: list[int],
    groups: int,
) -> torch.Tensor:
    if spike.dim() == 3:
        return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
    if spike.dim() == 4:
        return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
    return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)


@torch.library.register_fake("sj::cupy_spike_convolution_forward")
def _cupy_spike_convolution_forward_fake(
    spike: torch.Tensor,
    weight: torch.Tensor,
    bias: Optional[torch.Tensor],
    stride: list[int],
    padding: list[int],
    dilation: list[int],
    groups: int,
) -> torch.Tensor:
    if spike.dim() == 3:
        return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
    if spike.dim() == 4:
        return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
    return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)


def _setup_cupy_spike_convolution_context(ctx, inputs, output):
    del output
    spike, weight, bias, stride, padding, dilation, groups = inputs
    stride, padding, dilation = _normalize_conv_hyperparams(
        spike.dim(), stride, padding, dilation
    )
    ctx.need_grad_spike = bool(spike.requires_grad)
    ctx.need_grad_weight = bool(weight.requires_grad)
    ctx.need_grad_bias = bool(bias is not None and bias.requires_grad)
    ctx.save_for_backward(weight)
    if ctx.need_grad_spike or ctx.need_grad_weight:
        ctx.s_shape = spike.shape
        ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
    else:
        ctx.s_shape = None
        ctx.s_tk = None
    ctx.stride = stride
    ctx.padding = padding
    ctx.dilation = dilation
    ctx.groups = groups
    ctx.has_bias = bias is not None


def _cupy_spike_convolution_backward(ctx, grad_output):
    grad_spike = None
    grad_weight = None
    if ctx.need_grad_spike or ctx.need_grad_weight:
        (weight,) = ctx.saved_tensors
        spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
        weight = weight.to(grad_output.dtype)
        grad_spike, grad_weight = _spike_conv_backward_common(
            spike,
            grad_output,
            weight,
            ctx.padding,
            ctx.stride,
            ctx.dilation,
            ctx.groups,
            (ctx.need_grad_spike, ctx.need_grad_weight),
        )

    if ctx.has_bias and ctx.need_grad_bias:
        reduce_dims = (0, *range(2, grad_output.dim()))
        grad_bias = grad_output.sum(dim=reduce_dims)
    else:
        grad_bias = None
    return grad_spike, grad_weight, grad_bias, None, None, None, None


torch.library.register_autograd(
    "sj::cupy_spike_convolution_forward",
    _cupy_spike_convolution_backward,
    setup_context=_setup_cupy_spike_convolution_context,
)


[文档] def spike_linear( spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None ) -> Tensor: r""" **API Language:** :ref:`中文 <spike_linear-cn>` | :ref:`English <spike_linear-en>` ---- .. _spike_linear-cn: * **中文** :class:`torch.nn.functional.linear` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _spike_linear-en: * **English** A specific case of :class:`torch.nn.functional.linear` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.linear(spike, weight, bias) else: return cupy_spike_linear_forward(spike, weight, bias)
[文档] def spike_conv1d( spike: Tensor, weight: Tensor, bias: Tensor = None, stride: Union[_int, _size] = 1, padding: str = "valid", dilation: Union[_int, _size] = 1, groups: _int = 1, ) -> Tensor: r""" **API Language:** :ref:`中文 <spike_conv1d-cn>` | :ref:`English <spike_conv1d-en>` ---- .. _spike_conv1d-cn: * **中文** :class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _spike_conv1d-en: * **English** A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv1d(spike, weight, bias, stride, padding, dilation, groups) else: if isinstance(padding, str): return F.conv1d(spike, weight, bias, stride, padding, dilation, groups) return cupy_spike_convolution_forward( spike, weight, bias, list(_single(stride)), list(_single(padding)), list(_single(dilation)), groups, )
[文档] def spike_conv2d( spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[_int, _size] = 1, padding: str = "valid", dilation: Union[_int, _size] = 1, groups: _int = 1, ) -> Tensor: r""" **API Language:** :ref:`中文 <spike_conv2d-cn>` | :ref:`English <spike_conv2d-en>` ---- .. _spike_conv2d-cn: * **中文** :class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _spike_conv2d-en: * **English** A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv2d(spike, weight, bias, stride, padding, dilation, groups) else: if isinstance(padding, str): return F.conv2d(spike, weight, bias, stride, padding, dilation, groups) return cupy_spike_convolution_forward( spike, weight, bias, list(_pair(stride)), list(_pair(padding)), list(_pair(dilation)), groups, )
[文档] def spike_conv3d( spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[_int, _size] = 1, padding: str = "valid", dilation: Union[_int, _size] = 1, groups: _int = 1, ) -> Tensor: r""" **API Language:** :ref:`中文 <spike_conv3d-cn>` | :ref:`English <spike_conv3d-en>` ---- .. _spike_conv3d-cn: * **中文** :class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _spike_conv3d-en: * **English** A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv3d(spike, weight, bias, stride, padding, dilation, groups) else: if isinstance(padding, str): return F.conv3d(spike, weight, bias, stride, padding, dilation, groups) return cupy_spike_convolution_forward( spike, weight, bias, list(_triple(stride)), list(_triple(padding)), list(_triple(dilation)), groups, )
[文档] class SpikeLinear(nn.Linear): r""" **API Language:** :ref:`中文 <SpikeLinear-cn>` | :ref:`English <SpikeLinear-en>` ---- .. _SpikeLinear-cn: * **中文** * **中文** :class:`torch.nn.Linear` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _SpikeLinear-en: * **English** * **English** A specific case of :class:`torch.nn.Linear` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """
[文档] def forward(self, spike: Tensor) -> Tensor: return spike_linear(spike, self.weight, self.bias)
[文档] class SpikeConv1d(nn.Conv1d): r""" **API Language:** :ref:`中文 <SpikeConv1d-cn>` | :ref:`English <SpikeConv1d-en>` ---- .. _SpikeConv1d-cn: * **中文** * **中文** :class:`torch.nn.Conv1d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _SpikeConv1d-en: * **English** * **English** A specific case of :class:`torch.nn.Conv1d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return spike_conv1d( F.pad( spike, self._reversed_padding_repeated_twice, mode=self.padding_mode ), weight, bias, self.stride, _single(0), self.dilation, self.groups, ) return spike_conv1d( spike, weight, bias, self.stride, self.padding, self.dilation, self.groups )
[文档] class SpikeConv2d(nn.Conv2d): r""" **API Language:** :ref:`中文 <SpikeConv2d-cn>` | :ref:`English <SpikeConv2d-en>` ---- .. _SpikeConv2d-cn: * **中文** * **中文** :class:`torch.nn.Conv2d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _SpikeConv2d-en: * **English** * **English** A specific case of :class:`torch.nn.Conv2d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return spike_conv2d( F.pad( spike, self._reversed_padding_repeated_twice, mode=self.padding_mode ), weight, bias, self.stride, _pair(0), self.dilation, self.groups, ) return spike_conv2d( spike, weight, bias, self.stride, self.padding, self.dilation, self.groups )
[文档] class SpikeConv3d(nn.Conv3d): r""" **API Language:** :ref:`中文 <SpikeConv3d-cn>` | :ref:`English <SpikeConv3d-en>` ---- .. _SpikeConv3d-cn: * **中文** * **中文** :class:`torch.nn.Conv3d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 ---- .. _SpikeConv3d-en: * **English** * **English** A specific case of :class:`torch.nn.Conv3d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return spike_conv3d( F.pad( spike, self._reversed_padding_repeated_twice, mode=self.padding_mode ), weight, bias, self.stride, _triple(0), self.dilation, self.groups, ) return spike_conv3d( spike, weight, bias, self.stride, self.padding, self.dilation, self.groups )