# 传播模式

## 单步传播与多步传播

SpikingJelly中的绝大多数模块（spikingjelly.clock_driven.rnn 除外），例如 spikingjelly.clock_driven.layer.Dropout，模块名的前缀中没有 MultiStep，表示这个模块的 forward 函数定义的是单步的前向传播：

class MultiStepContainer(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)

def forward(self, x_seq: torch.Tensor):
"""
:param x_seq: shape=[T, batch_size, ...]
:type x_seq: torch.Tensor
:return: y_seq, shape=[T, batch_size, ...]
:rtype: torch.Tensor
"""
y_seq = []
for t in range(x_seq.shape[0]):
y_seq.append(super().forward(x_seq[t]))

for t in range(y_seq.__len__()):
y_seq[t] = y_seq[t].unsqueeze(0)


from spikingjelly.clock_driven import neuron, layer, functional
import torch

neuron_num = 4
T = 8
if_node = neuron.IFNode()
x = torch.rand([T, neuron_num]) * 2
for t in range(T):
print(f'if_node output spikes at t={t}', if_node(x[t]))
functional.reset_net(if_node)

ms_if_node = layer.MultiStepContainer(if_node)
print("multi step if_node output spikes\n", ms_if_node(x))
functional.reset_net(ms_if_node)


if_node output spikes at t=0 tensor([1., 1., 1., 0.])
if_node output spikes at t=1 tensor([0., 0., 0., 1.])
if_node output spikes at t=2 tensor([1., 1., 1., 1.])
if_node output spikes at t=3 tensor([0., 0., 1., 0.])
if_node output spikes at t=4 tensor([1., 1., 1., 1.])
if_node output spikes at t=5 tensor([1., 0., 0., 0.])
if_node output spikes at t=6 tensor([1., 0., 1., 1.])
if_node output spikes at t=7 tensor([1., 1., 1., 0.])
multi step if_node output spikes
tensor([[1., 1., 1., 0.],
[0., 0., 0., 1.],
[1., 1., 1., 1.],
[0., 0., 1., 0.],
[1., 1., 1., 1.],
[1., 0., 0., 0.],
[1., 0., 1., 1.],
[1., 1., 1., 0.]])


## 逐步传播与逐层传播

if_node = neuron.IFNode()
x = torch.rand([T, neuron_num]) * 2
for t in range(T):
print(f'if_node output spikes at t={t}', if_node(x[t]))


net = nn.Sequential(M0, M1, M2)

for t in range(T):
Y[t] = net(X[t])


net = nn.Sequential(M0, M1, M2)

Y = net(X)


for t in range(T):
y[t] = fc(x[t])  # x.shape=[T, batch_size, in_features]


y = fc(x)  # x.shape=[T, batch_size, in_features]


with torch.no_grad():
T = 16
batch_size = 8
x = torch.rand([T, batch_size, 4])
fc = SeqToANNContainer(nn.Linear(4, 2), nn.Linear(2, 3))
print(fc(x).shape)


torch.Size([16, 8, 3])


## 包装前向传播

net_step_by_step = nn.Sequential(
nn.BatchNorm2d(16),
neuron.IFNode()
)

net_layer_by_layer = nn.Sequential(
layer.SeqToANNContainer(
nn.BatchNorm2d(16),
),
neuron.MultiStepIFNode()
)

print('net_step_by_step.state_dict:', net_step_by_step.state_dict().keys())
print('net_layer_by_layer.state_dict:', net_layer_by_layer.state_dict().keys())


net_step_by_step.state_dict: odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_layer_by_layer.state_dict: odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])


class NetStepByStep(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(16)
self.sn = neuron.IFNode()

def forward(self, x):
# x.shape = [N, C, H, W]
x = self.conv(x)
x = self.bn(x)
x = self.sn(x)
return x

class NetLayerByLayer1(NetStepByStep):

def forward(self, x_seq):
# x_seq.shape = [T, N, C, H, W]
x_seq = functional.seq_to_ann_forward(x_seq, [self.conv, self.bn])
x_seq = functional.multi_step_forward(x_seq, self.sn)
return x_seq

class NetLayerByLayer2(NetStepByStep):
def __init__(self):
super().__init__()

# replace single-step neuron to multi-step neuron
del self.sn
self.sn = neuron.MultiStepIFNode()

def forward(self, x_seq):
# x_seq.shape = [T, N, C, H, W]
x_seq = functional.seq_to_ann_forward(x_seq, [self.conv, self.bn])
x_seq = self.sn(x_seq)
return x_seq


NetStepByStep, NetLayerByLayer1, NetLayerByLayer2state_dict.keys() 完全相同的，模型权重可以互相加载。