spikingjelly.activation_based.model.parametric_lif_net 源代码

import torch
import torch.nn as nn
from copy import deepcopy
from .. import layer
# Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks https://arxiv.org/abs/2007.05785

__all__ = ['MNISTNet', 'FashionMNISTNet', 'NMNISTNet', 'CIFAR10Net', 'CIFAR10DVSNet', 'DVSGestureNet']


[文档]class MNISTNet(nn.Module): def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() self.conv_fc = nn.Sequential( layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False), layer.BatchNorm2d(channels), spiking_neuron(**deepcopy(kwargs)), layer.MaxPool2d(2, 2), layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False), layer.BatchNorm2d(channels), spiking_neuron(**deepcopy(kwargs)), layer.MaxPool2d(2, 2), layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 7 * 7, 2048), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(2048, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer() )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
[文档]class FashionMNISTNet(MNISTNet): pass
[文档]class NMNISTNet(MNISTNet): def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__(channels, spiking_neuron, **kwargs) self.conv_fc[0] = layer.Conv2d(2, channels, kernel_size=3, padding=1, bias=False) self.conv_fc[-6] = layer.Linear(channels * 8 * 8, 2048)
[文档]class CIFAR10Net(nn.Module): def __init__(self, channels=256, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(2): for j in range(3): if conv.__len__() == 0: in_channels = 3 else: in_channels = channels conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False)) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 8 * 8, 2048), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(2048, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10) )
[文档] def forward(self, x): return self.conv_fc(x)
[文档]class CIFAR10DVSNet(nn.Module): def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(4): if conv.__len__() == 0: in_channels = 2 else: in_channels = channels conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False)) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 8 * 8, 512), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(512, 100), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10) )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
[文档]class DVSGestureNet(nn.Module): def __init__(self, channels=128, spiking_neuron: callable = None, **kwargs): super().__init__() conv = [] for i in range(5): if conv.__len__() == 0: in_channels = 2 else: in_channels = channels conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False)) conv.append(layer.BatchNorm2d(channels)) conv.append(spiking_neuron(**deepcopy(kwargs))) conv.append(layer.MaxPool2d(2, 2)) self.conv_fc = nn.Sequential( *conv, layer.Flatten(), layer.Dropout(0.5), layer.Linear(channels * 4 * 4, 512), spiking_neuron(**deepcopy(kwargs)), layer.Dropout(0.5), layer.Linear(512, 110), spiking_neuron(**deepcopy(kwargs)), layer.VotingLayer(10) )
[文档] def forward(self, x: torch.Tensor): return self.conv_fc(x)
def test_models(): import torch from .. import neuron, surrogate, functional x = torch.rand([2, 1, 28, 28]) net = MNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 1, 28, 28]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 1, 28, 28]) net = FashionMNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 1, 28, 28]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 32, 32]) net = NMNISTNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 2, 32, 32]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 3, 32, 32]) net = CIFAR10Net(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 3, 32, 32]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 128, 128]) net = CIFAR10DVSNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 2, 128, 128]) print(net(x).shape) functional.reset_net(net) del net del x x = torch.rand([2, 2, 128, 128]) net = DVSGestureNet(16, neuron.IFNode, surrogate_function=surrogate.ATan()) print(net(x).shape) functional.reset_net(net) functional.set_step_mode(net, 'm') x = torch.rand([4, 2, 2, 128, 128]) print(net(x).shape) functional.reset_net(net) del net del x