spikingjelly.activation_based.model.spiking_resnet 源代码

import torch
import torch.nn as nn
from copy import deepcopy

try:
    from torchvision.models.utils import load_state_dict_from_url
except ImportError:
    from torchvision._internally_replaced_utils import load_state_dict_from_url

from .. import layer


__all__ = [
    "SpikingResNet",
    "spiking_resnet18",
    "spiking_resnet34",
    "spiking_resnet50",
    "spiking_resnet101",
    "spiking_resnet152",
    "spiking_resnext50_32x4d",
    "spiking_resnext101_32x8d",
    "spiking_wide_resnet50_2",
    "spiking_wide_resnet101_2",
]

model_urls = {
    "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
    "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
    "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
    "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
    "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
    "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
    "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
    "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
    "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
}

# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """
    **API Language:**
    :ref:`中文 <conv3x3-cn>` | :ref:`English <conv3x3-en>`

    ----

    .. _conv3x3-cn:
    * **中文**

    * **中文**

    带 padding 的 3x3 卷积层构造函数。

    :param in_planes: 输入通道数
    :type in_planes: int
    :param out_planes: 输出通道数
    :type out_planes: int
    :param stride: 步幅,默认为 ``1``
    :type stride: int
    :param groups: 分组数,默认为 ``1``
    :type groups: int
    :param dilation: 膨胀率,默认为 ``1``
    :type dilation: int
    :return: 3x3 卷积层
    :rtype: layer.Conv2d

    ----

    .. _conv3x3-en:
    * **English**

    * **English**

    Construct a 3x3 convolution with padding.

    :param in_planes: Number of input channels
    :type in_planes: int
    :param out_planes: Number of output channels
    :type out_planes: int
    :param stride: Stride, default is ``1``
    :type stride: int
    :param groups: Number of groups, default is ``1``
    :type groups: int
    :param dilation: Dilation rate, default is ``1``
    :type dilation: int
    :return: 3x3 convolution layer
    :rtype: layer.Conv2d
    """
    return layer.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """
    **API Language:**
    :ref:`中文 <conv1x1-cn>` | :ref:`English <conv1x1-en>`

    ----

    .. _conv1x1-cn:
    * **中文**

    * **中文**

    1x1 卷积层构造函数。

    :param in_planes: 输入通道数
    :type in_planes: int
    :param out_planes: 输出通道数
    :type out_planes: int
    :param stride: 步幅,默认为 ``1``
    :type stride: int
    :return: 1x1 卷积层
    :rtype: layer.Conv2d

    ----

    .. _conv1x1-en:
    * **English**

    * **English**

    Construct a 1x1 convolution.

    :param in_planes: Number of input channels
    :type in_planes: int
    :param out_planes: Number of output channels
    :type out_planes: int
    :param stride: Stride, default is ``1``
    :type stride: int
    :return: 1x1 convolution layer
    :rtype: layer.Conv2d
    """
    return layer.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
        spiking_neuron: callable = None,
        **kwargs,
    ):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = layer.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.sn1 = spiking_neuron(**deepcopy(kwargs))
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.sn2 = spiking_neuron(**deepcopy(kwargs))
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.sn1(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.sn2(out)

        return out


class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
        spiking_neuron: callable = None,
        **kwargs,
    ):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = layer.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.sn1 = spiking_neuron(**deepcopy(kwargs))
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.sn2 = spiking_neuron(**deepcopy(kwargs))
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.sn3 = spiking_neuron(**deepcopy(kwargs))
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.sn1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.sn2(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.sn3(out)

        return out


[文档] class SpikingResNet(nn.Module): r""" **API Language:** :ref:`中文 <SpikingResNet-cn>` | :ref:`English <SpikingResNet-en>` ---- .. _SpikingResNet-cn: * **中文** * **中文** 脉冲 ResNet 网络。继承自 :class:`torchvision.models.ResNet`,将原网络的激活函数替换为脉冲神经元。 ---- .. _SpikingResNet-en: * **English** * **English** Spiking ResNet network. Inherits from :class:`torchvision.models.ResNet` with activations replaced by spiking neurons. :param block: 残差块的类型(``BasicBlock`` 或 ``Bottleneck``) :type block: type :param layers: 每个层的残差块数量 :type layers: list :param num_classes: 分类任务的类别数 :type num_classes: int :param zero_init_residual: 是否将最后一个 BN 初始化为零 :type zero_init_residual: bool :param groups: 分组卷积的组数 :type groups: int :param width_per_group: 每组的宽度 :type width_per_group: int :param replace_stride_with_dilation: 是否用膨胀卷积替换步长 :type replace_stride_with_dilation: Optional[List[bool]] :param norm_layer: 归一化层类型 :type norm_layer: Optional[Callable] :param spiking_neuron: 脉冲神经元类 :type spiking_neuron: callable :param kwargs: 传递给脉冲神经元的额外参数 :type kwargs: dict :param block: Type of residual block (``BasicBlock`` or ``Bottleneck``) :type block: type :param layers: Number of residual blocks per layer :type layers: list :param num_classes: Number of classes for classification :type num_classes: int :param zero_init_residual: Whether to zero-initialize the last BN :type zero_init_residual: bool :param groups: Number of groups for grouped convolution :type groups: int :param width_per_group: Width per group :type width_per_group: int :param replace_stride_with_dilation: Replace stride with dilated convolution :type replace_stride_with_dilation: Optional[List[bool]] :param norm_layer: Normalization layer type :type norm_layer: Optional[Callable] :param spiking_neuron: Spiking neuron class :type spiking_neuron: callable :param kwargs: Extra arguments for the spiking neuron :type kwargs: dict :return: None :rtype: None """ def __init__( self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, spiking_neuron: callable = None, **kwargs, ): super(SpikingResNet, self).__init__() if norm_layer is None: norm_layer = layer.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError( "replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation) ) self.groups = groups self.base_width = width_per_group self.conv1 = layer.Conv2d( 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = norm_layer(self.inplanes) self.sn1 = spiking_neuron(**deepcopy(kwargs)) self.maxpool = layer.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer( block, 64, layers[0], spiking_neuron=spiking_neuron, **kwargs ) self.layer2 = self._make_layer( block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0], spiking_neuron=spiking_neuron, **kwargs, ) self.layer3 = self._make_layer( block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1], spiking_neuron=spiking_neuron, **kwargs, ) self.layer4 = self._make_layer( block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2], spiking_neuron=spiking_neuron, **kwargs, ) self.avgpool = layer.AdaptiveAvgPool2d((1, 1)) self.fc = layer.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, layer.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (layer.BatchNorm2d, layer.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) def _make_layer( self, block, planes, blocks, stride=1, dilate=False, spiking_neuron: callable = None, **kwargs, ): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), ) layers = [] layers.append( block( self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, spiking_neuron, **kwargs, ) ) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append( block( self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, spiking_neuron=spiking_neuron, **kwargs, ) ) return nn.Sequential(*layers) def _forward_impl(self, x): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.sn1(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) if self.avgpool.step_mode == "s": x = torch.flatten(x, 1) elif self.avgpool.step_mode == "m": x = torch.flatten(x, 2) x = self.fc(x) return x
[文档] def forward(self, x): return self._forward_impl(x)
def _spiking_resnet( arch, block, layers, pretrained, progress, spiking_neuron, **kwargs ): model = SpikingResNet(block, layers, spiking_neuron=spiking_neuron, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model
[文档] def spiking_resnet18( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnet18-cn>` | :ref:`English <spiking_resnet18-en>` ---- .. _spiking_resnet18-cn: * **中文** 构造 Spiking ResNet-18。 :param pretrained: 若为 ``True``,加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNet-18 :rtype: torch.nn.Module ---- .. _spiking_resnet18-en: * **English** Construct Spiking ResNet-18. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNet-18 :rtype: torch.nn.Module """ return _spiking_resnet( "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnet34( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnet34-cn>` | :ref:`English <spiking_resnet34-en>` ---- .. _spiking_resnet34-cn: * **中文** 构造 Spiking ResNet-34。 :param pretrained: 若为 ``True``,加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNet-34 :rtype: torch.nn.Module ---- .. _spiking_resnet34-en: * **English** Construct Spiking ResNet-34. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNet-34 :rtype: torch.nn.Module """ return _spiking_resnet( "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnet50( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnet50-cn>` | :ref:`English <spiking_resnet50-en>` ---- .. _spiking_resnet50-cn: * **中文** 构造 Spiking ResNet-50。 :param pretrained: 若为 ``True``,加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNet-50 :rtype: torch.nn.Module ---- .. _spiking_resnet50-en: * **English** Construct Spiking ResNet-50. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNet-50 :rtype: torch.nn.Module """ return _spiking_resnet( "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnet101( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnet101-cn>` | :ref:`English <spiking_resnet101-en>` ---- .. _spiking_resnet101-cn: * **中文** 构造 Spiking ResNet-101。 :param pretrained: 若为 ``True``,加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNet-101 :rtype: torch.nn.Module ---- .. _spiking_resnet101-en: * **English** Construct Spiking ResNet-101. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNet-101 :rtype: torch.nn.Module """ return _spiking_resnet( "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnet152( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnet152-cn>` | :ref:`English <spiking_resnet152-en>` ---- .. _spiking_resnet152-cn: * **中文** 构造 Spiking ResNet-152。 :param pretrained: 若为 ``True``,加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNet-152 :rtype: torch.nn.Module ---- .. _spiking_resnet152-en: * **English** Construct Spiking ResNet-152. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNet-152 :rtype: torch.nn.Module """ return _spiking_resnet( "resnet152", Bottleneck, [3, 8, 36, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnext50_32x4d( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnext50_32x4d-cn>` | :ref:`English <spiking_resnext50_32x4d-en>` ---- .. _spiking_resnext50_32x4d-cn: * **中文** 构造 Spiking ResNeXt-50 32x4d。 :param pretrained: 若为 ``True``, 加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNeXt-50 32x4d :rtype: torch.nn.Module ---- .. _spiking_resnext50_32x4d-en: * **English** Construct Spiking ResNeXt-50 32x4d. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNeXt-50 32x4d :rtype: torch.nn.Module """ kwargs["groups"] = 32 kwargs["width_per_group"] = 4 return _spiking_resnet( "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_resnext101_32x8d( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_resnext101_32x8d-cn>` | :ref:`English <spiking_resnext101_32x8d-en>` ---- .. _spiking_resnext101_32x8d-cn: * **中文** 构造 Spiking ResNeXt-101 32x8d。 :param pretrained: 若为 ``True``, 加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking ResNeXt-101 32x8d :rtype: torch.nn.Module ---- .. _spiking_resnext101_32x8d-en: * **English** Construct Spiking ResNeXt-101 32x8d. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking ResNeXt-101 32x8d :rtype: torch.nn.Module """ kwargs["groups"] = 32 kwargs["width_per_group"] = 8 return _spiking_resnet( "resnext101_32x8d", Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_wide_resnet50_2( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_wide_resnet50_2-cn>` | :ref:`English <spiking_wide_resnet50_2-en>` ---- .. _spiking_wide_resnet50_2-cn: * **中文** 构造 Spiking Wide ResNet-50-2。 该模型来自 `Wide Residual Networks <https://arxiv.org/pdf/1605.07146.pdf>`_ 的脉冲版本。 :param pretrained: 若为 ``True``, 加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking Wide ResNet-50-2 :rtype: torch.nn.Module ---- .. _spiking_wide_resnet50_2-en: * **English** Construct Spiking Wide ResNet-50-2. This is the spiking version of `Wide Residual Networks <https://arxiv.org/pdf/1605.07146.pdf>`_. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking Wide ResNet-50-2 :rtype: torch.nn.Module """ kwargs["width_per_group"] = 64 * 2 return _spiking_resnet( "wide_resnet50_2", Bottleneck, [3, 4, 6, 3], pretrained, progress, spiking_neuron, **kwargs, )
[文档] def spiking_wide_resnet101_2( pretrained=False, progress=True, spiking_neuron: callable = None, **kwargs ): r""" **API Language:** :ref:`中文 <spiking_wide_resnet101_2-cn>` | :ref:`English <spiking_wide_resnet101_2-en>` ---- .. _spiking_wide_resnet101_2-cn: * **中文** 构造 Spiking Wide ResNet-101-2。 该模型来自 `Wide Residual Networks <https://arxiv.org/pdf/1605.07146.pdf>`_ 的脉冲版本。 :param pretrained: 若为 ``True``, 加载 ImageNet 预训练权重 :type pretrained: bool :param progress: 是否显示下载进度 :type progress: bool :param spiking_neuron: 脉冲神经元层 :type spiking_neuron: callable :param kwargs: 传给 ``spiking_neuron`` 的关键字参数 :type kwargs: dict :return: Spiking Wide ResNet-101-2 :rtype: torch.nn.Module ---- .. _spiking_wide_resnet101_2-en: * **English** Construct Spiking Wide ResNet-101-2. This is the spiking version of `Wide Residual Networks <https://arxiv.org/pdf/1605.07146.pdf>`_. :param pretrained: If ``True``, load ImageNet pretrained weights :type pretrained: bool :param progress: Whether to display download progress :type progress: bool :param spiking_neuron: Spiking neuron layer :type spiking_neuron: callable :param kwargs: Keyword arguments for ``spiking_neuron`` :type kwargs: dict :return: Spiking Wide ResNet-101-2 :rtype: torch.nn.Module """ kwargs["width_per_group"] = 64 * 2 return _spiking_resnet( "wide_resnet101_2", Bottleneck, [3, 4, 23, 3], pretrained, progress, spiking_neuron, **kwargs, )