Container
Author: fangwei123456
Translator: Qiu Haonan, fangwei123456
The major containers in SpikingJelly are:
multi_step_forward
in functional style andMultiStepContainer
in module styleseq_to_ann_forward
in functional style andSeqToANNContainer
in module styleStepModeContainer
for wrapping a single-step module for single/multi-step propagation
multi_step_forward
can use a single-step module to implement multi-step propagation, and MultiStepContainer
can wrap a single-step module to a multi-step module. For example:
import torch
from spikingjelly.activation_based import neuron, functional, layer
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]
# z_seq is identical to y_seq
For a stateless ANN layer such as torch.nn.Conv2d
, which requires input data with shape = [N, *]
, to be used in multi-step mode, we can wrap it by the multi-step containers:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
However, the ANN layers are stateless and \(Y[t]\) is only determined by \(X[t]\). Hence, it is not necessary to calculate \(Y[t]\) step-bt-step.We can use seq_to_ann_forward
or SeqToANNContainer
to wrap, which will reshape the input with shape = [T, N, *]
to shape = [TN, *]
, send data to ann layers, and reshape output to shape = [T, N, *]
. The calculation in different time-steps are in parallelism and faster:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
# p_seq.shape = [T, N, 8, H, W]
net = layer.SeqToANNContainer(conv, bn)
q_seq = net(x_seq)
# q_seq.shape = [T, N, 8, H, W]
# q_seq is identical to p_seq, and also identical to y_seq and z_seq
Most frequently-used ann modules have been defined in spikingjelly.activation_based.layer
. It is recommended to use modules in spikingjelly.activation_based.layer
, rather than using a container to wrap the ann layers manually. Althouth the modules in spikingjelly.activation_based.layer
are implementd by using seq_to_ann_forward
to wrap forward function, the advantages of modules in spikingjelly.activation_based.layer
are:
Both single-step and multi-step modes are supported. When using
SeqToANNContainer
orMultiStepContainer
to wrap modules, only the multi-step mode is supported.The wrapping of containers will add a prefix of
keys()
ofstate_dict
, which brings some troubles for loading weights.
For example:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron
ann = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU()
)
print(f'ann.state_dict.keys()={ann.state_dict().keys()}')
net_container = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
),
neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')
net_origin = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')
try:
print('net_container is trying to load state dict from ann...')
net_container.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_container can not load! The error message is\n', e)
try:
print('net_origin is trying to load state dict from ann...')
net_origin.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_origin can not load! The error message is', e)
The outputs are
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!
MultiStepContainer
and SeqToANNContainer
only support for multi-step mode and do not allow to switch to single-step mode.
StepModeContainer
works like the merged version of MultiStepContainer
and SeqToANNContainer
, which can be used to wrap stateless or stateful single-step modules.The user should specify whether the wrapped modules are stateless or stateful when using this container. This container also supports switching step modes.
Here is an example of wrapping a stateless layer:
import torch
from spikingjelly.activation_based import neuron, layer
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
False,
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
Here is an example of wrapping a stateful layer:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
functional.reset_net(net)
It is safe to use set_step_mode
to change the step mode of StepModeContainer
. Only the step_mode
of the container itself is changed, and the modules inside the container still use single-step:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
functional.set_step_mode(net, 'm')
print(f'net.step_mode={net.step_mode}')
print(f'net[0].step_mode={net[0].step_mode}')
If the module itself supports for switching between single-step and multi-step modes, is not recommended to use MultiStepContainer
or StepModeContainer
to wrap.Because the multi-step forward implemented by the container may not be as fast as the forward defined by the module itself.
In most cases, we use MultiStepContainer
or StepModeContainer
to wrap modules which do not define the multi-step forward, such as a network layer that exists in torch.nn
but does not exist in spikingjelly.activation_based.layer
.