spikingjelly.activation_based.model.parametric_lif_net 源代码

import torch
import torch.nn as nn
from copy import deepcopy
from .. import layer
# Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks https://arxiv.org/abs/2007.05785

__all__ = [
    "MNISTNet",
    "FashionMNISTNet",
    "NMNISTNet",
    "CIFAR10Net",
    "CIFAR10DVSNet",
    "DVSGestureNet",
]


[文档] class MNISTNet(nn.Module): r""" **API Language:** :ref:`中文 <MNISTNet-cn>` | :ref:`English <MNISTNet-en>` ---- .. _MNISTNet-cn: * **中文** * **中文** 用于 MNIST 手写数字分类的 Parametric LIF 网络。 ---- .. _MNISTNet-en: * **English** * **English** Parametric LIF network for MNIST digit classification. """ def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() self.conv_fc = nn.Sequential( layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False), layer.BatchNorm2d(channels), spiking_neuron(**deepcopy(kwargs)), layer.MaxPool2d(2, 2), layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), layer.BatchNorm2d(channels), spiking_neuron(**deepcopy(kwargs)), layer.MaxPool2d(2, 2), layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 7 * 7, 2048), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(2048, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(), )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
[文档] class FashionMNISTNet(MNISTNet): r""" **API Language:** :ref:`中文 <FashionMNISTNet-cn>` | :ref:`English <FashionMNISTNet-en>` ---- .. _FashionMNISTNet-cn: * **中文** * **中文** 用于 Fashion-MNIST 分类的 Parametric LIF 网络。即 :class:`MNISTNet` 的别名。 :param args: 与 :class:`MNISTNet` 相同的参数 :type args: tuple :param kwargs: 与 :class:`MNISTNet` 相同的关键字参数 :type kwargs: dict ---- .. _FashionMNISTNet-en: * **English** * **English** Parametric LIF network for Fashion-MNIST classification. Alias of :class:`MNISTNet`. :param args: Same as :class:`MNISTNet` :type args: tuple :param kwargs: Same as :class:`MNISTNet` :type kwargs: dict """ pass
[文档] class NMNISTNet(MNISTNet): r""" **API Language:** :ref:`中文 <NMNISTNet-cn>` | :ref:`English <NMNISTNet-en>` ---- .. _NMNISTNet-cn: * **中文** * **中文** 用于 N-MNIST 事件流分类的 Parametric LIF 网络。基于 :class:`MNISTNet`,将首层卷积输入通道调整为 2。 :param channels: 卷积层的通道数 :type channels: int :param spiking_neuron: 脉冲神经元类 :type spiking_neuron: callable ---- .. _NMNISTNet-en: * **English** * **English** Parametric LIF network for N-MNIST event stream classification. Based on :class:`MNISTNet` with first conv layer adjusted to 2 input channels. :param channels: Number of channels in conv layers :type channels: int :param spiking_neuron: Spiking neuron class :type spiking_neuron: callable """ def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__(channels, spiking_neuron, **kwargs) self.conv_fc[0] = layer.Conv2d( 2, channels, kernel_size=3, padding=1, bias=False ) self.conv_fc[-6] = layer.Linear(channels * 8 * 8, 2048)
[文档] class CIFAR10Net(nn.Module): r""" **API Language:** :ref:`中文 <CIFAR10Net-cn>` | :ref:`English <CIFAR10Net-en>` ---- .. _CIFAR10Net-cn: * **中文** * **中文** 用于 CIFAR-10 分类的 Parametric LIF 网络。 ---- .. _CIFAR10Net-en: * **English** * **English** Parametric LIF network for CIFAR-10 classification. """ def __init__(self, channels=256, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(2): for j in range(3): if conv.__len__() == 0: in_channels = 3 else: in_channels = channels conv.append( layer.Conv2d( in_channels, channels, kernel_size=3, padding=1, bias=False ) ) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 8 * 8, 2048), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(2048, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10), )
[文档] def forward(self, x): return self.conv_fc(x)
[文档] class CIFAR10DVSNet(nn.Module): r""" **API Language:** :ref:`中文 <CIFAR10DVSNet-cn>` | :ref:`English <CIFAR10DVSNet-en>` ---- .. _CIFAR10DVSNet-cn: * **中文** * **中文** 用于 CIFAR10-DVS 事件流分类的 Parametric LIF 网络。 ---- .. _CIFAR10DVSNet-en: * **English** * **English** Parametric LIF network for CIFAR10-DVS event stream classification. """ def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(4): if conv.__len__() == 0: in_channels = 2 else: in_channels = channels conv.append( layer.Conv2d( in_channels, channels, kernel_size=3, padding=1, bias=False ) ) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 8 * 8, 512), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(512, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10), )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
[文档] class DVSGestureNet(nn.Module): r""" **API Language:** :ref:`中文 <DVSGestureNet-cn>` | :ref:`English <DVSGestureNet-en>` ---- .. _DVSGestureNet-cn: * **中文** * **中文** 用于 DVS128 Gesture 手势识别的 Parametric LIF 网络。 ---- .. _DVSGestureNet-en: * **English** * **English** Parametric LIF network for DVS128 Gesture recognition. """ def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(5): if conv.__len__() == 0: in_channels = 2 else: in_channels = channels conv.append( layer.Conv2d( in_channels, channels, kernel_size=3, padding=1, bias=False ) ) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 4 * 4, 512), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(512, 110), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10), )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
def test_models(): import torch from .. import neuron, surrogate, functional x = torch.rand([2, 1, 28, 28]) net = MNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 1, 28, 28]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 1, 28, 28]) net = FashionMNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 1, 28, 28]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 32, 32]) net = NMNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 2, 32, 32]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 3, 32, 32]) net = CIFAR10Net(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 3, 32, 32]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 128, 128]) net = CIFAR10DVSNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 2, 128, 128]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 128, 128]) net = DVSGestureNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, "m") x = torch.rand([4, 2, 2, 128, 128]) print(net(x).shape) functional.reset_net(net) del net del x