import logging
from typing import Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.utils import _pair, _single, _triple
from torch.types import _int, _size
from torch.utils.cpp_extension import load_inline
from . import tensor_cache
try:
import cupy
except BaseException as e:
logging.info(f"spikingjelly.activation_based.spike_op: {e}")
cupy = None
try:
logging.warning(
"spikingjelly.activation_based.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions."
)
logging.warning(
f"If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory('', False)}.)"
)
cpp_wrapper = load_inline(
name="cpp_wrapper",
cpp_sources=r"""
#include <ATen/ATen.h>
#include <array>
#include <tuple>
#include <vector>
std::tuple<at::Tensor, at::Tensor> cudnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation,
at::IntArrayRef output_padding, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool, 2> output_mask) {
(void)benchmark;
(void)deterministic;
(void)allow_tf32;
auto grads = at::convolution_backward(
grad_output, input, weight, c10::nullopt,
stride, padding, dilation, false, output_padding, groups,
{output_mask[0], output_mask[1], false});
return std::make_tuple(std::get<0>(grads), std::get<1>(grads));
}
""",
functions=["cudnn_convolution_backward"],
with_cuda=True,
)
except BaseException as e:
logging.info(f"spikingjelly.activation_based.spike_op: {e}")
cpp_wrapper = None
def _spike_conv_backward_common(
spike: torch.Tensor,
grad_output: torch.Tensor,
weight: torch.Tensor,
padding,
stride,
dilation,
groups: int,
output_mask: tuple[bool, bool],
):
# This wrapper targets standard (non-transposed) convolution backward only.
# `output_padding` is fixed to zeros to match conv1d/2d/3d forward usage in this module.
if cpp_wrapper is None:
raise RuntimeError(
"cpp_wrapper is unavailable for spike convolution backward. "
"Please ensure the inline extension can be built in this environment."
)
output_padding = [0] * (len(stride) if isinstance(stride, (list, tuple)) else 1)
return cpp_wrapper.cudnn_convolution_backward(
spike,
grad_output,
weight,
padding,
stride,
dilation,
output_padding,
groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32,
output_mask,
)
def _normalize_conv_hyperparams(
spike_dim: int, stride, padding, dilation
) -> tuple[tuple[int, ...], Union[str, tuple[int, ...]], tuple[int, ...]]:
if isinstance(padding, str):
if spike_dim == 3:
return tuple(_single(stride)), padding, tuple(_single(dilation))
if spike_dim == 4:
return tuple(_pair(stride)), padding, tuple(_pair(dilation))
if spike_dim == 5:
return tuple(_triple(stride)), padding, tuple(_triple(dilation))
raise ValueError(
f"spikeConvolution only supports 3D/4D/5D input, but got dim={spike_dim}."
)
if spike_dim == 3:
return tuple(_single(stride)), tuple(_single(padding)), tuple(_single(dilation))
if spike_dim == 4:
return tuple(_pair(stride)), tuple(_pair(padding)), tuple(_pair(dilation))
if spike_dim == 5:
return tuple(_triple(stride)), tuple(_triple(padding)), tuple(_triple(dilation))
raise ValueError(
f"spikeConvolution only supports 3D/4D/5D input, but got dim={spike_dim}."
)
@torch.library.custom_op("sj::cupy_spike_linear_forward", mutates_args=())
def cupy_spike_linear_forward(
spike: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
return F.linear(spike, weight, bias)
@torch.library.register_fake("sj::cupy_spike_linear_forward")
def _cupy_spike_linear_forward_fake(
spike: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
return F.linear(spike, weight, bias)
def _setup_cupy_spike_linear_context(ctx, inputs, output):
del output
spike, weight, bias = inputs
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
ctx.save_for_backward(weight)
ctx.has_bias = bias is not None
def _cupy_spike_linear_backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
weight = weight.to(grad_output.dtype)
grad_spike = F.linear(grad_output, weight.t(), bias=None)
in_features = spike.shape[-1]
out_features = grad_output.shape[-1]
grad_weight = torch.mm(
grad_output.reshape(-1, out_features).t(),
spike.reshape(-1, in_features).to(grad_output.dtype),
)
grad_bias = grad_output.reshape(-1, out_features).sum(0) if ctx.has_bias else None
return grad_spike, grad_weight, grad_bias
torch.library.register_autograd(
"sj::cupy_spike_linear_forward",
_cupy_spike_linear_backward,
setup_context=_setup_cupy_spike_linear_context,
)
@torch.library.custom_op("sj::cupy_spike_convolution_forward", mutates_args=())
def cupy_spike_convolution_forward(
spike: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
) -> torch.Tensor:
if spike.dim() == 3:
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
if spike.dim() == 4:
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
@torch.library.register_fake("sj::cupy_spike_convolution_forward")
def _cupy_spike_convolution_forward_fake(
spike: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
) -> torch.Tensor:
if spike.dim() == 3:
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
if spike.dim() == 4:
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
def _setup_cupy_spike_convolution_context(ctx, inputs, output):
del output
spike, weight, bias, stride, padding, dilation, groups = inputs
stride, padding, dilation = _normalize_conv_hyperparams(
spike.dim(), stride, padding, dilation
)
ctx.need_grad_spike = bool(spike.requires_grad)
ctx.need_grad_weight = bool(weight.requires_grad)
ctx.need_grad_bias = bool(bias is not None and bias.requires_grad)
ctx.save_for_backward(weight)
if ctx.need_grad_spike or ctx.need_grad_weight:
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
else:
ctx.s_shape = None
ctx.s_tk = None
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
ctx.groups = groups
ctx.has_bias = bias is not None
def _cupy_spike_convolution_backward(ctx, grad_output):
grad_spike = None
grad_weight = None
if ctx.need_grad_spike or ctx.need_grad_weight:
(weight,) = ctx.saved_tensors
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
weight = weight.to(grad_output.dtype)
grad_spike, grad_weight = _spike_conv_backward_common(
spike,
grad_output,
weight,
ctx.padding,
ctx.stride,
ctx.dilation,
ctx.groups,
(ctx.need_grad_spike, ctx.need_grad_weight),
)
if ctx.has_bias and ctx.need_grad_bias:
reduce_dims = (0, *range(2, grad_output.dim()))
grad_bias = grad_output.sum(dim=reduce_dims)
else:
grad_bias = None
return grad_spike, grad_weight, grad_bias, None, None, None, None
torch.library.register_autograd(
"sj::cupy_spike_convolution_forward",
_cupy_spike_convolution_backward,
setup_context=_setup_cupy_spike_convolution_context,
)
[文档]
def spike_linear(
spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None
) -> Tensor:
r"""
**API Language:**
:ref:`中文 <spike_linear-cn>` | :ref:`English <spike_linear-en>`
----
.. _spike_linear-cn:
* **中文**
:class:`torch.nn.functional.linear` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _spike_linear-en:
* **English**
A specific case of :class:`torch.nn.functional.linear` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.linear(spike, weight, bias)
else:
return cupy_spike_linear_forward(spike, weight, bias)
[文档]
def spike_conv1d(
spike: Tensor,
weight: Tensor,
bias: Tensor = None,
stride: Union[_int, _size] = 1,
padding: str = "valid",
dilation: Union[_int, _size] = 1,
groups: _int = 1,
) -> Tensor:
r"""
**API Language:**
:ref:`中文 <spike_conv1d-cn>` | :ref:`English <spike_conv1d-en>`
----
.. _spike_conv1d-cn:
* **中文**
:class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _spike_conv1d-en:
* **English**
A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
else:
if isinstance(padding, str):
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
return cupy_spike_convolution_forward(
spike,
weight,
bias,
list(_single(stride)),
list(_single(padding)),
list(_single(dilation)),
groups,
)
[文档]
def spike_conv2d(
spike: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[_int, _size] = 1,
padding: str = "valid",
dilation: Union[_int, _size] = 1,
groups: _int = 1,
) -> Tensor:
r"""
**API Language:**
:ref:`中文 <spike_conv2d-cn>` | :ref:`English <spike_conv2d-en>`
----
.. _spike_conv2d-cn:
* **中文**
:class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _spike_conv2d-en:
* **English**
A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
else:
if isinstance(padding, str):
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
return cupy_spike_convolution_forward(
spike,
weight,
bias,
list(_pair(stride)),
list(_pair(padding)),
list(_pair(dilation)),
groups,
)
[文档]
def spike_conv3d(
spike: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Union[_int, _size] = 1,
padding: str = "valid",
dilation: Union[_int, _size] = 1,
groups: _int = 1,
) -> Tensor:
r"""
**API Language:**
:ref:`中文 <spike_conv3d-cn>` | :ref:`English <spike_conv3d-en>`
----
.. _spike_conv3d-cn:
* **中文**
:class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _spike_conv3d-en:
* **English**
A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
else:
if isinstance(padding, str):
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
return cupy_spike_convolution_forward(
spike,
weight,
bias,
list(_triple(stride)),
list(_triple(padding)),
list(_triple(dilation)),
groups,
)
[文档]
class SpikeLinear(nn.Linear):
r"""
**API Language:**
:ref:`中文 <SpikeLinear-cn>` | :ref:`English <SpikeLinear-en>`
----
.. _SpikeLinear-cn:
* **中文**
* **中文**
:class:`torch.nn.Linear` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _SpikeLinear-en:
* **English**
* **English**
A specific case of :class:`torch.nn.Linear` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
[文档]
def forward(self, spike: Tensor) -> Tensor:
return spike_linear(spike, self.weight, self.bias)
[文档]
class SpikeConv1d(nn.Conv1d):
r"""
**API Language:**
:ref:`中文 <SpikeConv1d-cn>` | :ref:`English <SpikeConv1d-en>`
----
.. _SpikeConv1d-cn:
* **中文**
* **中文**
:class:`torch.nn.Conv1d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _SpikeConv1d-en:
* **English**
* **English**
A specific case of :class:`torch.nn.Conv1d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return spike_conv1d(
F.pad(
spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_single(0),
self.dilation,
self.groups,
)
return spike_conv1d(
spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
[文档]
class SpikeConv2d(nn.Conv2d):
r"""
**API Language:**
:ref:`中文 <SpikeConv2d-cn>` | :ref:`English <SpikeConv2d-en>`
----
.. _SpikeConv2d-cn:
* **中文**
* **中文**
:class:`torch.nn.Conv2d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _SpikeConv2d-en:
* **English**
* **English**
A specific case of :class:`torch.nn.Conv2d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return spike_conv2d(
F.pad(
spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_pair(0),
self.dilation,
self.groups,
)
return spike_conv2d(
spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
)
[文档]
class SpikeConv3d(nn.Conv3d):
r"""
**API Language:**
:ref:`中文 <SpikeConv3d-cn>` | :ref:`English <SpikeConv3d-en>`
----
.. _SpikeConv3d-cn:
* **中文**
* **中文**
:class:`torch.nn.Conv3d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
----
.. _SpikeConv3d-en:
* **English**
* **English**
A specific case of :class:`torch.nn.Conv3d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return spike_conv3d(
F.pad(
spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_triple(0),
self.dilation,
self.groups,
)
return spike_conv3d(
spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
)