包装器

本教程作者: fangwei123456

SpikingJelly中主要提供了如下几种包装器:

multi_step_forward 可以将一个单步模块进行多步传播,而 MultiStepContainer 则可以将一个单步模块包装成多步模块,例如:

import torch
from spikingjelly.activation_based import neuron, functional, layer

net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]

net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]

# z_seq is identical to y_seq

对于无状态的ANN网络层,例如 torch.nn.Conv2d,其本身要求输入数据的 shape = [N, *],若用于多步模式,则可以用多步的包装器进行包装:

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    # y_seq.shape = [T, N, 8, H, W]

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    # z_seq.shape = [T, N, 8, H, W]

    # z_seq is identical to y_seq

但是ANN的网络层本身是无状态的,不存在前序依赖,没有必要在时间上串行的计算,可以使用函数风格的 seq_to_ann_forward 或模块风格的 SeqToANNContainer 进行包装。seq_to_ann_forwardshape = [T, N, *] 的数据首先变换为 shape = [TN, *],再送入无状态的网络层进行计算,输出的结果会被重新变换为 shape = [T, N, *]。不同时刻的数据是并行计算的,因而速度更快:

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer

with torch.no_grad():
    T = 4
    N = 1
    C = 3
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])

    conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
    bn = nn.BatchNorm2d(8)

    y_seq = functional.multi_step_forward(x_seq, (conv, bn))
    # y_seq.shape = [T, N, 8, H, W]

    net = layer.MultiStepContainer(conv, bn)
    z_seq = net(x_seq)
    # z_seq.shape = [T, N, 8, H, W]

    # z_seq is identical to y_seq

    p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
    # p_seq.shape = [T, N, 8, H, W]

    net = layer.SeqToANNContainer(conv, bn)
    q_seq = net(x_seq)
    # q_seq.shape = [T, N, 8, H, W]

    # q_seq is identical to p_seq, and also identical to y_seq and z_seq

常用的网络层,在 spikingjelly.activation_based.layer 已经定义过,更推荐使用 spikingjelly.activation_based.layer 中的网络层,而不是使用 SeqToANNContainer 手动包装,尽管 spikingjelly.activation_based.layer 中的网络层实际上就是用包装器包装 forward 函数实现的。spikingjelly.activation_based.layer 中的网络层,优势在于:

  • 支持单步和多步模式,而 SeqToANNContainerMultiStepContainer 包装的层,只支持多步模式

  • 包装器会使得 state_dictkeys() 也增加一层包装,给加载权重带来麻烦

例如

import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron


ann = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    nn.ReLU()
)

print(f'ann.state_dict.keys()={ann.state_dict().keys()}')

net_container = nn.Sequential(
    layer.SeqToANNContainer(
        nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(8),
    ),
    neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')

net_origin = nn.Sequential(
    layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(8),
    neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')

try:
    print('net_container is trying to load state dict from ann...')
    net_container.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_container can not load! The error message is\n', e)

try:
    print('net_origin is trying to load state dict from ann...')
    net_origin.load_state_dict(ann.state_dict())
    print('Load success!')
except BaseException as e:
    print('net_origin can not load! The error message is', e)

输出为

ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
    Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
    Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!

MultiStepContainerSeqToANNContainer 都是只支持多步模式的,不允许切换为单步模式。

StepModeContainer 类似于融合版的 MultiStepContainerSeqToANNContainer,可以用于包装无状态或有状态的单步模块,需要在包装时指明是否有状态,但此包装器还支持切换单步和多步模式。

包装无状态层的示例:

import torch
from spikingjelly.activation_based import neuron, layer


with torch.no_grad():
    T = 4
    N = 2
    C = 4
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])
    net = layer.StepModeContainer(
        False,
        nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(C),
    )
    net.step_mode = 'm'
    y_seq = net(x_seq)
    # y_seq.shape = [T, N, C, H, W]

    net.step_mode = 's'
    y = net(x_seq[0])
    # y.shape = [N, C, H, W]

包装有状态层的示例:

import torch
from spikingjelly.activation_based import neuron, layer, functional


with torch.no_grad():
    T = 4
    N = 2
    C = 4
    H = 8
    W = 8
    x_seq = torch.rand([T, N, C, H, W])
    net = layer.StepModeContainer(
        True,
        neuron.IFNode()
    )
    net.step_mode = 'm'
    y_seq = net(x_seq)
    # y_seq.shape = [T, N, C, H, W]
    functional.reset_net(net)

    net.step_mode = 's'
    y = net(x_seq[0])
    # y.shape = [N, C, H, W]
    functional.reset_net(net)

使用 set_step_mode 改变 StepModeContainer 是安全的,只会改变包装器本身的 step_mode,而包装器内的模块仍然保持单步:

import torch
from spikingjelly.activation_based import neuron, layer, functional


with torch.no_grad():
    net = layer.StepModeContainer(
        True,
        neuron.IFNode()
    )
    functional.set_step_mode(net, 'm')
    print(f'net.step_mode={net.step_mode}')
    print(f'net[0].step_mode={net[0].step_mode}')

如果模块本身就支持单步和多步模式的切换,则不推荐使用 MultiStepContainerStepModeContainer 对其进行包装。因为包装器使用的多步前向传播,可能不如模块自身定义的前向传播速度快。

通常需要用到 MultiStepContainerStepModeContainer 的是一些没有定义多步的模块,例如一个在 torch.nn 中存在,但在 spikingjelly.activation_based.layer 中不存在的网络层。