import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
from torch.cuda.amp import custom_fwd, custom_bwd
import logging
from . import tensor_cache
from torch import Tensor
from typing import Optional, Union
from torch.types import _int, _size
from torch.nn.modules.utils import _single, _pair, _triple
try:
import cupy
except BaseException as e:
logging.info(f'spikingjelly.activation_based.spike_op: {e}')
cupy = None
try:
logging.warning('spikingjelly.activation_based.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions.')
logging.warning(f'If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory("", False)}.)')
cpp_wrapper = load_inline(
name='cpp_wrapper',
cpp_sources='using namespace at;',
functions=[
'cudnn_convolution_backward',
'cudnn_convolution_backward_input',
'cudnn_convolution_backward_weight'
],
with_cuda=True
)
except BaseException as e:
logging.info(f'spikingjelly.activation_based.spike_op: {e}')
cpp_wrapper = None
'''
aten/src/ATen/native/cudnn/ConvPlaceholders.cpp
at::Tensor cudnn_convolution(
const at::Tensor& input, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)
There are two overloaded C++ methods `cudnn_convolution`. So, we need to use an alternative syntax to cast the overloaded function.
Refer to https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods and https://github.com/pytorch/pytorch/issues/39518 for more details.
aten/src/ATen/native/cudnn/ConvShared.cpp
Tensor cudnn_convolution_forward(
CheckedFrom c,
const TensorArg& input, const TensorArg& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
aten/src/ATen/native/cudnn/ConvPlaceholders.cpp
std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask)
aten/src/ATen/native/cudnn/ConvShared.cpp
at::Tensor cudnn_convolution_backward_input(
IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
aten/src/ATen/native/cudnn/ConvShared.cpp
at::Tensor cudnn_convolution_backward_weight(
IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
'''
[文档]class spikeConvolution(torch.autograd.Function):
# Pytorch only provides cudnn_convolution without bias.
# Refer to https://github.com/pytorch/pytorch/issues/3823 for more details.
[文档] @staticmethod
@custom_fwd
def forward(ctx, spike, weight, bias, stride, padding, dilation, groups):
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
if ctx.needs_input_grad[1]:
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
if ctx.needs_input_grad[0]:
ctx.save_for_backward(weight)
ctx.padding = padding
ctx.stride = stride
ctx.dilation = dilation
ctx.groups = groups
ctx.weight_shape = weight.shape
if spike.dim() == 3:
return F.conv1d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
elif spike.dim() == 4:
return F.conv2d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
elif spike.dim() == 5:
return F.conv3d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
[文档] @staticmethod
@custom_bwd
def backward(ctx, grad_output):
grad_spike = None
grad_weight = None
grad_bias = None
if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
weight = weight.to(grad_output.dtype)
grad_spike, grad_weight = cpp_wrapper.cudnn_convolution_backward(spike, grad_output, weight, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32, (
True,
True))
elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
grad_weight = cpp_wrapper.cudnn_convolution_backward_weight(ctx.weight_shape, grad_output, spike, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32)
elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
weight = weight.to(grad_output.dtype)
grad_spike = cpp_wrapper.cudnn_convolution_backward_input(ctx.spike_shape, grad_output, weight, ctx.padding,
ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32)
if ctx.needs_input_grad[2]:
# grad_output.shape = [N, C, *]
out_channels = grad_output.shape[1]
grad_bias = grad_output.transpose(0, 1).reshape(out_channels, -1).sum(1)
return grad_spike, grad_weight, grad_bias, None, None, None, None
[文档]class spikeLinear(torch.autograd.Function):
[文档] @staticmethod
@custom_fwd
def forward(ctx, spike, weight, bias=None):
# spike.shape = [N, *, in_features]
# weight.shape = [out_features, in_features]
# bias.shape = [out_features]
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
if ctx.needs_input_grad[1]:
ctx.s_shape = spike.shape
ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
if ctx.needs_input_grad[0]:
ctx.save_for_backward(weight)
return F.linear(spike, weight, bias)
[文档] @staticmethod
@custom_bwd
def backward(ctx, grad_output):
# grad_output.shape = [N, *, out_features]
if ctx.needs_input_grad[1]:
weight = ctx.saved_tensors[0]
if ctx.needs_input_grad[0]:
spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
grad_spike = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_spike = F.linear(grad_output, weight.t(), bias=None)
if ctx.needs_input_grad[1]:
in_features = spike.shape[-1]
out_features = grad_output.shape[-1]
# grad_output.reshape(-1, out_features).t().shape = [out_features, N*]
# spike.reshape(-1, in_features).shape = [N*, in_features]
grad_weight = torch.mm(grad_output.reshape(-1, out_features).t(), spike.reshape(-1, in_features).to(grad_output.dtype))
if ctx.needs_input_grad[2]:
out_features = grad_output.shape[-1]
grad_bias = grad_output.reshape(-1, out_features).sum(0)
return grad_spike, grad_weight, grad_bias
[文档]def spike_linear(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""
* :ref:`API in English <spike_linear-en>`
.. _spike_linear-cn:
:class:`torch.nn.functional.linear` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <spike_linear-cn>`
.. _spike_linear-en:
A specific case of :class:`torch.nn.functional.linear` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.linear(spike, weight, bias)
else:
return spikeLinear.apply(spike, weight, bias)
[文档]def spike_conv1d(spike: Tensor, weight: Tensor, bias: Tensor=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv1d-en>`
.. _spike_conv1d-cn:
:class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <spike_conv1d-cn>`
.. _spike_conv1d-en:
A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]def spike_conv2d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv2d-en>`
.. _spike_conv2d-cn:
:class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <spike_conv2d-cn>`
.. _spike_conv2d-en:
A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]def spike_conv3d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
"""
* :ref:`API in English <spike_conv3d-en>`
.. _spike_conv3d-cn:
:class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <spike_conv3d-cn>`
.. _spike_conv3d-en:
A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
if spike.get_device() < 0:
return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
else:
return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]class SpikeLinear(nn.Linear):
"""
* :ref:`API in English <SpikeLinear-en>`
.. _SpikeLinear-cn:
:class:`torch.nn.Linear` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <SpikeLinear-cn>`
.. _SpikeLinear-en:
A specific case of :class:`torch.nn.Linear` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
[文档] def forward(self, spike: Tensor) -> Tensor:
return spike_linear(spike, self.weight, self.bias)
[文档]class SpikeConv1d(nn.Conv1d):
"""
* :ref:`API in English <SpikeConv1d-en>`
.. _SpikeConv1d-cn:
:class:`torch.nn.Conv1d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <SpikeConv1d-cn>`
.. _SpikeConv1d-en:
A specific case of :class:`torch.nn.Conv1d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return spike_conv1d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_single(0), self.dilation, self.groups)
return spike_conv1d(spike, weight, bias, self.stride,
self.padding, self.dilation, self.groups)
[文档]class SpikeConv2d(nn.Conv2d):
"""
* :ref:`API in English <SpikeConv2d-en>`
.. _SpikeConv2d-cn:
:class:`torch.nn.Conv2d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <SpikeConv2d-cn>`
.. _SpikeConv2d-en:
A specific case of :class:`torch.nn.Conv2d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != 'zeros':
return spike_conv2d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)
return spike_conv2d(spike, weight, bias, self.stride,
self.padding, self.dilation, self.groups)
[文档]class SpikeConv3d(nn.Conv3d):
"""
* :ref:`API in English <SpikeConv3d-en>`
.. _SpikeConv3d-cn:
:class:`torch.nn.Conv3d` 在输入为脉冲时的特例。
.. note::
在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。
.. warning::
`spike` 中的任何元素都必须为0或1。
* :ref:`中文API <SpikeConv3d-cn>`
.. _SpikeConv3d-en:
A specific case of :class:`torch.nn.Conv3d` with inputs are spikes.
.. admonition:: Note
:class: note
This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices.
.. admonition:: Warning
:class: warning
Any element in `spike` must be 0 or 1.
"""
def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
if self.padding_mode != "zeros":
return spike_conv3d(
F.pad(
spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
),
weight,
bias,
self.stride,
_triple(0),
self.dilation,
self.groups,
)
return spike_conv3d(
spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
)