spikingjelly.activation_based.layer.stateless_wrapper 源代码

from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn.common_types import (
    _size_any_t,
    _size_1_t,
    _size_2_t,
    _size_3_t,
    _ratio_any_t,
)
import numpy as np

from .. import base, functional


__all__ = [
    "Conv1d",
    "Conv2d",
    "Conv3d",
    "Upsample",
    "ConvTranspose1d",
    "ConvTranspose2d",
    "ConvTranspose3d",
    "GroupNorm",
    "MaxPool1d",
    "MaxPool2d",
    "MaxPool3d",
    "AvgPool1d",
    "AvgPool2d",
    "AvgPool3d",
    "AdaptiveAvgPool1d",
    "AdaptiveAvgPool2d",
    "AdaptiveAvgPool3d",
    "Linear",
    "Flatten",
    "WSConv2d",
    "WSLinear",
]


################################################################################
# nn.Module wrappers with ``step_mode``                                        #
################################################################################


[文档] class Conv1d(nn.Conv1d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <Conv1d.__init__-cn>` | :ref:`English <Conv1d.__init__-en>` ---- .. _Conv1d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Conv1d` ---- .. _Conv1d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Conv1d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 4: raise ValueError( f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class Conv2d(nn.Conv2d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <Conv2d.__init__-cn>` | :ref:`English <Conv2d.__init__-en>` ---- .. _Conv2d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Conv2d` ---- .. _Conv2d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Conv2d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) y_shape = [x.shape[0], x.shape[1]] y = super().forward(x.flatten(0, 1)) y_shape.extend(y.shape[1:]) x = y.view(y_shape) return x
[文档] class Conv3d(nn.Conv3d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: Union[str, _size_3_t] = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <Conv3d.__init__-cn>` | :ref:`English <Conv3d.__init__-en>` ---- .. _Conv3d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Conv3d` ---- .. _Conv3d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Conv3d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 6: raise ValueError( f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class Upsample(nn.Upsample, base.StepModule): def __init__( self, size: Optional[_size_any_t] = None, scale_factor: Optional[_ratio_any_t] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <Upsample.__init__-cn>` | :ref:`English <Upsample.__init__-en>` ---- .. _Upsample.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Upsample` ---- .. _Upsample.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Upsample` for other parameters' API :return: None :rtype: None """ super().__init__( size, scale_factor, mode, align_corners, recompute_scale_factor ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor) -> Tensor: if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": x = functional.seq_to_ann_forward(x, super().forward) return x
[文档] class ConvTranspose1d(nn.ConvTranspose1d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: _size_1_t = 0, output_padding: _size_1_t = 0, groups: int = 1, bias: bool = True, dilation: _size_1_t = 1, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <ConvTranspose1d.__init__-cn>` | :ref:`English <ConvTranspose1d.__init__-en>` ---- .. _ConvTranspose1d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.ConvTranspose1d` ---- .. _ConvTranspose1d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.ConvTranspose1d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 4: raise ValueError( f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class ConvTranspose2d(nn.ConvTranspose2d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: _size_2_t = 0, output_padding: _size_2_t = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <ConvTranspose2d.__init__-cn>` | :ref:`English <ConvTranspose2d.__init__-en>` ---- .. _ConvTranspose2d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.ConvTranspose2d` ---- .. _ConvTranspose2d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.ConvTranspose2d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) y_shape = [x.shape[0], x.shape[1]] y = super().forward(x.flatten(0, 1)) y_shape.extend(y.shape[1:]) x = y.view(y_shape) return x
[文档] class ConvTranspose3d(nn.ConvTranspose3d, base.StepModule): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, padding: _size_3_t = 0, output_padding: _size_3_t = 0, groups: int = 1, bias: bool = True, dilation: _size_3_t = 1, padding_mode: str = "zeros", step_mode: str = "s", ) -> None: r""" **API Language:** :ref:`中文 <ConvTranspose3d.__init__-cn>` | :ref:`English <ConvTranspose3d.__init__-en>` ---- .. _ConvTranspose3d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.ConvTranspose3d` ---- .. _ConvTranspose3d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.ConvTranspose3d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation, padding_mode, ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 6: raise ValueError( f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class GroupNorm(nn.GroupNorm, base.StepModule): def __init__( self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, step_mode="s", ): r""" **API Language:** :ref:`中文 <GroupNorm.__init__-cn>` | :ref:`English <GroupNorm.__init__-en>` ---- .. _GroupNorm.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.GroupNorm` ---- .. _GroupNorm.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.GroupNorm` for other parameters' API :return: None :rtype: None """ super().__init__(num_groups, num_channels, eps, affine) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": return super().forward(x) elif self.step_mode == "m": return functional.seq_to_ann_forward(x, super().forward)
[文档] class MaxPool1d(nn.MaxPool1d, base.StepModule): def __init__( self, kernel_size: _size_1_t, stride: Optional[_size_1_t] = None, padding: _size_1_t = 0, dilation: _size_1_t = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <MaxPool1d.__init__-cn>` | :ref:`English <MaxPool1d.__init__-en>` ---- .. _MaxPool1d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.MaxPool1d` ---- .. _MaxPool1d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.MaxPool1d` for other parameters' API :return: None :rtype: None """ super().__init__( kernel_size, stride, padding, dilation, return_indices, ceil_mode ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 4: raise ValueError( f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) if isinstance(y, tuple): y_shape = [x.shape[0], x.shape[1]] y_shape.extend(y[0].shape[1:]) return y[0].view(y_shape), y[1].view(y_shape) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class MaxPool2d(nn.MaxPool2d, base.StepModule): def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, dilation: _size_2_t = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <MaxPool2d.__init__-cn>` | :ref:`English <MaxPool2d.__init__-en>` ---- .. _MaxPool2d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.MaxPool2d` ---- .. _MaxPool2d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.MaxPool2d` for other parameters' API :return: None :rtype: None """ super().__init__( kernel_size, stride, padding, dilation, return_indices, ceil_mode ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) y_shape = [x.shape[0], x.shape[1]] y = super().forward(x.flatten(0, 1)) if isinstance(y, tuple): y_shape.extend(y[0].shape[1:]) return y[0].view(y_shape), y[1].view(y_shape) y_shape.extend(y.shape[1:]) x = y.view(y_shape) return x
[文档] class MaxPool3d(nn.MaxPool3d, base.StepModule): def __init__( self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, dilation: _size_3_t = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <MaxPool3d.__init__-cn>` | :ref:`English <MaxPool3d.__init__-en>` ---- .. _MaxPool3d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.MaxPool3d` ---- .. _MaxPool3d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.MaxPool3d` for other parameters' API :return: None :rtype: None """ super().__init__( kernel_size, stride, padding, dilation, return_indices, ceil_mode ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 6: raise ValueError( f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) if isinstance(y, tuple): y_shape = [x.shape[0], x.shape[1]] y_shape.extend(y[0].shape[1:]) return y[0].view(y_shape), y[1].view(y_shape) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class AvgPool1d(nn.AvgPool1d, base.StepModule): def __init__( self, kernel_size: _size_1_t, stride: _size_1_t = None, padding: _size_1_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <AvgPool1d.__init__-cn>` | :ref:`English <AvgPool1d.__init__-en>` ---- .. _AvgPool1d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AvgPool1d` ---- .. _AvgPool1d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AvgPool1d` for other parameters' API :return: None :rtype: None """ super().__init__(kernel_size, stride, padding, ceil_mode, count_include_pad) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 4: raise ValueError( f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" ) x = functional.seq_to_ann_forward(x, super().forward) return x
[文档] class AvgPool2d(nn.AvgPool2d, base.StepModule): def __init__( self, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <AvgPool2d.__init__-cn>` | :ref:`English <AvgPool2d.__init__-en>` ---- .. _AvgPool2d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AvgPool2d` ---- .. _AvgPool2d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AvgPool2d` for other parameters' API :return: None :rtype: None """ super().__init__( kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) t, n = x.shape[0], x.shape[1] out = super().forward(x.flatten(0, 1)) x = out.view(t, n, *out.shape[1:]) return x
[文档] class AvgPool3d(nn.AvgPool3d, base.StepModule): def __init__( self, kernel_size: _size_3_t, stride: Optional[_size_3_t] = None, padding: _size_3_t = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, step_mode="s", ) -> None: r""" **API Language:** :ref:`中文 <AvgPool3d.__init__-cn>` | :ref:`English <AvgPool3d.__init__-en>` ---- .. _AvgPool3d.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AvgPool3d` ---- .. _AvgPool3d.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AvgPool3d` for other parameters' API :return: None :rtype: None """ super().__init__( kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override ) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 6: raise ValueError( f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" ) y = super().forward(x.flatten(0, 1)) x = y.view(x.shape[0], x.shape[1], *y.shape[1:]) return x
[文档] class AdaptiveAvgPool1d(nn.AdaptiveAvgPool1d, base.StepModule): def __init__(self, output_size, step_mode="s") -> None: r""" **API Language:** :ref:`中文 <AdaptiveAvgPool1d.__init__-cn>` | :ref:`English <AdaptiveAvgPool1d.__init__-en>` ---- .. _AdaptiveAvgPool1d.__init__-cn: * **中文** * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool1d` ---- .. _AdaptiveAvgPool1d.__init__-en: * **English** * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AdaptiveAvgPool1d` for other parameters' API :return: None :rtype: None """ super().__init__(output_size) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 4: raise ValueError( f"expected x with shape [T, N, C, L], but got x with shape {x.shape}!" ) x = functional.seq_to_ann_forward(x, super().forward) return x
[文档] class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d, base.StepModule): def __init__(self, output_size, step_mode="s") -> None: r""" **API Language:** :ref:`中文 <AdaptiveAvgPool2d.__init__-cn>` | :ref:`English <AdaptiveAvgPool2d.__init__-en>` ---- .. _AdaptiveAvgPool2d.__init__-cn: * **中文** * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool2d` ---- .. _AdaptiveAvgPool2d.__init__-en: * **English** * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AdaptiveAvgPool2d` for other parameters' API :return: None :rtype: None """ super().__init__(output_size) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) y_shape = [x.shape[0], x.shape[1]] x = super().forward(x.flatten(0, 1)) y_shape.extend(x.shape[1:]) x = x.view(y_shape) return x
[文档] class AdaptiveAvgPool3d(nn.AdaptiveAvgPool3d, base.StepModule): def __init__(self, output_size, step_mode="s") -> None: r""" **API Language:** :ref:`中文 <AdaptiveAvgPool3d.__init__-cn>` | :ref:`English <AdaptiveAvgPool3d.__init__-en>` ---- .. _AdaptiveAvgPool3d.__init__-cn: * **中文** * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.AdaptiveAvgPool3d` ---- .. _AdaptiveAvgPool3d.__init__-en: * **English** * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.AdaptiveAvgPool3d` for other parameters' API :return: None :rtype: None """ super().__init__(output_size) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": if x.dim() != 6: raise ValueError( f"expected x with shape [T, N, C, D, H, W], but got x with shape {x.shape}!" ) x = functional.seq_to_ann_forward(x, super().forward) return x
[文档] class Linear(nn.Linear, base.StepModule): def __init__( self, in_features: int, out_features: int, bias: bool = True, step_mode="s" ) -> None: r""" **API Language:** :ref:`中文 <Linear.__init__-cn>` | :ref:`English <Linear.__init__-en>` ---- .. _Linear.__init__-cn: * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Linear` ---- .. _Linear.__init__-en: * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Linear` for other parameters' API :return: None :rtype: None """ super().__init__(in_features, out_features, bias) self.step_mode = step_mode
[文档] class Flatten(nn.Flatten, base.StepModule): def __init__(self, start_dim: int = 1, end_dim: int = -1, step_mode="s") -> None: r""" **API Language:** :ref:`中文 <Flatten.__init__-cn>` | :ref:`English <Flatten.__init__-en>` ---- .. _Flatten.__init__-cn: * **中文** * **中文** :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str 其他的参数API参见 :class:`torch.nn.Flatten` ---- .. _Flatten.__init__-en: * **English** * **English** :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str Refer to :class:`torch.nn.Flatten` for other parameters' API :return: None :rtype: None """ super().__init__(start_dim, end_dim) self.step_mode = step_mode def extra_repr(self): return super().extra_repr() + f", step_mode={self.step_mode}"
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = super().forward(x) elif self.step_mode == "m": y_shape = [x.shape[0], x.shape[1]] x = super().forward(x.flatten(0, 1)) y_shape.extend(x.shape[1:]) x = x.view(y_shape) return x
################################################################################ # scaled weight standardization modules # ################################################################################
[文档] class WSConv2d(Conv2d): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", step_mode: str = "s", gain: bool = True, eps: float = 1e-4, ) -> None: r""" **API Language:** :ref:`中文 <WSConv2d.__init__-cn>` | :ref:`English <WSConv2d.__init__-en>` ---- .. _WSConv2d.__init__-cn: * **中文** :param gain: 是否对权重引入可学习的缩放系数 :type gain: bool :param eps: 预防数值问题的小量 :type eps: float 其他的参数API参见 :class:`Conv2d` ---- .. _WSConv2d.__init__-en: * **English** :param gain: whether introduce learnable scale factors for weights :type step_mode: bool :param eps: a small number to prevent numerical problems :type eps: float Refer to :class:`Conv2d` for other parameters' API :return: None :rtype: None """ super().__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, step_mode, ) if gain: self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) else: self.gain = None self.eps = eps
[文档] def get_weight(self): fan_in = np.prod(self.weight.shape[1:]) mean = torch.mean(self.weight, axis=[1, 2, 3], keepdims=True) var = torch.var(self.weight, axis=[1, 2, 3], keepdims=True) weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5) if self.gain is not None: weight = weight * self.gain return weight
def _forward(self, x: Tensor): return F.conv2d( x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups, )
[文档] def forward(self, x: Tensor): if self.step_mode == "s": x = self._forward(x) elif self.step_mode == "m": if x.dim() != 5: raise ValueError( f"expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!" ) x = functional.seq_to_ann_forward(x, self._forward) return x
[文档] class WSLinear(Linear): def __init__( self, in_features: int, out_features: int, bias: bool = True, step_mode="s", gain=True, eps=1e-4, ) -> None: r""" **API Language:** :ref:`中文 <WSLinear.__init__-cn>` | :ref:`English <WSLinear.__init__-en>` ---- .. _WSLinear.__init__-cn: * **中文** :param gain: 是否对权重引入可学习的缩放系数 :type gain: bool :param eps: 预防数值问题的小量 :type eps: float 其他的参数API参见 :class:`Linear` ---- .. _WSLinear.__init__-en: * **English** :param gain: whether introduce learnable scale factors for weights :type step_mode: bool :param eps: a small number to prevent numerical problems :type eps: float Refer to :class:`Linear` for other parameters' API :return: None :rtype: None """ super().__init__(in_features, out_features, bias, step_mode) if gain: self.gain = nn.Parameter(torch.ones(self.out_channels, 1)) else: self.gain = None self.eps = eps
[文档] def get_weight(self): fan_in = np.prod(self.weight.shape[1:]) mean = torch.mean(self.weight, axis=[1], keepdims=True) var = torch.var(self.weight, axis=[1], keepdims=True) weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5) if self.gain is not None: weight = weight * self.gain return weight
[文档] def forward(self, x: Tensor): return F.linear(x, self.get_weight(), self.bias)