spikingjelly.activation_based.spike_op 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
from torch.cuda.amp import custom_fwd, custom_bwd
import logging
from . import tensor_cache

from torch import Tensor
from typing import Optional, Union
from torch.types import _int, _size
from torch.nn.modules.utils import _single, _pair, _triple

try:
    import cupy
except BaseException as e:
    logging.info(f'spikingjelly.activation_based.spike_op: {e}')
    cupy = None


try:
    logging.warning('spikingjelly.activation_based.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions.')
    logging.warning(f'If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory("", False)}.)')
    cpp_wrapper = load_inline(
            name='cpp_wrapper',
            cpp_sources='using namespace at;',
            functions=[
                'cudnn_convolution_backward',
                'cudnn_convolution_backward_input',
                'cudnn_convolution_backward_weight'
            ],
            with_cuda=True
    )
except BaseException as e:
    logging.info(f'spikingjelly.activation_based.spike_op: {e}')
    cpp_wrapper = None

'''
aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

at::Tensor cudnn_convolution(
    const at::Tensor& input, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
    int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)

There are two overloaded C++ methods `cudnn_convolution`. So, we need to use an alternative syntax to cast the overloaded function.
Refer to https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods and https://github.com/pytorch/pytorch/issues/39518 for more details.
    
aten/src/ATen/native/cudnn/ConvShared.cpp

Tensor cudnn_convolution_forward(
    CheckedFrom c,
    const TensorArg& input, const TensorArg& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)

aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
    const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask)
  
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_input(
    IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)
    
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_weight(
    IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)
'''

