传播模式

本教程作者: fangwei123456

单步传播与多步传播

SpikingJelly中的绝大多数模块(spikingjelly.clock_driven.rnn 除外),例如 spikingjelly.clock_driven.layer.Dropout,模块名的前缀中没有 MultiStep,表示这个模块的 forward 函数定义的是单步的前向传播:

输入 \(X_{t}\),输出 \(Y_{t}\)

而如果前缀中含有 MultiStep,例如 spikingjelly.clock_driven.layer.MultiStepDropout,则表面这个模块的 forward 函数定义的是多步的前向传播:

输入 \(X_{t}, t=0,1,...,T-1\),输出 \(Y_{t}, t=0,1,...,T-1\)

一个单步传播的模块,可以很容易被封装成多步传播的模块,spikingjelly.clock_driven.layer.MultiStepContainer 提供了非常简单的方式,将原始模块作为子模块,并在 forward 函数中实现在时间上的循环,代码如下所示:

class MultiStepContainer(nn.Module):
    def __init__(self, module: nn.Module):
        super().__init__()
        self.module = module

    def forward(self, x_seq: torch.Tensor):
        y_seq = []
        for t in range(x_seq.shape[0]):
            y_seq.append(self.module(x_seq[t]))
            y_seq[-1].unsqueeze_(0)
        return torch.cat(y_seq, 0)

    def reset(self):
        if hasattr(self.module, 'reset'):
            self.module.reset()

我们使用这种方式来包装一个IF神经元:

from spikingjelly.clock_driven import neuron, layer
import torch

neuron_num = 4
T = 8
if_node = neuron.IFNode()
x = torch.rand([T, neuron_num]) * 2
for t in range(T):
    print(f'if_node output spikes at t={t}', if_node(x[t]))
if_node.reset()

ms_if_node = layer.MultiStepContainer(if_node)
print("multi step if_node output spikes\n", ms_if_node(x))
ms_if_node.reset()

输出为:

if_node output spikes at t=0 tensor([1., 1., 1., 0.])
if_node output spikes at t=1 tensor([0., 0., 0., 1.])
if_node output spikes at t=2 tensor([1., 1., 1., 1.])
if_node output spikes at t=3 tensor([0., 0., 1., 0.])
if_node output spikes at t=4 tensor([1., 1., 1., 1.])
if_node output spikes at t=5 tensor([1., 0., 0., 0.])
if_node output spikes at t=6 tensor([1., 0., 1., 1.])
if_node output spikes at t=7 tensor([1., 1., 1., 0.])
multi step if_node output spikes
 tensor([[1., 1., 1., 0.],
        [0., 0., 0., 1.],
        [1., 1., 1., 1.],
        [0., 0., 1., 0.],
        [1., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 0., 1., 1.],
        [1., 1., 1., 0.]])

两种方式的输出是完全相同的。

逐步传播与逐层传播

在以往的教程和样例中,我们定义的网络在运行时,是按照 逐步传播(step-by-step) 的方式,例如上文中的:

if_node = neuron.IFNode()
x = torch.rand([T, neuron_num]) * 2
for t in range(T):
    print(f'if_node output spikes at t={t}', if_node(x[t]))

逐步传播(step-by-step),指的是在前向传播时,先计算出整个网络在 \(t=0\) 的输出 \(Y_{0}\),然后再计算整个网络在 \(t=1\) 的输出 \(Y_{1}\),……,最终得到网络在所有时刻的输出 \(Y_{t}, t=0,1,...,T-1\)。例如下面这份代码(假定 M0, M1, M2 都是单步传播的模块):

net = nn.Sequential(M0, M1, M2)

for t in range(T):
    Y[t] = net(X[t])

前向传播的计算图的构建顺序如下所示:

../_images/step-by-step.png

对于SNN以及RNN,前向传播既发生在空域也发生在时域,逐步传播 逐步计算出整个网络在不同时刻的状态,我们可以很容易联想到,还可以使用另一种顺序来计算:逐层计算出每一层网络在所有时刻的状态。例如下面这份代码(假定 M0, M1, M2 都是多步传播的模块):

net = nn.Sequential(M0, M1, M2)

Y = net(X)

前向传播的计算图的构建顺序如下所示:

../_images/layer-by-layer.png

我们称这种方式为 逐层传播(layer-by-layer)逐层传播 在RNN以及SNN中也被广泛使用,例如 Low-activity supervised convolutional spiking neural networks applied to speech commands recognition 通过逐层计算的方式来获取每一层在所有时刻的输出,然后在时域上进行卷积,代码可见于 https://github.com/romainzimmer/s2net

逐步传播逐层传播 遍历计算图的顺序不同,但计算的结果是完全相同的。但 逐层传播 具有更大的并行性,因为当某一层是无状态的层,例如 torch.nn.Linear逐步传播 会按照下述方式计算:

for t in range(T):
    y[t] = fc(x[t])  # x.shape=[T, batch_size, in_features]

逐层传播 则可以并行计算:

y = fc(x)  # x.shape=[T, batch_size, in_features]

对于无状态的层,我们可以将 shape=[T, batch_size, ...] 的输入拼接成 shape=[T * batch_size, ...] 后,再送入这一层计算,避免在时间上的循环。spikingjelly.clock_driven.layer.SeqToANNContainerforward 函数中进行了这样的实现。我们可以直接使用这个模块:

with torch.no_grad():
    T = 16
    batch_size = 8
    x = torch.rand([T, batch_size, 4])
    fc = SeqToANNContainer(nn.Linear(4, 2), nn.Linear(2, 3))
    print(fc(x).shape)

输出为:

torch.Size([16, 8, 3])

输出仍然满足 shape=[T, batch_size, ...],可以直接送入到下一层网络。