[文档]class spikeConvolution(torch.autograd.Function): # Pytorch only provides cudnn_convolution without bias. # Refer to https://github.com/pytorch/pytorch/issues/3823 for more details.
[文档] @staticmethod @custom_fwd def forward(ctx, spike, weight, bias, stride, padding, dilation, groups): if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: if ctx.needs_input_grad[1]: ctx.s_shape = spike.shape ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike) if ctx.needs_input_grad[0]: ctx.save_for_backward(weight) ctx.padding = padding ctx.stride = stride ctx.dilation = dilation ctx.groups = groups ctx.weight_shape = weight.shape if spike.dim() == 3: return F.conv1d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) elif spike.dim() == 4: return F.conv2d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) elif spike.dim() == 5: return F.conv3d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
[文档] @staticmethod @custom_bwd def backward(ctx, grad_output): grad_spike = None grad_weight = None grad_bias = None if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: weight = ctx.saved_tensors[0] spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape) weight = weight.to(grad_output.dtype) grad_spike, grad_weight = cpp_wrapper.cudnn_convolution_backward(spike, grad_output, weight, ctx.padding, ctx.stride, ctx.dilation, ctx.groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32, ( True, True)) elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]: spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape) grad_weight = cpp_wrapper.cudnn_convolution_backward_weight(ctx.weight_shape, grad_output, spike, ctx.padding, ctx.stride, ctx.dilation, ctx.groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32) elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]: weight = ctx.saved_tensors[0] weight = weight.to(grad_output.dtype) grad_spike = cpp_wrapper.cudnn_convolution_backward_input(ctx.spike_shape, grad_output, weight, ctx.padding, ctx.stride, ctx.dilation, ctx.groups, torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32) if ctx.needs_input_grad[2]: # grad_output.shape = [N, C, *] out_channels = grad_output.shape[1] grad_bias = grad_output.transpose(0, 1).reshape(out_channels, -1).sum(1) return grad_spike, grad_weight, grad_bias, None, None, None, None
[文档]class spikeLinear(torch.autograd.Function):
[文档] @staticmethod @custom_fwd def forward(ctx, spike, weight, bias=None): # spike.shape = [N, *, in_features] # weight.shape = [out_features, in_features] # bias.shape = [out_features] if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: if ctx.needs_input_grad[1]: ctx.s_shape = spike.shape ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike) if ctx.needs_input_grad[0]: ctx.save_for_backward(weight) return F.linear(spike, weight, bias)
[文档] @staticmethod @custom_bwd def backward(ctx, grad_output): # grad_output.shape = [N, *, out_features] if ctx.needs_input_grad[1]: weight = ctx.saved_tensors[0] if ctx.needs_input_grad[0]: spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape) grad_spike = grad_weight = grad_bias = None if ctx.needs_input_grad[0]: grad_spike = F.linear(grad_output, weight.t(), bias=None) if ctx.needs_input_grad[1]: in_features = spike.shape[-1] out_features = grad_output.shape[-1] # grad_output.reshape(-1, out_features).t().shape = [out_features, N*] # spike.reshape(-1, in_features).shape = [N*, in_features] grad_weight = torch.mm(grad_output.reshape(-1, out_features).t(), spike.reshape(-1, in_features).to(grad_output.dtype)) if ctx.needs_input_grad[2]: out_features = grad_output.shape[-1] grad_bias = grad_output.reshape(-1, out_features).sum(0) return grad_spike, grad_weight, grad_bias
[文档]def spike_linear(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: """ * :ref:`API in English <spike_linear-en>` .. _spike_linear-cn: :class:`torch.nn.functional.linear` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <spike_linear-cn>` .. _spike_linear-en: A specific case of :class:`torch.nn.functional.linear` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.linear(spike, weight, bias) else: return spikeLinear.apply(spike, weight, bias)
[文档]def spike_conv1d(spike: Tensor, weight: Tensor, bias: Tensor=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor: """ * :ref:`API in English <spike_conv1d-en>` .. _spike_conv1d-cn: :class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <spike_conv1d-cn>` .. _spike_conv1d-en: A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv1d(spike, weight, bias, stride, padding, dilation, groups) else: return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]def spike_conv2d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor: """ * :ref:`API in English <spike_conv2d-en>` .. _spike_conv2d-cn: :class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <spike_conv2d-cn>` .. _spike_conv2d-en: A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv2d(spike, weight, bias, stride, padding, dilation, groups) else: return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]def spike_conv3d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor: """ * :ref:`API in English <spike_conv3d-en>` .. _spike_conv3d-cn: :class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <spike_conv3d-cn>` .. _spike_conv3d-en: A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ if spike.get_device() < 0: return F.conv3d(spike, weight, bias, stride, padding, dilation, groups) else: return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)
[文档]class SpikeLinear(nn.Linear): """ * :ref:`API in English <SpikeLinear-en>` .. _SpikeLinear-cn: :class:`torch.nn.Linear` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <SpikeLinear-cn>` .. _SpikeLinear-en: A specific case of :class:`torch.nn.Linear` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """
[文档] def forward(self, spike: Tensor) -> Tensor: return spike_linear(spike, self.weight, self.bias)
[文档]class SpikeConv1d(nn.Conv1d): """ * :ref:`API in English <SpikeConv1d-en>` .. _SpikeConv1d-cn: :class:`torch.nn.Conv1d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <SpikeConv1d-cn>` .. _SpikeConv1d-en: A specific case of :class:`torch.nn.Conv1d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return spike_conv1d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _single(0), self.dilation, self.groups) return spike_conv1d(spike, weight, bias, self.stride, self.padding, self.dilation, self.groups)
[文档]class SpikeConv2d(nn.Conv2d): """ * :ref:`API in English <SpikeConv2d-en>` .. _SpikeConv2d-cn: :class:`torch.nn.Conv2d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <SpikeConv2d-cn>` .. _SpikeConv2d-en: A specific case of :class:`torch.nn.Conv2d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != 'zeros': return spike_conv2d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode), weight, bias, self.stride, _pair(0), self.dilation, self.groups) return spike_conv2d(spike, weight, bias, self.stride, self.padding, self.dilation, self.groups)
[文档]class SpikeConv3d(nn.Conv3d): """ * :ref:`API in English <SpikeConv3d-en>` .. _SpikeConv3d-cn: :class:`torch.nn.Conv3d` 在输入为脉冲时的特例。 .. note:: 在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。 .. warning:: `spike` 中的任何元素都必须为0或1。 * :ref:`中文API <SpikeConv3d-cn>` .. _SpikeConv3d-en: A specific case of :class:`torch.nn.Conv3d` with inputs are spikes. .. admonition:: Note :class: note This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices. .. admonition:: Warning :class: warning Any element in `spike` must be 0 or 1. """ def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]): if self.padding_mode != "zeros": return spike_conv3d( F.pad( spike, self._reversed_padding_repeated_twice, mode=self.padding_mode ), weight, bias, self.stride, _triple(0), self.dilation, self.groups, ) return spike_conv3d( spike, weight, bias, self.stride, self.padding, self.dilation, self.groups )