欢迎来到惊蜇(SpikingJelly)的文档
SpikingJelly 是一个基于 PyTorch ,使用脉冲神经网络(Spiking Neural Network, SNN)进行深度学习的框架。
版本说明
自 0.0.0.0.14
版本开始,包括 clock_driven
和 event_driven
在内的模块被重命名了,请参考教程 从老版本迁移。
不同版本文档的地址(其中 latest 是开发版):
安装
注意,SpikingJelly是基于PyTorch的,需要确保环境中已经安装了PyTorch,才能安装spikingjelly。
奇数版本是开发版,随着GitHub/OpenI不断更新。偶数版本是稳定版,可以从PyPI获取。
从 PyPI 安装最新的稳定版本:
pip install spikingjelly
从源代码安装最新的开发版:
通过 GitHub:
git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
通过 OpenI :
git clone https://git.openi.org.cn/OpenI/spikingjelly.git
cd spikingjelly
python setup.py install
从老版本迁移
本教程作者: fangwei123456
新版的SpikingJelly改动较大,使用老版本SpikingJelly的用户若想迁移到新版本,则需要阅读此教程。SpikingJelly的版本升级尽可能前向兼容,因此用户无需做出太多代码上的修改,即可轻松迁移到新版本。
推荐老版本用户也阅读新版本的教程 基本概念。
“老版本SpikingJelly”均指的是版本号 <=0.0.0.0.12
的SpikingJelly。
子包重命名
新版的SpikingJelly对子包进行了重命名,与老版本的对应关系为:
老版本 |
新版本 |
---|---|
clock_driven |
activation_based |
event_driven |
timing_based |
单步多步模块和传播模式
<=0.0.0.0.12
的老版本SpikingJelly,在默认情况下所有模块都是单步的,除非其名称含有前缀 MultiStep
。而新版的SpikingJelly,则不再使用前缀对单步和多步模块进行区分,取而代之的是同一个模块,拥有单步和多步两种步进模式,使用 step_mode
进行控制。具体信息可以参见 基本概念。
因而在新版本中不再有单独的多步模块,取而代之的则是融合了单步和多步的统一模块。例如,在老版本的SpikingJelly中,若想使用单步LIF神经元,是按照如下方式:
from spikingjelly.clock_driven import neuron
lif = neuron.LIFNode()
在新版本中,所有模块默认是单步的,所以与老版本的代码几乎相同,除了将 clock_driven
换成了 activation_based
:
from spikingjelly.activation_based import neuron
lif = neuron.LIFNode()
在老版本的SpikingJelly中,若想使用多步LIF神经元,是按照如下方式:
from spikingjelly.clock_driven import neuron
lif = neuron.MultiStepLIFNode()
在新版本中,单步和多步模块进行了统一,因此只需要指定为多步模块即可:
from spikingjelly.activation_based import neuron
lif = neuron.LIFNode(step_mode='m')
在老版本中,若想分别搭建一个逐步传播和逐层传播的网络,按照如下方式:
import torch
import torch.nn as nn
from spikingjelly.clock_driven import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
# step-by-step
net_sbs = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
neuron.IFNode()
)
y_seq = functional.multi_step_forward(x_seq, net_sbs)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net_sbs)
# layer-by-layer
net_lbl = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
),
neuron.MultiStepIFNode()
)
y_seq = net_lbl(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net_lbl)
而在新版本中,由于单步和多步模块已经融合,可以通过 spikingjelly.activation_based.functional.set_step_mode
对整个网络的步进模式进行转换。在所有模块使用单步模式时,整个网络就可以使用逐步传播;所有模块都使用多步模式时,整个网络就可以使用逐层传播:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
# the network uses step-by-step because step_mode='s' is the default value for all modules
net = nn.Sequential(
layer.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(C),
neuron.IFNode()
)
y_seq = functional.multi_step_forward(x_seq, net)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
# set the network to use layer-by-layer
functional.set_step_mode(net, step_mode='m')
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
基本概念
本教程作者: fangwei123456
本教程介绍了 spikingjelly.activation_based
的一些基本概念,推荐所有用户在使用SpikingJelly框架前进行阅读。
SpikingJelly框架是基于PyTorch的SNN深度学习框架,使用SpikingJelly框架的用户应该首先熟悉PyTorch的使用。如果用户对PyTorch不甚了解,我们推荐用户先学习 PyTorch的基础教程 。
基于激活值的表示方法
spikingjelly.activation_based
使用取值仅为0或1的张量表示脉冲,例如:
import torch
v = torch.rand([8])
v_th = 0.5
spike = (v >= v_th).to(v)
print('spike =', spike)
# spike = tensor([0., 0., 0., 1., 1., 0., 1., 0.])
数据格式
在 spikingjelly.activation_based
中,数据有两种格式,分别为:
表示单个时刻的数据,其
shape = [N, *]
,其中N
是batch维度,*
表示任意额外的维度表示多个时刻的数据,其
shape = [T, N, *]
,其中T
是数据的时间维度,N
是batch维度,* 表示任意额外的维度
步进模式
spikingjelly.activation_based
中的模块,具有两种传播模式,分别是单步模式(single-step)和多步模式(multi-step)。在单步模式下,数据使用 shape = [N, *]
的格式;而在多步模式下,数据使用 shape = [T, N, *]
的格式。
模块在初始化时可以指定其使用的步进模式 step_mode
,也可以在构建后直接进行修改:
import torch
from spikingjelly.activation_based import neuron
net = neuron.IFNode(step_mode='m')
# 'm' is the multi-step mode
net.step_mode = 's'
# 's' is the single-step mode
如果我们想给单步模式的模块输入 shape = [T, N, *]
的序列数据,通常需要手动做一个时间上的循环,将数据拆成 T
个 shape = [N, *]
的数据并逐步输入进去。让我们新建一层IF神经元,设置为单步模式,将数据逐步输入并得到输出:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
for t in range(T):
x = x_seq[t] # x.shape = [N, C, H, W]
y = net_s(x) # y.shape = [N, C, H, W]
y_seq.append(y.unsqueeze(0))
y_seq = torch.cat(y_seq)
# y_seq.shape = [T, N, C, H, W]
multi_step_forward
提供了将 shape = [T, N, *]
的序列数据输入到单步模块进行逐步的前向传播的封装,使用起来更加方便:
import torch
from spikingjelly.activation_based import neuron, functional
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
但是,直接将模块设置成多步模块,其实更为便捷:
import torch
from spikingjelly.activation_based import neuron
net_m = neuron.IFNode(step_mode='m')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = net_m(x_seq)
# y_seq.shape = [T, N, C, H, W]
为了保持与老版本SpikingJelly代码的兼容性,所有模块的默认步进模式都是单步。
状态的保存和重置
SNN中的神经元等模块,与RNN类似,带有隐藏状态,其输出 \(Y[t]\) 不仅仅与当前时刻的输入 \(X[t]\) 有关,还与上一个时末的状态 \(H[t-1]\) 有关,即 \(Y[t] = f(X[t], H[t-1])\)。
PyTorch的设计为RNN将状态也一并输出,可以参考 torch.nn.RNN
的API文档。而在 spikingjelly.activation_based
中,状态会被保存在模块内部。例如,我们新建一层IF神经元,设置为单步模式,查看给与输入前的默认电压,和给与输入后的电压:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(net_s)
print(f'the initial v={net_s.v}')
y = net_s(x)
print(f'x={x}')
print(f'y={y}')
print(f'v={net_s.v}')
# outputs are:
'''
IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
the initial v=0.0
x=tensor([0.5543, 0.0350, 0.2171, 0.6740])
y=tensor([0., 0., 0., 0.])
v=tensor([0.5543, 0.0350, 0.2171, 0.6740])
'''
在初始化后,IF神经元层的 v
会被设置为0,首次给与输入后 v
会自动广播到与输入相同的 shape
。
若我们给与一个新的输入,则应该先清除神经元之前的状态,让其恢复到初始化状态,可以通过调用模块的 self.reset()
函数实现:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(f'check point 0: v={net_s.v}')
y = net_s(x)
print(f'check point 1: v={net_s.v}')
net_s.reset()
print(f'check point 2: v={net_s.v}')
x = torch.rand([8])
y = net_s(x)
print(f'check point 3: v={net_s.v}')
# outputs are:
'''
check point 0: v=0.0
check point 1: v=tensor([0.9775, 0.6598, 0.7577, 0.2952])
check point 2: v=0.0
check point 3: v=tensor([0.8728, 0.9031, 0.2278, 0.5089, 0.1059, 0.0479, 0.5008, 0.8530])
'''
方便起见,还可以通过调用 spikingjelly.activation_based.functional.reset_net
将整个网络中的所有有状态模块进行重置。
若网络使用了有状态的模块,在训练和推理时,务必在处理完毕一个batch的数据后进行重置:
from spikingjelly.activation_based import functional
# ...
for x, label in tqdm(train_data_loader):
# ...
optimizer.zero_grad()
y = net(x)
loss = criterion(y, label)
loss.backward()
optimizer.step()
functional.reset_net(net)
# Never forget to reset the network!
如果忘了重置,在推理时可能输出错误的结果,而在训练时则会直接报错:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
传播模式
若一个网络全部由单步模块构成,则整个网络的计算顺序是按照逐步传播(step-by-step)的模式进行,例如:
for t in range(T):
x = x_seq[t]
y = net(x)
y_seq_step_by_step.append(y.unsqueeze(0))
y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)
如果网络全部由多步模块构成,则整个网络的计算顺序是按照逐层传播(layer-by-layer)的模式进行,例如:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4
N = 2
C = 8
x_seq = torch.rand([T, N, C]) * 64.
net = nn.Sequential(
layer.Linear(C, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)
functional.set_step_mode(net, step_mode='m')
with torch.no_grad():
y_seq_layer_by_layer = x_seq
for i in range(net.__len__()):
y_seq_layer_by_layer = net[i](y_seq_layer_by_layer)
在绝大多数情况下我们不需要显式的实现 for i in range(net.__len__())
这样的循环,因为 torch.nn.Sequential
已经帮我们实现过了,因此实际上我们可以这样做:
y_seq_layer_by_layer = net(x_seq)
逐步传播和逐层传播,实际上只是计算顺序不同,它们的计算结果是完全相同的:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4
N = 2
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W]) * 64.
net = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, stride=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
neuron.IFNode(),
layer.Flatten(start_dim=1),
layer.Linear(8 * H // 2 * W // 2, 10),
neuron.IFNode(),
)
print(f'net={net}')
with torch.no_grad():
y_seq_step_by_step = []
for t in range(T):
x = x_seq[t]
y = net(x)
y_seq_step_by_step.append(y.unsqueeze(0))
y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)
# we can also use `y_seq_step_by_step = functional.multi_step_forward(x_seq, net)` to get the same results
print(f'y_seq_step_by_step=\n{y_seq_step_by_step}')
functional.reset_net(net)
functional.set_step_mode(net, step_mode='m')
y_seq_layer_by_layer = net(x_seq)
max_error = (y_seq_layer_by_layer - y_seq_step_by_step).abs().max()
print(f'max_error={max_error}')
上面这段代码的输出为:
net=Sequential(
(0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=s)
(3): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(4): Flatten(start_dim=1, end_dim=-1, step_mode=s)
(5): Linear(in_features=128, out_features=10, bias=True)
(6): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
y_seq_step_by_step=
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0., 0., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 1., 0., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]]])
max_error=0.0
下面的图片展示了逐步传播构建计算图的顺序:

下面的图片展示了逐层传播构建计算图的顺序:

SNN的计算图有2个维度,分别是时间步数和网络深度,网络的传播实际上就是生成完整计算图的过程,正如上面的2张图片所示。实际上,逐步传播是深度优先遍历,而逐层传播是广度优先遍历。
尽管两者区别仅在于计算顺序,但计算速度和内存消耗上会略有区别。
在使用梯度替代法训练时,通常推荐使用逐层传播。在正确构建网络的情况下,逐层传播的并行度更大,速度更快
在内存受限时使用逐步传播,例如ANN2SNN任务中需要用到非常大的
T
。因为在逐层传播模式下,对无状态的层而言,真正的 batch size 是TN
而不是N
(参见下一个教程),当T
太大时内存消耗极大
包装器
本教程作者: fangwei123456
SpikingJelly中主要提供了如下几种包装器:
函数风格的
multi_step_forward
和模块风格的MultiStepContainer
函数风格的
seq_to_ann_forward
和模块风格的SeqToANNContainer
对单步模块进行包装以进行单步/多步传播的
StepModeContainer
multi_step_forward
可以将一个单步模块进行多步传播,而 MultiStepContainer
则可以将一个单步模块包装成多步模块,例如:
import torch
from spikingjelly.activation_based import neuron, functional, layer
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]
# z_seq is identical to y_seq
对于无状态的ANN网络层,例如 torch.nn.Conv2d
,其本身要求输入数据的 shape = [N, *]
,若用于多步模式,则可以用多步的包装器进行包装:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
但是ANN的网络层本身是无状态的,不存在前序依赖,没有必要在时间上串行的计算,可以使用函数风格的 seq_to_ann_forward
或模块风格的 SeqToANNContainer
进行包装。seq_to_ann_forward
将 shape = [T, N, *]
的数据首先变换为 shape = [TN, *]
,再送入无状态的网络层进行计算,输出的结果会被重新变换为 shape = [T, N, *]
。不同时刻的数据是并行计算的,因而速度更快:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
# p_seq.shape = [T, N, 8, H, W]
net = layer.SeqToANNContainer(conv, bn)
q_seq = net(x_seq)
# q_seq.shape = [T, N, 8, H, W]
# q_seq is identical to p_seq, and also identical to y_seq and z_seq
常用的网络层,在 spikingjelly.activation_based.layer
已经定义过,更推荐使用 spikingjelly.activation_based.layer
中的网络层,而不是使用 SeqToANNContainer
手动包装,尽管 spikingjelly.activation_based.layer
中的网络层实际上就是用包装器包装 forward 函数实现的。spikingjelly.activation_based.layer
中的网络层,优势在于:
支持单步和多步模式,而
SeqToANNContainer
和MultiStepContainer
包装的层,只支持多步模式包装器会使得
state_dict
的keys()
也增加一层包装,给加载权重带来麻烦
例如
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron
ann = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU()
)
print(f'ann.state_dict.keys()={ann.state_dict().keys()}')
net_container = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
),
neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')
net_origin = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')
try:
print('net_container is trying to load state dict from ann...')
net_container.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_container can not load! The error message is\n', e)
try:
print('net_origin is trying to load state dict from ann...')
net_origin.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_origin can not load! The error message is', e)
输出为
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!
MultiStepContainer
和 SeqToANNContainer
都是只支持多步模式的,不允许切换为单步模式。
StepModeContainer
类似于融合版的 MultiStepContainer
和 SeqToANNContainer
,可以用于包装无状态或有状态的单步模块,需要在包装时指明是否有状态,但此包装器还支持切换单步和多步模式。
包装无状态层的示例:
import torch
from spikingjelly.activation_based import neuron, layer
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
False,
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
包装有状态层的示例:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
functional.reset_net(net)
使用 set_step_mode
改变 StepModeContainer
是安全的,只会改变包装器本身的 step_mode
,而包装器内的模块仍然保持单步:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
functional.set_step_mode(net, 'm')
print(f'net.step_mode={net.step_mode}')
print(f'net[0].step_mode={net[0].step_mode}')
如果模块本身就支持单步和多步模式的切换,则不推荐使用 MultiStepContainer
或 StepModeContainer
对其进行包装。因为包装器使用的多步前向传播,可能不如模块自身定义的前向传播速度快。
通常需要用到 MultiStepContainer
或 StepModeContainer
的是一些没有定义多步的模块,例如一个在 torch.nn
中存在,但在 spikingjelly.activation_based.layer
中不存在的网络层。
神经元
本教程作者: fangwei123456
本节教程主要关注 spikingjelly.activation_based.neuron
,介绍脉冲神经元。
脉冲神经元模型
在 spikingjelly
中,我们约定,只能输出脉冲,即0或1的神经元,都可以称之为“脉冲神经元”。使用脉冲神经元的网络,进而也可以称之为脉冲神经元网络(Spiking Neural Networks, SNNs)。spikingjelly.activation_based.neuron
中定义了各种常见的脉冲神经元模型,我们以 spikingjelly.activation_based.neuron.IFNode
为例来介绍脉冲神经元。
首先导入相关的模块:
import torch
from spikingjelly.activation_based import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
新建一个IF神经元层:
if_layer = neuron.IFNode()
IF神经元层有一些构造参数,在API文档中对这些参数有详细的解释,我们暂时只关注下面几个重要的参数:
v_threshold – 神经元的阈值电压
v_reset – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
;如果设置为None
,则电压会被减去v_threshold
surrogate_function – 反向传播时用来计算脉冲函数梯度的替代函数
你可能会好奇这一层神经元的数量是多少。对于 spikingjelly.activation_based.neuron.IFNode
中的绝大多数神经元层,神经元的数量是在初始化或调用 reset()
函数重新初始化后,根据第一次接收的输入的 shape
自动决定的。
与RNN中的神经元非常类似,脉冲神经元也是有状态的,或者说是有记忆。脉冲神经元的状态变量,一般是它的膜电位 \(V[t]\)。因此,spikingjelly.activation_based.neuron
中的神经元,都有成员变量 v
。可以打印出刚才新建的IF神经元层的膜电位:
print(if_layer.v)
# if_layer.v=0.0
可以发现,现在的 if_layer.v
是 0.0
,因为我们还没有给与它任何输入。我们给与几个不同的输入,观察神经元的电压的 shape
,可以发现它与输入的数量是一致的:
x = torch.rand(size=[2, 3])
if_layer(x)
print(f'x.shape={x.shape}, if_layer.v.shape={if_layer.v.shape}')
# x.shape=torch.Size([2, 3]), if_layer.v.shape=torch.Size([2, 3])
if_layer.reset()
x = torch.rand(size=[4, 5, 6])
if_layer(x)
print(f'x.shape={x.shape}, if_layer.v.shape={if_layer.v.shape}')
# x.shape=torch.Size([4, 5, 6]), if_layer.v.shape=torch.Size([4, 5, 6])
if_layer.reset()
脉冲神经元是有状态的,在输入下一个样本前,一定要先调用 reset()
函数清除之前的状态。
\(V[t]\) 和输入 \(X[t]\) 的关系是什么样的?在脉冲神经元中,\(V[t]\) 不仅取决于当前时刻的输入 \(X[t]\),还取决于它在上一个时刻末的膜电位 \(V[t-1]\)。
通常使用阈下(指的是膜电位不超过阈值电压 V_{threshold}
时)神经动态方程 \(\frac{\mathrm{d}V(t)}{\mathrm{d}t} = f(V(t), X(t))\) 描述连续时间的脉冲神经元的充电过程,例如对于IF神经元,充电方程为:
spikingjelly.activation_based.neuron
中的神经元,使用离散的差分方程来近似连续的微分方程。在差分方程的视角下,IF神经元的充电方程为:
因此可以得到 \(V[t]\) 的表达式为
可以在 spikingjelly.activation_based.neuron.IFNode.neuronal_charge
中找到如下所示的代码:
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
不同的神经元,充电方程不尽相同。但膜电位超过阈值电压后,释放脉冲,以及释放脉冲后,膜电位的重置都是相同的。因此它们全部继承自 spikingjelly.activation_based.neuron.BaseNode
,共享相同的放电、重置方程。可以在 spikingjelly.activation_based.neuron.BaseNode.neuronal_fire
中找到释放脉冲的代码:
def neuronal_fire(self):
self.spike = self.surrogate_function(self.v - self.v_threshold)
surrogate_function()
在前向传播时是阶跃函数,只要输入大于或等于0,就会返回1,否则会返回0。我们将这种元素仅为0或1的 tensor
视为脉冲。
释放脉冲消耗了神经元之前积累的电荷,因此膜电位会有一个瞬间的降低,即膜电位的重置。在SNN中,对膜电位重置的实现,有2种方式:
Hard方式:释放脉冲后,膜电位直接被设置成重置电压:\(V[t] = V_{reset}\)
Soft方式:释放脉冲后,膜电位减去阈值电压:\(V[t] = V[t] - V_{threshold}\)
可以发现,对于使用Soft方式的神经元,并不需要重置电压 \(V_{reset}\) 这个变量。spikingjelly.activation_based.neuron
中的神经元,在构造函数的参数之一 v_reset
,默认为 1.0
,表示神经元使用Hard方式;若设置为 None
,则会使用Soft方式。在 spikingjelly.activation_based.neuron.BaseNode.neuronal_fire.neuronal_reset
中可以找到膜电位重置的代码:
# The following codes are for tutorials. The actual codes are different, but have the similar behavior.
def neuronal_reset(self):
if self.v_reset is None:
self.v = self.v - self.spike * self.v_threshold
else:
self.v = (1. - self.spike) * self.v + self.spike * self.v_reset
描述离散脉冲神经元的三个方程
至此,我们可以用充电、放电、重置,这3个离散方程来描述任意的离散脉冲神经元。充电、放电方程为:
其中 \(\Theta(x)\) 即为构造函数参数中的 surrogate_function
,是一个阶跃函数:
Hard方式重置方程为:
Soft方式重置方程为:
其中 \(X[t]\) 是外源输入,例如电压增量;为了避免混淆,我们使用 \(H[t]\) 表示神经元充电后、释放脉冲前的膜电位;\(V[t]\) 是神经元释放脉冲后的膜电位;\(f(V[t-1], X[t])\) 是神经元的状态更新方程,不同的神经元,区别就在于更新方程不同。
神经元的动态如下图所示(图片来自 Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks):

仿真
接下来,我们将逐步给与神经元输入,并查看它的膜电位和输出脉冲。
现在让我们给与IF神经元层持续的输入,并画出其放电后的膜电位和输出脉冲:
if_layer.reset()
x = torch.as_tensor([0.02])
T = 150
s_list = []
v_list = []
for t in range(T):
s_list.append(if_layer(x))
v_list.append(if_layer.v)
dpi = 300
figsize = (12, 8)
visualizing.plot_one_neuron_v_s(torch.cat(v_list).numpy(), torch.cat(s_list).numpy(), v_threshold=if_layer.v_threshold,
v_reset=if_layer.v_reset,
figsize=figsize, dpi=dpi)
plt.show()
我们给与的输入 shape=[1]
,因此这个IF神经元层只有1个神经元。它的膜电位和输出脉冲随着时间变化情况如下:
下面我们将神经元层重置,并给与 shape=[32]
的输入,查看这32个神经元的膜电位和输出脉冲:
if_layer.reset()
T = 50
x = torch.rand([32]) / 8.
s_list = []
v_list = []
for t in range(T):
s_list.append(if_layer(x).unsqueeze(0))
v_list.append(if_layer.v.unsqueeze(0))
s_list = torch.cat(s_list)
v_list = torch.cat(v_list)
figsize = (12, 8)
dpi = 200
visualizing.plot_2d_heatmap(array=v_list.numpy(), title='membrane potentials', xlabel='simulating step',
ylabel='neuron index', int_x_ticks=True, x_max=T, figsize=figsize, dpi=dpi)
visualizing.plot_1d_spikes(spikes=s_list.numpy(), title='membrane sotentials', xlabel='simulating step',
ylabel='neuron index', figsize=figsize, dpi=dpi)
plt.show()
结果如下:
步进模式和后端
在 基本概念 中我们已经介绍过单步和多步模式,在本教程前面的内容中,我们使用的都是单步模式。切换成多步模式非常简单,只需要设置 step_mode
即可:
import torch
from spikingjelly.activation_based import neuron, functional
if_layer = neuron.IFNode(step_mode='s')
T = 8
N = 2
x_seq = torch.rand([T, N])
y_seq = functional.multi_step_forward(x_seq, if_layer)
if_layer.reset()
if_layer.step_mode = 'm'
y_seq = if_layer(x_seq)
if_layer.reset()
此外,部分神经元在多步模式下支持 cupy
后端。在 cupy
模式下,前反向传播会使用CuPy进行加速:
import torch
from spikingjelly.activation_based import neuron
if_layer = neuron.IFNode()
print(f'if_layer.backend={if_layer.backend}')
# if_layer.backend=torch
print(f'step_mode={if_layer.step_mode}, supported_backends={if_layer.supported_backends}')
# step_mode=s, supported_backends=('torch',)
if_layer.step_mode = 'm'
print(f'step_mode={if_layer.step_mode}, supported_backends={if_layer.supported_backends}')
# step_mode=m, supported_backends=('torch', 'cupy')
device = 'cuda:0'
if_layer.to(device)
if_layer.backend = 'cupy' # switch to the cupy backend
print(f'if_layer.backend={if_layer.backend}')
# if_layer.backend=cupy
x_seq = torch.rand([8, 4], device=device)
y_seq = if_layer(x_seq)
if_layer.reset()
自定义神经元
如前所述,SpikingJelly使用充电、放电、重置三个方程来描述脉冲神经元,在 BaseNode
中可以找到对应的代码,单步模式下的前向传播 single_step_forward
函数即是由这3个过程组成:
# spikingjelly.activation_based.neuron.BaseNode
def single_step_forward(self, x: torch.Tensor):
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
其中 neuronal_fire
和 neuronal_reset
对绝大多数神经元都是相同的,因而在 BaseNode
中就已经定义了。不同的神经元主要是构造函数和充电方程 neuronal_charge
不同。因此,若想实现新的神经元,则只需要更改构造函数和充电方程即可。
假设我们构造一种平方积分发放神经元,其充电方程为:
实现方式如下:
import torch
from spikingjelly.activation_based import neuron
class SquareIFNode(neuron.BaseNode):
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x ** 2
BaseNode
继承自 MemoryModule
。MemoryModule
默认的多步传播,是使用 for t in range(T)
来循环调用单步传播实现的。因此我们定义 neuronal_charge
后, single_step_forward
就已经是完整的了,进而 multi_step_forward
也可以被使用。
使用平方积分发放神经元进行单步或多步传播:
import torch
from spikingjelly.activation_based import neuron
class SquareIFNode(neuron.BaseNode):
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x ** 2
sif_layer = SquareIFNode()
T = 4
N = 1
x_seq = torch.rand([T, N])
print(f'x_seq={x_seq}')
for t in range(T):
yt = sif_layer(x_seq[t])
print(f'sif_layer.v[{t}]={sif_layer.v}')
sif_layer.reset()
sif_layer.step_mode = 'm'
y_seq = sif_layer(x_seq)
print(f'y_seq={y_seq}')
sif_layer.reset()
输出为
x_seq=tensor([[0.7452],
[0.8062],
[0.6730],
[0.0942]])
sif_layer.v[0]=tensor([0.5554])
sif_layer.v[1]=tensor([0.])
sif_layer.v[2]=tensor([0.4529])
sif_layer.v[3]=tensor([0.4618])
y_seq=tensor([[0.],
[1.],
[0.],
[0.]])
梯度替代
本教程作者: fangwei123456
在 神经元 中我们已经提到过,描述神经元放电过程的 \(S[t] = \Theta(H[t] - V_{threshold})\),使用了一个Heaviside阶跃函数:
按照定义,其导数为冲激函数:
直接使用冲激函数进行梯度下降,显然会使得网络的训练及其不稳定。为了解决这一问题,各种梯度替代法(the surrogate gradient method)被相继提出,参见此综述 Surrogate Gradient Learning in Spiking Neural Networks。
替代函数在神经元中被用于生成脉冲,查看 BaseNode.neuronal_fire
的源代码可以发现:
# spikingjelly.activation_based.neuron
class BaseNode(base.MemoryModule):
def __init__(..., surrogate_function: Callable = surrogate.Sigmoid(), ...)
# ...
self.surrogate_function = surrogate_function
# ...
def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
梯度替代法的原理是,在前向传播时使用 \(y = \Theta(x)\),而在反向传播时则使用 \(\frac{\mathrm{d}y}{\mathrm{d}x} = \sigma'(x)\),而非\(\frac{\mathrm{d}y}{\mathrm{d}x} = \Theta'(x)\),其中 \(\sigma(x)\) 即为替代函数。\(\sigma(x)\) 通常是一个形状与 \(\Theta(x)\) 类似,但光滑连续的函数。
在 spikingjelly.activation_based.surrogate
中提供了一些常用的替代函数,其中Sigmoid函数 \(\sigma(x, \alpha) = \frac{1}{1 + \exp(-\alpha x)}\) 为 spikingjelly.activation_based.surrogate.Sigmoid
,下图展示了原始的Heaviside阶跃函数 Heaviside
、 alpha=5
时的Sigmoid原函数 Primitive
以及其梯度 Gradient
:
替代函数的使用比较简单,使用替代函数就像是使用函数一样:
import torch
from spikingjelly.activation_based import surrogate
sg = surrogate.Sigmoid(alpha=4.)
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = sg(x)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
输出为:
x=tensor([-0.1303, 0.4976, 0.3364, 0.4296, 0.2779, 0.4580, 0.4447, 0.2466],
requires_grad=True)
y=tensor([0., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<sigmoidBackward>)
x.grad=tensor([0.9351, 0.4231, 0.6557, 0.5158, 0.7451, 0.4759, 0.4943, 0.7913])
每个替代函数,除了有形如 spikingjelly.activation_based.surrogate.Sigmoid
的模块风格API,也提供了形如 spikingjelly.activation_based.surrogate.sigmoid
函数风格的API。模块风格的API使用驼峰命名法,而函数风格的API使用下划线命名法,关系类似于 torch.nn
和 torch.nn.functional
,下面是几个示例:
模块 |
函数 |
---|---|
|
|
|
|
|
|
下面是函数风格API的用法示例:
import torch
from spikingjelly.activation_based import surrogate
alpha = 4.
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = surrogate.sigmoid.apply(x, alpha)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
替代函数通常会有1个或多个控制形状的超参数,例如 spikingjelly.activation_based.surrogate.Sigmoid
中的 alpha
。SpikingJelly中替代函数的形状参数,默认情况下是使得替代函数梯度最大值为1,这在一定程度上可以避免梯度累乘导致的梯度爆炸问题。
监视器
本教程作者: fangwei123456
在 spikingjelly.activation_based.monitor
中定义了几个通用的监视器类,用户可以使用这些监视器实现复杂的数据记录功能。下面以一个简单的网络为例进行介绍。
基本使用
所有的监视器的用法类似,以 spikingjelly.activation_based.monitor.OutputMonitor
为例进行介绍。
首先我们搭建起一个简单的多步网络。为了避免无脉冲释放,我们将权重全部设置为正值:
import torch
import torch.nn as nn
from spikingjelly.activation_based import monitor, neuron, functional, layer
net = nn.Sequential(
layer.Linear(8, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)
for param in net.parameters():
param.data.abs_()
functional.set_step_mode(net, 'm')
spikingjelly.activation_based.monitor.OutputMonitor
可以记录网络中任何类型为 instance
的模块的输出。脉冲神经元层的输出即为脉冲,因此我们可以使用 OutputMonitor
来构建一个脉冲监视器,记录网络中所有 neuron.IFNode
的输出脉冲:
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
net(x_seq)
要记录的数据,会根据生成顺序,保存在 .records
的 list
中:
print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
输出为:
spike_seq_monitor.records=
[tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]]), tensor([[[0., 0.]],
[[1., 0.]],
[[0., 1.]],
[[1., 0.]]])]
也可以使用索引操作,直接访问被记录的第 i
个数据:
print(f'spike_seq_monitor[0]={spike_seq_monitor[0]}')
输出为:
spike_seq_monitor[0]=tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]])
.monitored_layers
记录了被监视器监控的层的名字:
print(f'net={net}')
print(f'spike_seq_monitor.monitored_layers={spike_seq_monitor.monitored_layers}')
输出为:
net=Sequential(
(0): Linear(in_features=8, out_features=4, bias=True)
(1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): Linear(in_features=4, out_features=2, bias=True)
(3): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
spike_seq_monitor.monitored_layers=['1', '3']
可以直接通过层的名字作为索引,访问某一层被记录的数据。这返回的是一个 list
:
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")
输出为:
spike_seq_monitor['1']=[tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]])]
可以通过调用 .clear_recorded_data()
来清空已经记录的数据:
spike_seq_monitor.clear_recorded_data()
print(f'spike_seq_monitor.records={spike_seq_monitor.records}')
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")
输出为:
spike_seq_monitor.records=[]
spike_seq_monitor['1']=[]
所有的 monitor
在析构时都会自动删除已经注册的钩子,但python的内存回收机制并不保证在手动调用 del
时一定会进行析构。因此删除一个监视器,并不能保证钩子也立刻被删除:
del spike_seq_monitor
# 钩子可能仍然在起作用
若想立刻删除钩子,应该通过以下方式:
spike_seq_monitor.remove_hooks()
OutputMonitor
还支持在记录数据时就对数据进行简单的处理,只需要指定构造函数中的 function_on_output
即可。 function_on_output
的默认值是 lambda x: x
,也就是默认不进行任何处理。我们想要记录每个时刻的脉冲发放频率,首先要定义脉冲发放频率如何计算:
def cal_firing_rate(s_seq: torch.Tensor):
# s_seq.shape = [T, N, *]
return s_seq.flatten(1).mean(1)
接下来就可以以此来构建发放率监视器:
fr_monitor = monitor.OutputMonitor(net, neuron.IFNode, cal_firing_rate)
通过 .disable()
可以让 monitor
暂停记录,而 .enable()
则可以让其重新开始记录:
with torch.no_grad():
functional.reset_net(net)
fr_monitor.disable()
net(x_seq)
functional.reset_net(net)
print(f'after call fr_monitor.disable(), fr_monitor.records=\n{fr_monitor.records}')
fr_monitor.enable()
net(x_seq)
print(f'after call fr_monitor.enable(), fr_monitor.records=\n{fr_monitor.records}')
functional.reset_net(net)
del fr_monitor
输出为:
after call fr_monitor.disable(), fr_monitor.records=
[]
after call fr_monitor.enable(), fr_monitor.records=
[tensor([0.0000, 1.0000, 0.5000, 1.0000]), tensor([0., 1., 0., 1.])]
记录模块成员变量
若想记录模块的成员变量,例如神经元的电压,可以通过 spikingjelly.activation_based.monitor.AttributeMonitor
实现。
神经元构造参数中的 store_v_seq: bool = False
表示在默认情况下,只记录当前时刻的电压,不记录所有时刻的电压序列。现在我们想记录所有时刻的电压,则将其更改为 True
:
for m in net.modules():
if isinstance(m, neuron.IFNode):
m.store_v_seq = True
接下来,新建记录电压序列的监视器并进行记录:
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.IFNode)
with torch.no_grad():
net(x_seq)
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
functional.reset_net(net)
del v_seq_monitor
输出为:
v_seq_monitor.records=
[tensor([[[0.8102, 0.8677, 0.8153, 0.9200]],
[[0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.8129, 0.0000, 0.9263]],
[[0.0000, 0.0000, 0.0000, 0.0000]]]), tensor([[[0.2480, 0.4848]],
[[0.0000, 0.0000]],
[[0.8546, 0.6674]],
[[0.0000, 0.0000]]])]
记录模块输入
设置输入监视器的方法,和设置输出监视器的如出一辙:
input_monitor = monitor.InputMonitor(net, neuron.IFNode)
with torch.no_grad():
net(x_seq)
print(f'input_monitor.records=\n{input_monitor.records}')
functional.reset_net(net)
del input_monitor
输出为:
input_monitor.records=
[tensor([[[1.1710, 0.7936, 0.9325, 0.8227]],
[[1.4373, 0.7645, 1.2167, 1.3342]],
[[1.6011, 0.9850, 1.2648, 1.2650]],
[[0.9322, 0.6143, 0.7481, 0.9770]]]), tensor([[[0.8072, 0.7733]],
[[1.1186, 1.2176]],
[[1.0576, 1.0153]],
[[0.4966, 0.6030]]])]
记录模块的输入梯度 \(\frac{\partial L}{\partial Y}\)
如果我们想要记录每一层脉冲神经元的输入梯度 \(\frac{\partial L}{\partial S}\),则可以使用 spikingjelly.activation_based.monitor.GradOutputMonitor
轻松实现:
spike_seq_grad_monitor = monitor.GradOutputMonitor(net, neuron.IFNode)
net(x_seq).sum().backward()
print(f'spike_seq_grad_monitor.records=\n{spike_seq_grad_monitor.records}')
functional.reset_net(net)
del spike_seq_grad_monitor
输出为:
spike_seq_grad_monitor.records=
[tensor([[[1., 1.]],
[[1., 1.]],
[[1., 1.]],
[[1., 1.]]]), tensor([[[ 0.0803, 0.0383, 0.1035, 0.1177]],
[[-0.1013, -0.1346, -0.0561, -0.0085]],
[[ 0.5364, 0.6285, 0.3696, 0.1818]],
[[ 0.3704, 0.4747, 0.2201, 0.0596]]])]
由于我们使用 .sum().backward()
,因而损失传给最后一层输出脉冲的梯度全为1。
记录模块的输出梯度 \(\frac{\partial L}{\partial X}\)
使用 spikingjelly.activation_based.monitor.GradInputMonitor
可以轻松记录模块的输出梯度 \(\frac{\partial L}{\partial X}\)。
让我们构建一个深度网络,调节替代函数的 alpha
并比较不同 alpha
下的梯度的幅值:
import torch
import torch.nn as nn
from spikingjelly.activation_based import monitor, neuron, functional, layer, surrogate
net = []
for i in range(10):
net.append(layer.Linear(8, 8))
net.append(neuron.IFNode())
net = nn.Sequential(*net)
functional.set_step_mode(net, 'm')
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
input_grad_monitor = monitor.GradInputMonitor(net, neuron.IFNode, function_on_grad_input=torch.norm)
for alpha in [0.1, 0.5, 2, 4, 8]:
for m in net.modules():
if isinstance(m, surrogate.Sigmoid):
m.alpha = alpha
net(x_seq).sum().backward()
print(f'alpha={alpha}, input_grad_monitor.records=\n{input_grad_monitor.records}\n')
functional.reset_net(net)
# zero grad
for param in net.parameters():
param.grad.zero_()
input_grad_monitor.records.clear()
输出为:
alpha=0.1, input_grad_monitor.records=
[tensor(0.3868), tensor(0.0138), tensor(0.0003), tensor(9.1888e-06), tensor(1.0164e-07), tensor(1.9384e-09), tensor(4.0199e-11), tensor(8.6942e-13), tensor(1.3389e-14), tensor(2.7714e-16)]
alpha=0.5, input_grad_monitor.records=
[tensor(1.7575), tensor(0.2979), tensor(0.0344), tensor(0.0045), tensor(0.0002), tensor(1.5708e-05), tensor(1.6167e-06), tensor(1.6107e-07), tensor(1.1618e-08), tensor(1.1097e-09)]
alpha=2, input_grad_monitor.records=
[tensor(3.3033), tensor(1.2917), tensor(0.4673), tensor(0.1134), tensor(0.0238), tensor(0.0040), tensor(0.0008), tensor(0.0001), tensor(2.5466e-05), tensor(3.9537e-06)]
alpha=4, input_grad_monitor.records=
[tensor(3.5353), tensor(1.6377), tensor(0.7076), tensor(0.2143), tensor(0.0369), tensor(0.0069), tensor(0.0026), tensor(0.0006), tensor(0.0003), tensor(8.5736e-05)]
alpha=8, input_grad_monitor.records=
[tensor(4.3944), tensor(2.4396), tensor(0.8996), tensor(0.4376), tensor(0.0640), tensor(0.0122), tensor(0.0053), tensor(0.0016), tensor(0.0013), tensor(0.0005)]
使用单层全连接SNN识别MNIST
本教程作者:Yanqi-Chen
本节教程将介绍如何使用编码器与替代梯度方法训练一个最简单的MNIST分类网络。
从头搭建一个简单的SNN网络
在PyTorch中搭建神经网络时,我们可以简单地使用nn.Sequential
将多个网络层堆叠得到一个前馈网络,输入数据将依序流经各个网络层得到输出。
MNIST数据集包含若干尺寸为\(28\times 28\)的8位灰度图像,总共有0~9共10个类别。以MNIST的分类为例,一个简单的单层ANN网络如下:
nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 10, bias=False),
nn.Softmax()
)
我们也可以用完全类似结构的SNN来进行分类任务。就这个网络而言,只需要先去掉所有的激活函数,再将神经元添加到原来激活函数的位置,这里我们选择的是LIF神经元。神经元之间的连接层需要用spikingjelly.activation_based.layer
包装:
nn.Sequential(
layer.Flatten(),
layer.Linear(28 * 28, 10, bias=False),
neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
)
其中膜电位衰减常数\(\tau\)需要通过参数tau
设置,替代函数这里选择surrogate.ATan
。
训练SNN网络
首先指定好训练参数如学习率等以及若干其他配置
优化器默认使用Adam,以及使用泊松编码器,在每次输入图片时进行脉冲编码
# 使用Adam优化器
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# 使用泊松编码器
encoder = encoding.PoissonEncoder()
训练代码的编写需要遵循以下三个要点:
脉冲神经元的输出是二值的,而直接将单次运行的结果用于分类极易受到编码带来的噪声干扰。因此一般认为脉冲网络的输出是输出层一段时间内的发放频率(或称发放率),发放率的高低表示该类别的响应大小。因此网络需要运行一段时间,即使用
T
个时刻后的平均发放率作为分类依据。我们希望的理想结果是除了正确的神经元以最高频率发放,其他神经元保持静默。常常采用交叉熵损失或者MSE损失,这里我们使用实际效果更好的MSE损失。
每次网络仿真结束后,需要重置网络状态
结合以上三点,得到训练循环的核心代码如下:
for epoch in range(start_epoch, args.epochs):
start_time = time.time()
net.train()
train_loss = 0
train_acc = 0
train_samples = 0
for img, label in train_data_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).float()
# 混合精度训练
if scaler is not None:
with amp.autocast():
out_fr = 0.
# 运行T个时间步
for t in range(args.T):
encoded_img = encoder(img)
out_fr += net(encoded_img)
out_fr = out_fr / args.T
# out_fr是shape=[batch_size, 10]的tensor
# 记录整个仿真时长内,输出层的10个神经元的脉冲发放率
loss = F.mse_loss(out_fr, label_onehot)
# 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
# 这样的损失函数会使得:当标签i给定时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
out_fr = 0.
for t in range(args.T):
encoded_img = encoder(img)
out_fr += net(encoded_img)
out_fr = out_fr / args.T
loss = F.mse_loss(out_fr, label_onehot)
loss.backward()
optimizer.step()
train_samples += label.numel()
train_loss += loss.item() * label.numel()
# 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
train_acc += (out_fr.argmax(1) == label).float().sum().item()
# 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
functional.reset_net(net)
完整的代码位于activation_based.examples.lif_fc_mnist.py
,在代码中我们还使用了Tensorboard来保存训练日志。可以直接在命令行运行它:
$ python -m spikingjelly.activation_based.examples.lif_fc_mnist --help
usage: lif_fc_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N]
[-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-opt {sgd,adam}]
[-momentum MOMENTUM] [-lr LR] [-tau TAU]
LIF MNIST Training
optional arguments:
-h, --help show this help message and exit
-T T simulating time-steps
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-opt {sgd,adam} use which optimizer. SGD or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
-tau TAU parameter tau of LIF neuron
需要注意的是,训练这样的SNN,所需显存数量与仿真时长 T
线性相关,更长的 T
相当于使用更小的仿真步长,训练更为“精细”,但训练效果不一定更好。T
太大时,SNN在时间上展开后会变成一个非常深的网络,这将导致BPTT计算梯度时容易衰减或爆炸。
另外由于我们使用了泊松编码器,因此需要较大的 T
保证编码带来的噪声不太大。
训练结果
取tau=2.0,T=100,batch_size=64,lr=1e-3
,对应的运行命令为
python -m spikingjelly.activation_based.examples.lif_fc_mnist -tau 2.0 -T 100 -device cuda:0 -b 64 -epochs 100 -data-dir <PATH to MNIST> -amp -opt adam -lr 1e-3 -j 8
其中为了加快训练速度,启用了混合精度训练。训练100个Epoch后,将会输出两个npy文件以及训练日志。测试集上的最高正确率为92.9%,通过matplotlib可视化得到的正确率曲线如下
选取测试集中第一张图片:

用训好的模型进行分类,得到分类结果
Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
通过visualizing
模块中的函数可视化得到输出层的电压以及脉冲如下图所示
可以看到除了正确类别对应的神经元外,其它神经元均未发放任何脉冲。完整的训练代码可见 activation_based/examples/lif_fc_mnist.py 。
使用卷积SNN识别Fashion-MNIST
本教程作者: fangwei123456
在本节教程中,我们将搭建一个卷积脉冲神经网络,对 Fashion-MNIST 数据集进行
分类。Fashion-MNIST数据集,与MNIST数据集的格式相同,均为 1 * 28 * 28
的灰度图片。
网络结构
我们使用最常见的卷积神经网络结构。具体而言,网络结构为:
{Conv2d-BatchNorm2d-IFNode-MaxPool2d}-{Conv2d-BatchNorm2d-IFNode-MaxPool2d}-{Linear-IFNode}
网络结构的定义如下:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
super().__init__()
self.T = T
self.conv_fc = nn.Sequential(
layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 14 * 14
layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 7 * 7
layer.Flatten(),
layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(channels * 4 * 4, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
)
为了更快的训练速度,我们将网络设置成多步模式,并根据构造函数的要求,决定是否使用 cupy
后端:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
# ...
functional.set_step_mode(self, step_mode='m')
if use_cupy:
functional.set_backend(self, backend='cupy')
将图片直接输入到SNN,而不是编码后在输入,是近年来深度SNN的常见做法,我们在此教程中也使用这样的方法。在这种情况下,实际的 图片-脉冲
编码是由网络中的前三层,也就是 {Conv2d-BatchNorm2d-IFNode}
完成。
网络的输入直接是 shape=[N, C, H, W]
的图片,我们将其添加时间维度,并复制 T
次,得到 shape=[T, N, C, H, W]
的序列,然后送入到网络层。网络的输出定义为最后一层脉冲神经元的脉冲发放频率。因而,网络的前向传播定义为:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
x_seq = self.conv_fc(x_seq)
fr = x_seq.mean(0)
return fr
网络训练
网络的训练方式、损失函数定义、分类结果的确定均与上一节教程相同,不再赘述。唯一的区别是,使用Fashion-MNIST数据集:
# spikingjelly.activation_based.examples.conv_fashion_mnist
train_set = torchvision.datasets.FashionMNIST(
root=args.data_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_set = torchvision.datasets.FashionMNIST(
root=args.data_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
可以使用如下命令查看训练参数:
(sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h
usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR] [-channels CHANNELS]
Classify Fashion-MNIST
optional arguments:
-h, --help show this help message and exit
-T T simulating time-steps
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of Fashion-MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-cupy use cupy backend
-opt OPT use which optimizer. SDG or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
-channels CHANNELS channels of CSNN
-save-es SAVE_ES dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}
我们使用如下命令进行训练,其中为了加快训练速度,启用了混合精度训练和CuPy后端:
python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 128 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8
输出为:
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
CSNN(
(conv_fc): Sequential(
(0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(6): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(8): Flatten(start_dim=1, end_dim=-1, step_mode=m)
(9): Linear(in_features=6272, out_features=2048, bias=False)
(10): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(11): Linear(in_features=2048, out_features=10, bias=False)
(12): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
Mkdir ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy.
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =0, train_loss = 0.0325, train_acc = 0.7875, test_loss = 0.0248, test_acc = 0.8543, max_test_acc = 0.8543
train speed = 7109.7899 images/s, test speed = 7936.2602 images/s
escape time = 2022-05-24 21:42:15
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =1, train_loss = 0.0217, train_acc = 0.8734, test_loss = 0.0201, test_acc = 0.8758, max_test_acc = 0.8758
train speed = 7712.5343 images/s, test speed = 7902.5029 images/s
escape time = 2022-05-24 21:43:13
...
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =63, train_loss = 0.0024, train_acc = 0.9941, test_loss = 0.0113, test_acc = 0.9283, max_test_acc = 0.9308
train speed = 7627.8147 images/s, test speed = 7868.9090 images/s
escape time = 2022-05-24 21:42:16
最终获得了 max_test_acc = 0.9308
的性能。如果精心调整超参数,通常还可以获得更高的性能。
下图展示了训练过程中正确率的变化:
可视化编码器
如前所述,我们将图片直接送入网络,实际的编码过程是由网络中的首个 {Conv2d-BatchNorm2d-IFNode}
实现的。现在让我们提取出网络中的编码器,输入图片,并将输出脉冲可视化,代码如下:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
# ...
def spiking_encoder(self):
return self.conv_fc[0:3]
def main():
# ...
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch'] + 1
max_test_acc = checkpoint['max_test_acc']
if args.save_es is not None and args.save_es != '':
encoder = net.spiking_encoder()
with torch.no_grad():
for img, label in test_data_loader:
img = img.to(args.device)
label = label.to(args.device)
# img.shape = [N, C, H, W]
img_seq = img.unsqueeze(0).repeat(net.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
spike_seq = encoder(img_seq)
functional.reset_net(encoder)
to_pil_img = torchvision.transforms.ToPILImage()
vs_dir = os.path.join(args.save_es, 'visualization')
os.mkdir(vs_dir)
img = img.cpu()
spike_seq = spike_seq.cpu()
img = F.interpolate(img, scale_factor=4, mode='bilinear')
# 28 * 28 is too small to read. So, we interpolate it to a larger size
for i in range(label.shape[0]):
vs_dir_i = os.path.join(vs_dir, f'{i}')
os.mkdir(vs_dir_i)
to_pil_img(img[i]).save(os.path.join(vs_dir_i, f'input.png'))
for t in range(net.T):
print(f'saving {i}-th sample with t={t}...')
# spike_seq.shape = [T, N, C, H, W]
visualizing.plot_2d_feature_map(spike_seq[t][i], 8, spike_seq.shape[2] // 8, 2, f'$S[{t}]$')
plt.savefig(os.path.join(vs_dir_i, f's_{t}.png'))
plt.savefig(os.path.join(vs_dir_i, f's_{t}.pdf'))
plt.savefig(os.path.join(vs_dir_i, f's_{t}.svg'))
plt.clf()
exit()
# ...
我们加载已经训练好的模型,设置 batch_size=4
(表示我们只保存4张图片和对应的编码后的脉冲),将图片保存到 ./logs
下,按照如下命令运行:
python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 4 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -resume ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy/checkpoint_latest.pth -save-es ./logs
运行后图片会保存到 ./logs/visualization
文件夹中。下面展示2个输入图片,和对应的编码后的脉冲:


神经形态数据集处理
本教程作者: fangwei123456
spikingjelly.datasets
中集成了常用的神经形态数据集,包括 N-MNIST 1, CIFAR10-DVS 2, DVS128 Gesture 3, N-Caltech101 1, ASLDVS 4 等。所有数据集的处理都遵循类似的步骤,开发人员也可以很轻松的添加新数据集代码。在本节教程中,我
们将以 DVS128 Gesture 为例,展示如何使用惊蜇框架处理神经形态数据集。
自动下载和手动下载
CIFAR10-DVS等数据集支持自动下载。支持自动下载的数据集,在首次运行时原始数据集将会被下载到数据集根目录下的 download
文件夹。每个数据集的 downloadable()
函数定义了该数据集是否能够自动下载,而 resource_url_md5()
函数定义了各个文件的下载链接和MD5。示例:
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
print('CIFAR10-DVS downloadable', CIFAR10DVS.downloadable())
print('resource, url, md5/n', CIFAR10DVS.resource_url_md5())
print('DVS128Gesture downloadable', DVS128Gesture.downloadable())
print('resource, url, md5/n', DVS128Gesture.resource_url_md5())
输出为:
CIFAR10-DVS downloadable True
resource, url, md5
[('airplane.zip', 'https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'), ('automobile.zip', 'https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'), ('bird.zip', 'https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'), ('cat.zip', 'https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'), ('deer.zip', 'https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'), ('dog.zip', 'https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'), ('frog.zip', 'https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'), ('horse.zip', 'https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'), ('ship.zip', 'https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'), ('truck.zip', 'https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7')]
DVS128Gesture downloadable False
resource, url, md5
[('DvsGesture.tar.gz', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '8a5c71fb11e24e5ca5b11866ca6c00a1'), ('gesture_mapping.csv', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '109b2ae64a0e1f3ef535b18ad7367fd1'), ('LICENSE.txt', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '065e10099753156f18f51941e6e44b66'), ('README.txt', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', 'a0663d3b1d8307c329a43d949ee32d19')]
DVS128 Gesture数据集不支持自动下载,但它的 resource_url_md5()
函数会打印出获取下载地址的网址。DVS128 Gesture数据集可以从 https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794 进行下载。box网站不支持在不登陆的情况下使用代码直接下载,因此用户需要手动从网站上下载。将数据集下载到 E:/datasets/DVS128Gesture/download
,下载完成后这个文件夹的目录结构为
.
|-- DvsGesture.tar.gz
|-- LICENSE.txt
|-- README.txt
`-- gesture_mapping.csv
备注
不同框架对DVS128 Gesture数据集的预处理方式可能不同,这或许导致不同的训练集和测试机样本数量。请参考 spikingjelly.datasets.dvs128_gesture.DVS128Gesture
的API文档获取更多信息。
获取Event数据
创建训练集和测试集,其中参数 data_type='event'
表示我们使用Event数据。
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='event')
运行这段代码,惊蜇框架将会完成以下工作:
检测数据集是否存在,如果存在,则进行MD5校验,确认数据集无误后,开始进行解压。将原始数据解压到同级目录下的
extract
文件夹DVS128 Gesture中的每个样本,是在不同光照环境下,对不同表演者进行录制的手势视频。一个AER文件中包含了多个手势,对应的会有一个csv文件来标注整个视频内各个时间段内都是哪种手势。因此,单个的视频文件并不是一个类别,而是多个类别的集合。惊蜇框架会启动多线程进行划分,将每个视频中的每个手势类别文件单独提取出来
下面是运行过程中的命令行输出:
The [D:/datasets/DVS128Gesture/download] directory for saving downloaed files already exists, check files...
Mkdir [D:/datasets/DVS128Gesture/extract].
Extract [D:/datasets/DVS128Gesture/download/DvsGesture.tar.gz] to [D:/datasets/DVS128Gesture/extract].
Mkdir [D:/datasets/DVS128Gesture/events_np].
Start to convert the origin data from [D:/datasets/DVS128Gesture/extract] to [D:/datasets/DVS128Gesture/events_np] in np.ndarray format.
Mkdir [('D:/datasets/DVS128Gesture//events_np//train', 'D:/datasets/DVS128Gesture//events_np//test').
Mkdir ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9'] in [D:/datasets/DVS128Gesture/events_np/train] and ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9'] in [D:/datasets/DVS128Gesture/events_np/test].
Start the ThreadPoolExecutor with max workers = [8].
Start to split [D:/datasets/DVS128Gesture/extract/DvsGesture/user02_fluorescent.aedat] to samples.
[D:/datasets/DVS128Gesture/events_np/train/0/user02_fluorescent_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/train/1/user02_fluorescent_0.npz] saved.
......
[D:/datasets/DVS128Gesture/events_np/test/8/user29_lab_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/test/9/user29_lab_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/test/10/user29_lab_0.npz] saved.
Used time = [1017.27s].
All aedat files have been split to samples and saved into [('D:/datasets/DVS128Gesture//events_np//train', 'D:/datasets/DVS128Gesture//events_np//test')].
提取各个手势类别的速度较慢,需要耐心等待。运行完成后,同级目录下会多出一个 events_np
文件夹,其中包含训练集和测试集:
|-- events_np
| |-- test
| `-- train
打印一个数据:
event, label = train_set[0]
for k in event.keys():
print(k, event[k])
print('label', label)
得到输出为:
t [80048267 80048277 80048278 ... 85092406 85092538 85092700]
x [49 55 55 ... 60 85 45]
y [82 92 92 ... 96 86 90]
p [1 0 0 ... 1 0 0]
label 0
其中 event
使用字典格式存储Events数据,键为 ['t', 'x', 'y', 'p']
;label
是数据的标签,DVS128 Gesture共有11类。
获取Frame数据
将原始的Event流积分成Frame数据,是常用的处理方法,我们采用 5 的实现方式。。我们将原始的Event数据记为 \(E(x_{i}, y_{i}, t_{i}, p_{i}), 0 \leq i < N\);设置 split_by='number'
表示从Event数量 \(N\) 上进行划分,接近均匀地划分为 frames_num=20
, 也就是 \(T\) 段。记积分后的Frame数据中的某一帧
为 \(F(j)\),在 \((p, x, y)\) 位置的像素值为 \(F(j, p, x, y)\);\(F(j)\) 是从Event流中索引介于 \(j_{l}\) 和 \(j_{r}\) 的Event
积分而来:
其中 \(\lfloor \cdot \rfloor\) 是向下取整,\(\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) 是示性函数,当且仅当 \((p, x, y) = (p_{i}, x_{i}, y_{i})\) 时取值为1,否则为0。
运行下列代码,惊蜇框架就会开始进行积分,创建Frame数据集:
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
命令行的输出为:
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/1].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/10].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/2].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/3].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/4].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/5].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/6].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/7].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/8].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/9].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/0].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/1].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/10].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/2].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/3].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/4].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/5].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/6].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/7].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/8].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9].
Start ThreadPoolExecutor with max workers = [8].
Start to integrate [D:/datasets/DVS128Gesture/events_np/test/0/user24_fluorescent_0.npz] to frames and save to [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
Start to integrate [D:/datasets/DVS128Gesture/events_np/test/0/user24_fluorescent_led_0.npz] to frames and save to [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
......
Frames [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9/user23_lab_0.npz] saved.Frames [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9/user23_led_0.npz] saved.
Used time = [102.11s].
运行后,同级目录下会出现 frames_number_20_split_by_number
文件夹,这里存放了积分生成的Frame数据。
打印一个数据:
frame, label = train_set[0]
print(frame.shape)
得到输出为:
(20, 2, 128, 128)
查看1个积分好的Frame数据:
from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
显示效果如下图所示:

固定时间间隔积分
使用固定时间间隔积分,更符合实际物理系统。例如每 10 ms
积分一次,则长度为 L ms
的数据,可以得到 math.floor(L / 10)
帧。但
神经形态数据集中每个样本的长度往往不相同,因此会得到不同长度的帧数据。使用惊蜇框架提供的 spikingjelly.datasets.pad_sequence_collate
和 spikingjelly.datasets.padded_sequence_mask
可以很方便的对不等长数据进行对齐和还原。
示例代码:
import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask, dvs128_gesture
root='D:/datasets/DVS128Gesture'
train_set = dvs128_gesture.DVS128Gesture(root, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = train_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')
train_data_loader = DataLoader(train_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break
输出为:
The directory [D:/datasets/DVS128Gesture\duration_1000000] already exists.
x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
x[4].shape=[T, C, H, W]=(7, 2, 128, 128)
x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
x_len=tensor([6, 6, 5, 5, 7])
mask=
tensor([[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
自定义积分方法
惊蜇框架支持用户自定义积分方法。用户只需要提供积分函数 custom_integrate_function
以及保存frames的文件夹名 custom_integrated_frames_dir_name
。
custom_integrate_function
是用户定义的函数,输入是 events, H, W
,其中 events
是一个pythono字典,键为
['t', 'x', 'y', 'p']
值为 numpy.ndarray
类型。H
是数据高度,W
是数据宽度。例如,对于DVS手势数据集,H=128, W=128。
这个函数的返回值应该是frames。
custom_integrated_frames_dir_name
可以为 None
,在这种情况下,保存frames的文件夹名会被设置成 custom_integrate_function.__name__
。
例如,我们定义这样一种积分方式:随机将全部events一分为二,然后积分成2帧。我们可定义如下函数:
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
return frames
接下来创建数据集:
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)
运行完毕后,在 root_dir
目录下出现了 integrate_events_to_2_frames_randomly
文件夹,保存了我们的frame数据。
查看一下我们积分得到的数据:
from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)

惊蜇框架还支持其他的积分方式,阅读API文档以获取更多信息。
- 1(1,2)
Orchard, Garrick, et al. “Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.” Frontiers in Neuroscience, vol. 9, 2015, pp. 437–437.
- 2
Li, Hongmin, et al. “CIFAR10-DVS: An Event-Stream Dataset for Object Classification.” Frontiers in Neuroscience, vol. 11, 2017, pp. 309–309.
- 3
Amir, Arnon, et al. “A Low Power, Fully Event-Based Gesture Recognition System.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017, pp. 7388–7397.
- 4
Bi, Yin, et al. “Graph-Based Object Classification for Neuromorphic Vision Sensing.” 2019 IEEE/CVF International Conference on Computer Vision (ICCV), 2019, pp. 491–501.
- 5
Fang, Wei, et al. “Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks.” ArXiv: Neural and Evolutionary Computing, 2020.
分类 DVS Gesture
本教程作者: fangwei123456
在 神经形态数据集处理 中我们已经学习了如何使用神经形态数据集,下面让我们搭建SNN对其进行分类。
网络结构
我们将使用 1 一文中定义的网络,其结构如下:

1 一文中的所有网络都在 spikingjelly.activation_based.model.parametric_lif_net
中进行了定义,其中用于DVS Gesture的网络结构为:
# spikingjelly.activation_based.model.parametric_lif_net
import torch
import torch.nn as nn
from .. import layer
class DVSGestureNet(nn.Module):
def __init__(self, channels=128, spiking_neuron: callable = None, *args, **kwargs):
super().__init__()
conv = []
for i in range(5):
if conv.__len__() == 0:
in_channels = 2
else:
in_channels = channels
conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False))
conv.append(layer.BatchNorm2d(channels))
conv.append(spiking_neuron(*args, **kwargs))
conv.append(layer.MaxPool2d(2, 2))
self.conv_fc = nn.Sequential(
*conv,
layer.Flatten(),
layer.Dropout(0.5),
layer.Linear(channels * 4 * 4, 512),
spiking_neuron(*args, **kwargs),
layer.Dropout(0.5),
layer.Linear(512, 110),
spiking_neuron(*args, **kwargs),
layer.VotingLayer(10)
)
def forward(self, x: torch.Tensor):
return self.conv_fc(x)
训练
训练的代码与之前的教程 使用卷积SNN识别Fashion-MNIST 几乎相同,相同之处不再赘述,下面只介绍差异部分。
定义网络,使用多步模式。若使用 CuPy
则将所有的 neuron.LIFNode
设置为 cupy
后端:
# spikingjelly.activation_based.examples.classify_dvsg
import torch
import sys
import torch.nn.functional as F
from torch.cuda import amp
from spikingjelly.activation_based import functional, surrogate, neuron
from spikingjelly.activation_based.model import parametric_lif_net
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import os
import argparse
import datetime
def main():
# ...
net = parametric_lif_net.DVSGestureNet(channels=args.channels, spiking_neuron=neuron.LIFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(net, 'm')
if args.cupy:
functional.set_backend(net, 'cupy', instance=neuron.LIFNode)
# ...
新建数据集:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
train_set = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number')
test_set = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, split_by='number')
# ...
注意,由 DataLoader
打包的数据,第0维总是batch维度,因此我们从 DataLoader
读取的数据实际上是 shape = [N, T, C, H, W]
,因此我们需要转换为SpikingJelly的多步模式使用的 shape = [T, N, C, H, W]
:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
for epoch in range(start_epoch, args.epochs):
for frame, label in train_data_loader:
optimizer.zero_grad()
frame = frame.to(args.device)
frame = frame.transpose(0, 1) # [N, T, C, H, W] -> [T, N, C, H, W]
# ...
with torch.no_grad():
for frame, label in test_data_loader:
frame = frame.to(args.device)
frame = frame.transpose(0, 1) # [N, T, C, H, W] -> [T, N, C, H, W]
# ...
# ...
DVS Gesture有11类,因此在生成one hot的target时别忘了设置为11类:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
label_onehot = F.one_hot(label, 11).float()
# ...
DVSGestureNet
输出的并不是脉冲发放频率,而是 shape = [T, N, 11]
的原始输出:
# spikingjelly.activation_based.model.parametric_lif_net
class DVSGestureNet(nn.Module):
# ...
def forward(self, x: torch.Tensor):
return self.conv_fc(x)
因此,我们需要对输出在时间维度上求平均后,得到脉冲发放频率,然后才去计算损失和正确率:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
out_fr = net(frame).mean(0)
loss = F.mse_loss(out_fr, label_onehot)
# ...
运行我们的网络:
python -m spikingjelly.activation_based.examples.classify_dvsg -T 16 -device cuda:0 -b 16 -epochs 64 -data-dir /datasets/DVSGesture/ -amp -cupy -opt adam -lr 0.001 -j 8
得到输出为:
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
DVSGestureNet(
(conv_fc): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(2): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(6): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(10): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(14): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(17): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(18): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(19): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(20): Flatten(start_dim=1, end_dim=-1, step_mode=m)
(21): Dropout(p=0.5)
(22): Linear(in_features=2048, out_features=512, bias=True)
(23): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(24): Dropout(p=0.5)
(25): Linear(in_features=512, out_features=110, bias=True)
(26): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(27): VotingLayer(voting_size=10, step_mode=m)
)
)
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
Mkdir ./logs/T16_b16_adam_lr0.001_c128_amp_cupy.
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 0, train_loss = 0.0666, train_acc = 0.3964, test_loss = 0.0514, test_acc = 0.6042, max_test_acc = 0.6042
train speed = 92.7646 images/s, test speed = 115.2935 images/s
escape time = 2022-05-25 21:31:54
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 1, train_loss = 0.0463, train_acc = 0.6036, test_loss = 0.0439, test_acc = 0.6319, max_test_acc = 0.6319
train speed = 101.5938 images/s, test speed = 120.5184 images/s
escape time = 2022-05-25 21:30:48
...
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 63, train_loss = 0.0011, train_acc = 0.9991, test_loss = 0.0103, test_acc = 0.9375, max_test_acc = 0.9375
train speed = 100.4324 images/s, test speed = 121.0402 images/s
escape time = 2022-05-25 21:30:51
最终获得了 max_test_acc = 0.9375
的性能。如果精心调整超参数、增加训练 epochs
,通常还能获得更高的性能。
下图展示了训练过程中的正确率曲线:
自连接和有状态突触
本教程作者: fangwei123456
自连接模块
自连接指的是从输出到输入的连接,例如 1 一文中的SRNN(recurrent networks of spiking neurons),如下图所示:

使用SpikingJelly框架很容易构建出带有自连接的模块。考虑最简单的一种情况,我们给神经元增加一个回路,使得它在 \(t\) 时刻的输出 \(s[t]\),会与下一个时刻的外界输入 \(x[t+1]\) 相加,共同作为输入。这可以由 spikingjelly.activation_based.layer.ElementWiseRecurrentContainer
轻松实现。 ElementWiseRecurrentContainer
是一个包装器,给任意的 sub_module
增加一个额外的自连接。连接的形式可以使用用户自定义的逐元素函数操作 \(z=f(x, y)\) 来实现。记 \(x[t]\) 为
\(t\) 时刻整个模块的输入,\(i[t]\) 和 \(y[t]\) 是
sub_module
的输入和输出(注意 \(y[t]\) 同时也是整个模块的输出),则
其中 \(f\) 是用户自定义的逐元素操作。默认 \(y[-1] = 0\)。
现在让我们用 ElementWiseRecurrentContainer
来包装一个IF神经元,逐元素操作设置为加法,因而
我们给与 \(x[t]=[1.5, 0, ..., 0]\) 的输入:
import torch
from spikingjelly.activation_based import layer, functional, neuron
T = 8
N = 1
def element_wise_add(x, y):
return x + y
net = layer.ElementWiseRecurrentContainer(neuron.IFNode(), element_wise_add)
print(net)
x = torch.zeros([T, N])
x[0] = 1.5
for t in range(T):
print(t, f'x[t]={x[t]}, s[t]={net(x[t])}')
functional.reset_net(net)
输出为:
ElementWiseRecurrentContainer(
element-wise function=<function element_wise_add at 0x00000158FC15ACA0>, step_mode=s
(sub_module): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
0 x[t]=tensor([1.5000]), s[t]=tensor([1.])
1 x[t]=tensor([0.]), s[t]=tensor([1.])
2 x[t]=tensor([0.]), s[t]=tensor([1.])
3 x[t]=tensor([0.]), s[t]=tensor([1.])
4 x[t]=tensor([0.]), s[t]=tensor([1.])
5 x[t]=tensor([0.]), s[t]=tensor([1.])
6 x[t]=tensor([0.]), s[t]=tensor([1.])
7 x[t]=tensor([0.]), s[t]=tensor([1.])
可以发现,由于存在自连接,即便 \(t \ge 1\) 时 \(x[t]=0\),由于输出的脉冲能传回到输入,神经元也能持续释放脉冲。
可以使用 spikingjelly.activation_based.layer.LinearRecurrentContainer
实现更复杂的全连接形式的自连接。
有状态的突触
2 3 等文章使用有状态的突触。将 spikingjelly.activation_based.layer.SynapseFilter
放在普通无状
态突触的后面,对突触输出的电流进行滤波,就可以得到有状态的突触,例如:
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron
stateful_conv = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, padding=1, stride=1),
layer.SynapseFilter(tau=100.)
)
Sequential FashionMNIST上的对比实验
接下来让我们在Sequential FashionMNIST上做一个简单的实验,验证自连接和有状态突触是否有助于改善网络的记忆能力。Sequential FashionMNIST指的是 将原始的FashionMNIST图片一行一行或者一列一列,而不是整个图片,作为输入。在这种情况下,网络必须具有一定的记忆能力,才能做出正确的分类。我们将会把 图片一列一列的输入,这样对网络而言,就像是从左到右“阅读”一样,如下图所示:

下图中展示了被读入的列:

首先导入相关的包:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets
from spikingjelly.activation_based import neuron, surrogate, layer, functional
from torch.cuda import amp
import os, argparse
from torch.utils.tensorboard import SummaryWriter
import time
import datetime
import sys
我们定义一个普通的前馈网络 PlainNet
:
class PlainNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
我们在 PlainNet
的第一层脉冲神经元后增加一个 spikingjelly.activation_based.layer.SynapseFilter
,得到一个新的网络 StatefulSynapseNet
:
class StatefulSynapseNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.SynapseFilter(tau=2., learnable=True),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
我们给 PlainNet
的第一层脉冲神经元增加一个反馈连接 spikingjelly.activation_based.layer.LinearRecurrentContainer
得到 FeedBackNet
:
class FeedBackNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
layer.LinearRecurrentContainer(
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True),
in_features=32, out_features=32, bias=True
),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
下图展示了3种网络的结构:

完整的代码位于 spikingjelly.activation_based.examples.rsnn_sequential_fmnist。我们可以通过命令行直接运行。运行参数为:
usage: rsnn_sequential_fmnist.py [-h] [-model MODEL] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR] [-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]
Classify Sequential Fashion-MNIST
optional arguments:
-h, --help show this help message and exit
-model MODEL use which model, "plain", "ss" (StatefulSynapseNet) or "fb" (FeedBackNet)
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of Fashion-MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-cupy use cupy backend
-opt OPT use which optimizer. SDG or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
分别训练3个模型:
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model plain
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model fb
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model ss
下图展示了3种网络的训练曲线:
可以发现, StatefulSynapseNet
和 FeedBackNet
的性能都高于 PlainNet
,表明自连接和有状态突触都有助于提升网络的记忆能力。
- 1
Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
- 2
Diehl P U, Cook M. Unsupervised learning of digit recognition using spike-timing-dependent plasticity[J]. Frontiers in computational neuroscience, 2015, 9: 99.
- 3
Fang H, Shrestha A, Zhao Z, et al. Exploiting Neuron and Synapse Filter Dynamics in Spatial Temporal Learning of Deep Spiking Neural Network[J].
训练大规模SNN
本教程作者: fangwei123456
使用 activation_based.model
在 spikingjelly.activation_based.model
中定义了一些经典网络模型,可以直接拿来使用,使用方法与 torchvision.models
类似。以Spiking ResNet 1 为例:
import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
print(f's_resnet18={s_resnet18}')
with torch.no_grad():
T = 4
N = 1
x_seq = torch.rand([T, N, 3, 224, 224])
functional.set_step_mode(s_resnet18, 'm')
y_seq = s_resnet18(x_seq)
print(f'y_seq.shape={y_seq.shape}')
functional.reset_net(s_resnet18)
输出为:
s_resnet18=SpikingResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
y_seq.shape=torch.Size([4, 1, 1000])
SpikingJelly按照 torchvision
中的ResNet结构搭建的Spiking ResNet,保持了 state_dict().keys()
相同,因此支持直接加载预训练权重,设置 pretrained=True
即可:
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
使用 activation_based.model.train_classify
spikingjelly.activation_based.model.train_classify
是根据 torchvision 0.12 references 的分类代码进行改动而来,使用这个模块可以很方便的进行训练。
spikingjelly.activation_based.model.train_classify.Trainer
提供了较为灵活的训练方式,预留了一些接口给用户改动。例如, spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer
定义了如何设置优化器,默认为:
# spikingjelly.activation_based.model.train_classify
class Trainer:
# ...
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer
{args.opt}. Only SGD, RMSprop and AdamW are supported.")
return optimizer
def main(self, args):
# ...
optimizer = self.set_optimizer(args, parameters)
# ...
如果我们增加一个优化器,例如 Adamax
,只需要继承并重写此方法,例如:
class MyTrainer(train_classify.Trainer):
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("adamax"):
optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
return optimizer
else:
return super().set_optimizer(args, parameters)
默认的 Trainer.get_args_parser
已经包含了较多的参数设置:
(pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h
usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING]
[--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY]
[--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT]
[--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION]
[--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED]
[--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda]
PyTorch Classification Training
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
dataset path
--model MODEL model name
--device DEVICE device (Use cuda or cpu Default: cuda)
-b BATCH_SIZE, --batch-size BATCH_SIZE
images per gpu, the total batch size is $NGPU x batch_size
--epochs N number of total epochs to run
-j N, --workers N number of data loading workers (default: 16)
--opt OPT optimizer
--lr LR initial learning rate
--momentum M momentum
--wd W, --weight-decay W
weight decay (default: 0.)
--norm-weight-decay NORM_WEIGHT_DECAY
weight decay for Normalization layers (default: None, same value as --wd)
--label-smoothing LABEL_SMOOTHING
label smoothing (default: 0.1)
--mixup-alpha MIXUP_ALPHA
mixup alpha (default: 0.2)
--cutmix-alpha CUTMIX_ALPHA
cutmix alpha (default: 1.0)
--lr-scheduler LR_SCHEDULER
the lr scheduler (default: cosa)
--lr-warmup-epochs LR_WARMUP_EPOCHS
the number of epochs to warmup (default: 5)
--lr-warmup-method LR_WARMUP_METHOD
the warmup method (default: linear)
--lr-warmup-decay LR_WARMUP_DECAY
the decay for lr
--lr-step-size LR_STEP_SIZE
decrease lr every step-size epochs
--lr-gamma LR_GAMMA decrease lr by a factor of lr-gamma
--output-dir OUTPUT_DIR
path to save outputs
--resume RESUME path of checkpoint. If set to 'latest', it will try to load the latest checkpoint
--start-epoch N start epoch
--cache-dataset Cache the datasets for quicker initialization. It also serializes the transforms
--sync-bn Use sync batch norm
--test-only Only test the model
--pretrained Use pre-trained models from the modelzoo
--auto-augment AUTO_AUGMENT
auto augment policy (default: ta_wide)
--random-erase RANDOM_ERASE
random erasing probability (default: 0.1)
--world-size WORLD_SIZE
number of distributed processes
--dist-url DIST_URL url used to set up distributed training
--model-ema enable tracking Exponential Moving Average of model parameters
--model-ema-steps MODEL_EMA_STEPS
the number of iterations that controls how often to update the EMA model (default: 32)
--model-ema-decay MODEL_EMA_DECAY
decay factor for Exponential Moving Average of model parameters (default: 0.99998)
--interpolation INTERPOLATION
the interpolation method (default: bilinear)
--val-resize-size VAL_RESIZE_SIZE
the resize size used for validation (default: 232)
--val-crop-size VAL_CROP_SIZE
the central crop size used for validation (default: 224)
--train-crop-size TRAIN_CROP_SIZE
the random crop size used for training (default: 176)
--clip-grad-norm CLIP_GRAD_NORM
the maximum gradient norm (default None)
--ra-sampler whether to use Repeated Augmentation in training
--ra-reps RA_REPS number of repetitions for Repeated Augmentation (default: 4)
--prototype Use prototype model builders instead those from main area
--weights WEIGHTS the weights enum name to load
--seed SEED the random seed
--print-logdir print the dirs for tensorboard logs and pt files and exit
--clean delete the dirs for tensorboard logs and pt files
--disable-pinmemory not use pin memory in dataloader, which can help reduce memory consumption
--disable-amp not use automatic mixed precision training
--local_rank LOCAL_RANK
args for DDP, which should not be set by user
--disable-uda not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation
如果想增加参数,仍然可以通过继承的方式实现:
class MyTrainer(train_classify.Trainer):
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--do-something', type=str, help="do something")
Trainer
的许多其他函数都可以进行补充修改或覆盖,方法类似,不再赘述。
对于 Trainer
及用户自己继承实现的子类,可以通过如下方式调用并进行训练:
trainer = Trainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
Trainer
在训练中会自动计算训练集、测试集的 Acc@1, Acc@5, loss
并使用 tensorboard
保存为日志文件,此外训练过程中的最新一个epoch的模型以及测试集性能最高的模型也会被保存下来。 Trainer
支持Distributed Data Parallel训练。
在ImageNet上训练
Trainer
默认的数据加载函数load_data
加载 ImageNet 2 数据集。结合Trainer
和spikingjelly.activation_based.model.spiking_resnet
,我们可以轻松训练大型深度SNN,示例代码如下:
# spikingjelly.activation_based.model.train_imagenet_example
import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify
class SResNetTrainer(train_classify.Trainer):
def preprocess_train_sample(self, args, x: torch.Tensor):
# define how to process train sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def preprocess_test_sample(self, args, x: torch.Tensor):
# define how to process test sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def process_model_output(self, args, y: torch.Tensor):
return y.mean(0) # return firing rate
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--T', type=int, help="total time-steps")
parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
return parser
def get_tb_logdir_name(self, args):
return super().get_tb_logdir_name(args) + f'_T{args.T}'
def load_model(self, args, num_classes):
if args.model in spiking_resnet.__all__:
model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode,
surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(model, step_mode='m')
if args.cupy:
functional.set_backend(model, 'cupy', neuron.IFNode)
return model
else:
raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")
if __name__ == "__main__":
trainer = SResNetTrainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
代码位于 spikingjelly.activation_based.model.train_imagenet_example
,可以直接运行。
在单卡上进行训练:
python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
在多卡上进行训练:
python -m torch.distributed.launch --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
- 1
He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- 2
Deng, Jia, et al. “Imagenet: A large-scale hierarchical image database.” 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.
STDP学习
本教程作者: fangwei123456
生物可解释性的学习规则一直备受SNN研究者的关注。在SpikingJelly中提供了STDP(Spike Timing Dependent Plasticity) 学习器,可以用于卷积或全连接层的权重学习。
STDP(Spike Timing Dependent Plasticity)
STDP(Spike Timing Dependent Plasticity)最早是由 1 提出,是在生物实验中发现的一种突触可塑性机制。实验发现,突触权重 受到突触连接的前神经元(pre)和后神经元(post)的脉冲发放的影响,具体而言是:
如果pre神经元先发放脉冲,post神经元后发放脉冲,则突触的权重会增大; 如果pre神经元后发放脉冲,post神经元先发放脉冲,则突触的权重会减小。
生理实验数据拟合的曲线如下图 2 所示:

STDP可以使用如下公式进行拟合:
其中 \(A, B\) 是突触权重变化的最大值,\(\tau_{+}, \tau_{-}\) 是时间常数。
上述标准的STDP公式在实践中使用较为繁琐,因其需要记录前后神经元所有的脉冲发放时刻。实践中通常使用迹 3 的方式来实现STDP。
对于pre神经元 \(i\) 和post神经元 \(j\),分别使用迹 \(tr_{pre}[i]\) 和 \(tr_{post}[j]\) 来记录其脉冲发放。迹的更新类似于LIF神经元:
其中 \(\tau_{pre}, \tau_{post}\) 是pre和post迹的时间常数,\(s[i][t], s[j][t]\) 是在 \(t\) 时刻pre神经元 \(i\) 和post神经元 \(j\) 发放的脉冲,取值仅为0或1。
突触权重的更新按照:
其中 \(F_{pre}, F_{post}\) 是控制突触改变量的函数。
STDP优化器
spikingjelly.activation_based.learning.STDPLearner
提供了STDP优化器的实现,支持卷积和全连接层,请读者先阅读其API文档以获取使用方法。
我们使用 STDPLearner
搭建一个最简单的 1x1
网络,pre和post都只有一个神经元,并且将权重设置为 0.4
:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt
torch.manual_seed(0)
def f_weight(x):
return torch.clamp(x, -1, 1.)
tau_pre = 2.
tau_post = 2.
T = 128
N = 1
lr = 0.01
net = nn.Sequential(
layer.Linear(1, 1, bias=False),
neuron.IFNode()
)
nn.init.constant_(net[0].weight.data, 0.4)
STDPLearner
可以将负的权重的更新量 - delta_w * scale
叠加到参数的梯度上,因而与深度学习完全兼容。
我们可以将其和优化器、学习率调节器等深度学习中的模块一起使用。这里我们使用最简单的权重更新策略:
其中 \(\nabla W\) 是使用STDP得到的权重更新量取负后的 - delta_w * scale
。因而借助优化器可以实现 weight.data = weight.data - lr * weight.grad = weight.data + lr * delta_w * scale
。
这可以使用最朴素的 torch.optim.SGD
实现,只需要设置 momentum=0.
:
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)
接下来生成输入脉冲,设置 STDPLearner
:
in_spike = (torch.rand([T, N, 1]) > 0.7).float()
stdp_learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
接下来送入数据计算。需要注意的是,为了便于画图,我们会将输出数据进行 squeeze()
,这样使得 shape = [T, N, 1]
的数据变为 shape = [T]
:
out_spike = []
trace_pre = []
trace_post = []
weight = []
with torch.no_grad():
for t in range(T):
optimizer.zero_grad()
out_spike.append(net(in_spike[t]).squeeze())
stdp_learner.step(on_grad=True) # 将STDP学习得到的参数更新量叠加到参数的梯度上
optimizer.step()
weight.append(net[0].weight.data.clone().squeeze())
trace_pre.append(stdp_learner.trace_pre.squeeze())
trace_post.append(stdp_learner.trace_post.squeeze())
in_spike = in_spike.squeeze()
out_spike = torch.stack(out_spike)
trace_pre = torch.stack(trace_pre)
trace_post = torch.stack(trace_post)
weight = torch.stack(weight)
完整的代码位于 spikingjelly/activation_based/examples/stdp_trace.py
。
将 in_spike, out_spike, trace_pre, trace_post, weight
画出,得到下图:
这与 3 中的Fig.3是一致的(注意下图中使用 j 表示pre神经元,i 表示后神经元,与我们相反):

与梯度下降混合使用
在SNN中一种广泛使用STDP的做法是,使用STDP和梯度下降分别训练网路中的不同层。下面介绍如何使用 STDPLearner
实现这一做法。
我们的目标是搭建一个深度卷积SNN,使用STDP训练卷积层,使用梯度下降法训练全连接层。首先定义超参数:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional
T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'
我们使用 shape = [T, N, C, H, W] = [8, 2, 3, 32, 32]
的输入。
接下来定义STDP的权重函数以及网络,这里我们搭建的是一个简单的卷积SNN,且使用多步模式:
def f_weight(x):
return torch.clamp(x, -1, 1.)
net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)
functional.set_step_mode(net, step_mode)
我们希望使用STDP训练 layer.Conv2d
,其他层使用梯度下降训练。首先定义使用STDP训练的层类型:
instances_stdp = (layer.Conv2d, )
对于每个类型为 instances_stdp
的层,我们都使用一个STDP学习器:
stdp_learners = []
for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)
接下来进行参数分组,将类型为 instances_stdp
的层参数和其他类型的层的参数,分别放置到不同的优化器中。这里我们使用 Adam
作为梯度下降训练的参数的优化器,使用 SGD
作为STDP训练的参数的优化器:
params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)
params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
if p not in params_stdp_set:
params_gradient_descent.append(p)
optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)
在实际任务中,输入和输出应该是从数据集中抽样得到的,我们这里仅仅是做示例,因此手动生成:
x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])
接下来就是参数优化的主要步骤了,在实际任务中下面的代码通常会放到训练的主循环中。我们的代码与纯梯度下降会略有不同。
首先清零所有梯度,进行前向传播,计算出损失,反向传播:
optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()
需要注意的是,尽管 optimizer_gd
只会对 params_gradient_descent
中的参数进行梯度下降,但调用 loss.backward()
后整个网络中所有的参数都会计算出梯度,包括那些我们只想使用STDP进行优化的参数。
因此,我们需要将使用梯度下降得到的 params_stdp
的梯度进行清零:
optimizer_stdp.zero_grad()
接下来就是使用STDP学习器计算出参数更新量,然后使用2个优化器,对整个网络的参数进行更新:
for i in range(stdp_learners.__len__()):
stdp_learners[i].step(on_grad=True)
optimizer_gd.step()
optimizer_stdp.step()
以 STDPLearner
为代表的所有学习器都是 MemoryModule
的子类,其内部记忆状态包括了突触前后神经元的迹 trace_pre, trace_post
;另外,学习器内部用于记录神经元活动的监视器存储了突触前后神经元的发放历史;这些发放历史也可以视作学习器的内部记忆状态。因此,必须及时调用学习器的 reset()
方法,来清空其内部记忆状态,从而防止内存/显存消耗量随着训练而不断增长!通常的做法是:在每个batch结束后,将学习器和网络一起重制:
functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()
完整的示例代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional
T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'
def f_weight(x):
return torch.clamp(x, -1, 1.)
net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)
functional.set_step_mode(net, step_mode)
instances_stdp = (layer.Conv2d, )
stdp_learners = []
for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)
params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)
params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
if p not in params_stdp_set:
params_gradient_descent.append(p)
optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)
x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])
optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()
optimizer_stdp.zero_grad()
for i in range(stdp_learners.__len__()):
stdp_learners[i].step(on_grad=True)
optimizer_gd.step()
optimizer_stdp.step()
functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()
- 1
Bi, Guo-qiang, and Mu-ming Poo. “Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type.” Journal of neuroscience 18.24 (1998): 10464-10472.
- 2
Froemke, Robert C., et al. “Contribution of individual spikes in burst-induced long-term synaptic modification.” Journal of neurophysiology (2006).
- 3(1,2)
Morrison, Abigail, Markus Diesmann, and Wulfram Gerstner. “Phenomenological models of synaptic plasticity based on spike timing.” Biological cybernetics 98.6 (2008): 459-478.
ANN转换SNN
本教程作者: DingJianhao, fangwei123456, Lv Liuzhenghao
本节教程主要关注 spikingjelly.activation_based.ann2snn
,介绍如何将训练好的ANN转换SNN,并且在SpikingJelly框架上进行仿真。
相关API见此处 API参考 。
较早的实现方案中有两套实现:基于ONNX 和 基于PyTorch。本版本基于torch.fx。fx专门用于对nn.Module实例进行转换,而且在构建计算图时会原生地将复杂模型解耦合。一起来看看吧!
ANN转换SNN的理论基础
SNN相比于ANN,产生的脉冲是离散的,这有利于高效的通信。在ANN大行其道的今天,SNN的直接训练需要较多资源。自然我们会想到使用现在非常成熟的ANN转换到SNN,希望SNN也能有类似的表现。这就牵扯到如何搭建起ANN和SNN桥梁的问题。现在SNN主流的方式是采用频率编码,因此对于输出层,我们会用神经元输出脉冲数来判断类别。发放率和ANN有没有关系呢?
幸运的是,ANN中的ReLU神经元非线性激活和SNN中IF神经元(采用减去阈值 \(V_{threshold}\) 方式重置)的发放率有着极强的相关性,我们可以借助这个特性来进行转换。这里说的神经元更新方式,也就是 神经元教程 中提到的Soft方式。
实验:IF神经元脉冲发放频率和输入的关系
我们给与恒定输入到IF神经元,观察其输出脉冲和脉冲发放频率。首先导入相关的模块,新建IF神经元层,确定输入并画出每个IF神经元的输入 \(x_{i}\):
import torch
from spikingjelly.activation_based import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
import numpy as np
plt.rcParams['figure.dpi'] = 200
if_node = neuron.IFNode(v_reset=None)
T = 128
x = torch.arange(-0.2, 1.2, 0.04)
plt.scatter(torch.arange(x.shape[0]), x)
plt.title('Input $x_{i}$ to IF neurons')
plt.xlabel('Neuron index $i$')
plt.ylabel('Input $x_{i}$')
plt.grid(linestyle='-.')
plt.show()
接下来,将输入送入到IF神经元层,并运行 T=128
步,观察各个神经元发放的脉冲、脉冲发放频率:
s_list = []
for t in range(T):
s_list.append(if_node(x).unsqueeze(0))
out_spikes = np.asarray(torch.cat(s_list))
visualizing.plot_1d_spikes(out_spikes, 'IF neurons\' spikes and firing rates', 't', 'Neuron index $i$')
plt.show()
可以发现,脉冲发放的频率在一定范围内,与输入 \(x_{i}\) 的大小成正比。
接下来,让我们画出IF神经元脉冲发放频率和输入 \(x_{i}\) 的曲线,并与 \(\mathrm{ReLU}(x_{i})\) 对比:
plt.subplot(1, 2, 1)
firing_rate = np.mean(out_spikes, axis=1)
plt.plot(x, firing_rate)
plt.title('Input $x_{i}$ and firing rate')
plt.xlabel('Input $x_{i}$')
plt.ylabel('Firing rate')
plt.grid(linestyle='-.')
plt.subplot(1, 2, 2)
plt.plot(x, x.relu())
plt.title('Input $x_{i}$ and ReLU($x_{i}$)')
plt.xlabel('Input $x_{i}$')
plt.ylabel('ReLU($x_{i}$)')
plt.grid(linestyle='-.')
plt.show()
可以发现,两者的曲线几乎一致。需要注意的是,脉冲频率不可能高于1,因此IF神经元无法拟合ANN中ReLU的输入大于1的情况。
理论证明
文献 1 对ANN转SNN提供了解析的理论基础。理论说明,SNN中的IF神经元是ReLU激活函数在时间上的无偏估计器。
针对神经网络第一层即输入层,讨论SNN神经元的发放率 \(r\) 和对应ANN中激活的关系。假定输入恒定为 \(z \in [0,1]\)。 对于采用减法重置的IF神经元,其膜电位V随时间变化为:
- 其中:
\(V_{threshold}\) 为发放阈值,通常设为1.0。 \(\theta_t\) 为输出脉冲。 \(T\) 时间步内的平均发放率可以通过对膜电位求和得到:
将含有 \(V_t\) 的项全部移项到左边,两边同时除以 \(T\) :
其中 \(N\) 为 \(T\) 时间步内脉冲数, \(\frac{N}{T}\) 就是发放率 \(r\)。利用 \(z= V_{threshold} a\) 即:
故在仿真时间步 \(T\) 无限长情况下:
类似地,针对神经网络更高层,文献 1 进一步说明层间发放率满足:
转换到脉冲神经网络
转换主要解决两个问题:
ANN为了快速训练和收敛提出了批归一化(Batch Normalization)。批归一化旨在将ANN输出归一化到0均值,这与SNN的特性相违背。因此,可以将BN的参数吸收到前面的参数层中(Linear、Conv2d)
根据转换理论,ANN的每层输入输出需要被限制在[0,1]范围内,这就需要对参数进行缩放(模型归一化)
◆ BatchNorm参数吸收
假定BatchNorm的参数为 \(\gamma\) (BatchNorm.weight
), \(\beta\) (BatchNorm.bias
), \(\mu\) (BatchNorm.running_mean
) ,
\(\sigma\) (BatchNorm.running_var
,\(\sigma = \sqrt{\mathrm{running\_var}}\))。具体参数定义详见
torch.nn.BatchNorm1d 。
参数模块(例如Linear)具有参数 \(W\) 和 \(b\) 。BatchNorm参数吸收就是将BatchNorm的参数通过运算转移到参数模块的 \(W\) 中,使得数据输入新模块的输出和有BatchNorm时相同。
对此,新模型的 \(\bar{W}\) 和 \(\bar{b}\) 公式表示为:
◆ 模型归一化
对于某个参数模块,假定得到了其输入张量和输出张量,其输入张量的最大值为 \(\lambda_{pre}\) ,输出张量的最大值为 \(\lambda\) 那么,归一化后的权重 \(\hat{W}\) 为:
归一化后的偏置 \(\hat{b}\) 为:
ANN每层输出的分布虽然服从某个特定分布,但是数据中常常会存在较大的离群值,这会导致整体神经元发放率降低。 为了解决这一问题,鲁棒归一化将缩放因子从张量的最大值调整为张量的p分位点。文献中推荐的分位点值为99.9。
到现在为止,我们对神经网络做的操作,在数值上是完全等价的。当前的模型表现应该与原模型相同。
转换中,我们需要将原模型中的ReLU激活函数变为IF神经元。 对于ANN中的平均池化,我们需要将其转化为空间下采样。由于IF神经元可以等效ReLU激活函数。空间下采样后增加IF神经元与否对结果的影响极小。 对于ANN中的最大池化,目前没有非常理想的方案。目前的最佳方案为使用基于动量累计脉冲的门控函数控制脉冲通道 1 。此处我们依然推荐使用avgpool2d。 仿真时,依照转换理论,SNN需要输入恒定的模拟输入。使用Poisson编码器将会带来准确率的降低。
实现与可选配置
ann2snn框架在2022年4月又迎来一次较大更新。取消了parser和simulator两大类。使用converter类替代了之前的方案。目前的方案更加简洁,并且具有更多转换设置空间。
ann2snn框架在2022年10月再次更新。在converter类中添加fuse方法,将bn层参数吸收进conv层。
◆ Converter类
该类用于将ReLU的ANN转换为SNN。
这里实现了常见的三种模式:
最常见的是最大电流转换模式,它利用前后层的激活上限,使发放率最高的情况能够对应激活取得最大值的情况。使用这种模式需要将参数mode设置为 max
2 。
99.9%电流转换模式利用99.9%的激活分位点限制了激活上限。使用这种模式需要将参数mode设置为 99.9%
1 。
缩放转换模式下,用户需要给定缩放参数到模式中,即可利用缩放后的激活最大值对电流进行限制。使用这种模式需要将参数mode设置为0-1的浮点数。
实现了可选的BatchNorm层参数吸收功能:
设置 fuse_flag
为 True
(默认值) ,以进行conv层与bn层的参数融合。
转换后ReLU模块被删除,SNN需要的新模块(包括VoltageScaler、IFNode等)被创建并存放在 snn tailor
父模块中。
由于返回值的类型为fx.GraphModule,所以您可以使用print(fx.GraphModule.graph)查看计算图及前向传播关系。更多API参见 GraphModule .
识别MNIST
原ANN
现在我们使用 ann2snn
,搭建一个简单卷积网络,对MNIST数据集进行分类。
首先定义我们的网络结构 (见 ann2snn.sample_models.mnist_cnn
):
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
)
def forward(self,x):
x = self.network(x)
return x
注意:如果遇到需要将tensor展开的情况,就在网络中定义一个 nn.Flatten
模块,在forward函数中需要使用定义的Flatten而不是view函数。
定义我们的超参数:
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
device = 'cuda'
dataset_dir = 'G:/Dataset/mnist'
batch_size = 100
T = 50
这里的T就是一会儿推理时使用的推理时间步。
如果您想训练的话,还需要初始化数据加载器、优化器、损失函数,例如:
lr = 1e-3
epochs = 10
# 定义损失函数
loss_function = nn.CrossEntropyLoss()
# 使用Adam优化器
optimizer = torch.optim.Adam(ann.parameters(), lr=lr, weight_decay=5e-4)
训练ANN。示例中,我们的模型训练了10个epoch。训练时测试集准确率变化情况如下:
Epoch: 0 100%|██████████| 600/600 [00:05<00:00, 112.04it/s]
Validating Accuracy: 0.972
Epoch: 1 100%|██████████| 600/600 [00:05<00:00, 105.43it/s]
Validating Accuracy: 0.986
Epoch: 2 100%|██████████| 600/600 [00:05<00:00, 107.49it/s]
Validating Accuracy: 0.987
Epoch: 3 100%|██████████| 600/600 [00:05<00:00, 109.26it/s]
Validating Accuracy: 0.990
Epoch: 4 100%|██████████| 600/600 [00:05<00:00, 103.98it/s]
Validating Accuracy: 0.984
Epoch: 5 100%|██████████| 600/600 [00:05<00:00, 100.42it/s]
Validating Accuracy: 0.989
Epoch: 6 100%|██████████| 600/600 [00:06<00:00, 96.24it/s]
Validating Accuracy: 0.991
Epoch: 7 100%|██████████| 600/600 [00:05<00:00, 104.97it/s]
Validating Accuracy: 0.992
Epoch: 8 100%|██████████| 600/600 [00:05<00:00, 106.45it/s]
Validating Accuracy: 0.991
Epoch: 9 100%|██████████| 600/600 [00:05<00:00, 111.93it/s]
Validating Accuracy: 0.991
训练好模型后,我们快速加载一下模型测试一下保存好的模型性能:
model.load_state_dict(torch.load('SJ-mnist-cnn_model-sample.pth'))
acc = val(model, device, test_data_loader)
print('ANN Validating Accuracy: %.4f' % (acc))
输出结果如下:
100%|██████████| 200/200 [00:02<00:00, 89.44it/s]
ANN Validating Accuracy: 0.9870
使用Converter进行转换
使用Converter进行转换非常简单,只需要参数中设置希望使用的模式即可。例如使用MaxNorm,需要先定义一个 ann2snn.Converter
,并且把模型forward给这个对象:
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
snn_model就是输出的SNN模型。查看snn_model的网络结构(BatchNorm2d的缺失,是由于转换过程中进行的conv_bn_fuse,也就是将bn层的参数吸收进conv层):
ANN(
(network): Module(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(3): AvgPool2d(kernel_size=2, stride=2, padding=0)
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(7): AvgPool2d(kernel_size=2, stride=2, padding=0)
(8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(11): AvgPool2d(kernel_size=2, stride=2, padding=0)
(12): Flatten(start_dim=1, end_dim=-1)
(13): Linear(in_features=32, out_features=10, bias=True)
(15): Softmax(dim=1)
)
(snn tailor): Module(
(0): Module(
(0): VoltageScaler(0.240048)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(4.165831)
)
(1): Module(
(0): VoltageScaler(0.307485)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(3.252196)
)
(2): Module(
(0): VoltageScaler(0.141659)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(7.059210)
)
(3): Module(
(0): VoltageScaler(0.060785)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(16.451399)
)
)
)
snn_model的类型为 GraphModule
,参见 GraphModule 。
调用 GraphModule.graph.print_tabular()
方法,用表格的形式查看模型的计算图的中间表示:
#snn_model.graph.print_tabular()
opcode name target args kwargs
----------- -------------- -------------- ----------------- --------
placeholder x x () {}
call_module network_0 network.0 (x,) {}
call_module snn_tailor_0_1 snn tailor.0.0 (network_0,) {}
call_module snn_tailor_0_2 snn tailor.0.1 (snn_tailor_0_1,) {}
call_module snn_tailor_0_3 snn tailor.0.2 (snn_tailor_0_2,) {}
call_module network_3 network.3 (snn_tailor_0_3,) {}
call_module network_4 network.4 (network_3,) {}
call_module snn_tailor_1_1 snn tailor.1.0 (network_4,) {}
call_module snn_tailor_1_2 snn tailor.1.1 (snn_tailor_1_1,) {}
call_module snn_tailor_1_3 snn tailor.1.2 (snn_tailor_1_2,) {}
call_module network_7 network.7 (snn_tailor_1_3,) {}
call_module network_8 network.8 (network_7,) {}
call_module snn_tailor_2_1 snn tailor.2.0 (network_8,) {}
call_module snn_tailor_2_2 snn tailor.2.1 (snn_tailor_2_1,) {}
call_module snn_tailor_2_3 snn tailor.2.2 (snn_tailor_2_2,) {}
call_module network_11 network.11 (snn_tailor_2_3,) {}
call_module network_12 network.12 (network_11,) {}
call_module network_13 network.13 (network_12,) {}
call_module snn_tailor_3_1 snn tailor.3.0 (network_13,) {}
call_module snn_tailor_3_2 snn tailor.3.1 (snn_tailor_3_1,) {}
call_module snn_tailor_3_3 snn tailor.3.2 (snn_tailor_3_2,) {}
call_module network_15 network.15 (snn_tailor_3_3,) {}
output output output (network_15,) {}
不同转换模式的对比
按照这个例子,我们分别定义模式为 max
,99.9%
,1.0/2
,1.0/3
,1.0/4
, 1.0/5
情况下的SNN转换并分别推理T步得到准确率。
print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_max_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_max_accs[-1]))
print('---------------------------------------------')
print('Converting using RobustNorm')
model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/2 max(activation) as scales...')
model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_two_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_two_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/3 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_three_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_three_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/4 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_four_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_four_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/5 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_five_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_five_accs[-1]))
观察控制栏输出:
---------------------------------------------
Converting using MaxNorm
100%|██████████| 600/600 [00:04<00:00, 128.25it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.44it/s] SNN accuracy (simulation 50 time-steps): 0.9777
---------------------------------------------
Converting using RobustNorm
100%|██████████| 600/600 [00:19<00:00, 31.06it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.75it/s] SNN accuracy (simulation 50 time-steps): 0.9841
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 600/600 [00:04<00:00, 126.64it/s] ]Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.90it/s] SNN accuracy (simulation 50 time-steps): 0.9844
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 126.27it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.73it/s] SNN accuracy (simulation 50 time-steps): 0.9828
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 128.94it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.47it/s] SNN accuracy (simulation 50 time-steps): 0.9747
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 121.18it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.42it/s] SNN accuracy (simulation 50 time-steps): 0.9487
---------------------------------------------
模型转换的速度可以看到是非常快的。模型推理速度200步仅需11s完成(GTX 2080ti)。 根据模型输出的随时间变化的准确率,我们可以绘制不同设置下的准确率图像。
fig = plt.figure()
plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
plt.plot(np.arange(0, T), mode_two_accs, label='mode: 1.0/2')
plt.plot(np.arange(0, T), mode_three_accs, label='mode: 1.0/3')
plt.plot(np.arange(0, T), mode_four_accs, label='mode: 1.0/4')
plt.plot(np.arange(0, T), mode_five_accs, label='mode: 1.0/5')
plt.legend()
plt.xlabel('t')
plt.ylabel('Acc')
plt.show()

不同的设置可以得到不同的结果,有的推理速度快,但是最终精度低,有的推理慢,但是精度高。用户可以根据自己的需求选择模型设置。
- 1(1,2,3,4,5,6)
Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
- 2
Diehl, Peter U. , et al. Fast classifying, high-accuracy spiking deep networks through weight and threshold balancing. Neural Networks (IJCNN), 2015 International Joint Conference on IEEE, 2015.
- 3
Rueckauer, B., Lungu, I. A., Hu, Y., & Pfeiffer, M. (2016). Theory and tools for the conversion of analog to spiking convolutional neural networks. arXiv preprint arXiv:1612.04052.
- 4
Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.
遗产教程
本教程作者: fangwei123456
由于开发者精力有限,有一些教程并未随着SpikingJelly的代码更新而同步更新,还有一些教程被精简合并进了新版教程。下面列出一些可能对读者有帮助的老版教程。
Activation-based 的设计来源
在早期的框架中,Activation-based 被称之为 Clock-driven,下面是相关的教程:
编码器
新版框架还没有来得及进行更新,可以先查看老版本的教程:
ANN转换SNN
新版框架还没有来得及进行更新,可以先查看老版本的教程:
SNN在其他任务的应用
步进模式的设计来源
CUPY后端的设计来源
诚待英才
我们非常欢迎有余力的读者,将这些教程更新到与框架的master版本匹配,并提交Pull Request到master版本中。
编写CUPY神经元
本教程作者: fangwei123456
本教程介绍如何编写CUPY后端的神经元。本教程需要如下的前置知识:
了解CUDA,能够实现简单的逐元素CUDA内核
能够使用
torch.autograd.Function
实现自定义反向传播已经阅读了
spikingjelly.activation_based.auto_cuda.base
的全部API文档,能够使用spikingjelly.activation_based.auto_cuda.base
编写2D CUDA内核
实现IF神经元的CUDA多步前向传播
假设我们要编写一个python函数用于神经元进行多步前向传播(FPTT),则这个函数的输入应该至少包括:
v_init
:shape = [N]
,表示神经元在当前时刻的初始电压(上一个时刻的放电后的电压)。其中N
为神经元的数量。当神经元是多维时,N
应该是神经元展平后的数量x_seq
:shape = [T, N]
,表示T
个time-steps的输入v_th
:float
,表示阈值电压
如果使用 hard reset,则还需要增加一个参数:
v_reset
:float
,表示重置电压
这个函数的输出应该包括:
spike_seq
:shape = [T, N]
,表示输出的T
个time-steps的脉冲v_seq
:shape = [T, N]
,表示T
个time-steps的放电后的电压。我们需要输出所有时刻而不仅仅是最后时刻的电压,因为有时可能会用到这些数据
若将FPTT写成CUDA函数,则函数参数仍然包括上述参数,但还需要一些额外的参数。
spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronFPTTKernel
继承自spikingjelly.activation_based.auto_cuda.base.CKernel2D
。NeuronFPTTKernel
是神经元进行多步前向传播(FPTT)的CUDA内核基类。
我们可以查看其默认的CUDA参数声明:
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronFPTTKernel(hard_reset=True, dtype='float')
for key, value in base_kernel.cparams.items():
print(f'key="{key}",'.ljust(20), f'value="{value}"'.ljust(20))
输出为:
key="numel", value="const int &"
key="N", value="const int &"
key="x_seq", value="const float *"
key="v_v_seq", value="float *"
key="h_seq", value="float *"
key="spike_seq", value="float *"
key="v_th", value="float &"
key="v_reset", value="float &"
绝大多数参数与之前相同,不同的参数包括
numel
: 元素总数,即numel = T * N
N
: 神经元的数量v_v_seq
:shape = [T + 1, N]
,合并v_init
和v_seq
得到的h_seq
:shape = [T, N]
,充电后放电前的电压,反向传播时需要用到
NeuronFPTTKernel
作为神经元FPTT的基类,类似于 spikingjelly.activation_based.neuron.BaseNode
,已经实现了放电和重置方程。我们在实现新神经元的FPTT CUDA内核时,只需要继承 NeuronFPTTKernel
并补充充电方程即可。
我们首先查看一下 NeuronFPTTKernel
的完整代码:
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronFPTTKernel(hard_reset=True, dtype='float')
print(base_kernel.full_codes)
输出为:
#include <cuda_fp16.h>
extern "C" __global__
void NeuronFPTTKernel_float_hard_reset(
const int & numel, const int & N, const float * x_seq, float * v_v_seq, float * h_seq, float * spike_seq, float & v_th, float & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
// neuronal_charge should be defined here!;
spike_seq[t] = (h_seq[t] - v_th) >= 0.0f ? 1.0f: 0.0f;
v_v_seq[t + dt] = h_seq[t] * (1.0f - spike_seq[t]) + v_reset * spike_seq[t];
}
}
}
可以发现,这个内核已经比较完善,仅需要我们补充部分代码。
NeuronFPTTKernel
提供了 neuronal_charge
函数:
class NeuronFPTTKernel(base.CKernel2D):
# ...
def neuronal_charge(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`H[t] = f(X[t], V[t-1], ...)`.
This function should define how ``h_seq[t]`` is calculated by ``x_seq[t], v_v_seq[t]`` and other params if
the neuron needs.
For example, the IF neuron define this function as:
.. code-block:: python
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
"""
return '// neuronal_charge should be defined here!'
如果想要实现新的神经元,只需要重定义这个函数。现在以最简单的IF神经元为例,其充电方程为
则实现方式为:
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeFPTTKernel(neuron_kernel.NeuronFPTTKernel):
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
if_fptt_kernel = IFNodeFPTTKernel(hard_reset=True, dtype='float')
print(if_fptt_kernel.full_codes)
输出为:
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeFPTTKernel_float_hard_reset(
const int & numel, const int & N, const float * x_seq, float * v_v_seq, float * h_seq, float * spike_seq, float & v_th, float & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
h_seq[t] = x_seq[t] + v_v_seq[t];
spike_seq[t] = (h_seq[t] - v_th) >= 0.0f ? 1.0f: 0.0f;
v_v_seq[t + dt] = h_seq[t] * (1.0f - spike_seq[t]) + v_reset * spike_seq[t];
}
}
}
这其实就是一个完整的CUDA内核了。可以发现,NeuronFPTTKernel
给编写CUDA内核带来了极大的方便。
需要注意的是,这里我们使用:
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
而不是手动编写:
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return 'h_seq[t] = x_seq[t] + v_v_seq[t];'
原因在于 spikingjelly.activation_based.auto_cuda.cfunction
提供的函数,通常包括 float
和 half2
两种数据类型的实现,比我们手动编写两种更便捷。
若设置 dtype='half2'
,可以直接得到半精度的内核:
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeFPTTKernel(neuron_kernel.NeuronFPTTKernel):
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
if_fptt_kernel = IFNodeFPTTKernel(hard_reset=True, dtype='half2')
print(if_fptt_kernel.full_codes)
输出为:
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeFPTTKernel_half2_hard_reset(
const int & numel, const int & N, const half2 * x_seq, half2 * v_v_seq, half2 * h_seq, half2 * spike_seq, half2 & v_th, half2 & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
h_seq[t] = __hadd2(x_seq[t], v_v_seq[t]);
spike_seq[t] = __hgeu2(__hsub2(h_seq[t], v_th), __float2half2_rn(0.0f));
v_v_seq[t + dt] = __hfma2(h_seq[t], __hsub2(__float2half2_rn(1.0f), spike_seq[t]), __hmul2(v_reset, spike_seq[t]));
}
}
}
实现IF神经元的CUDA多步反向传播
多步反向传播(BPTT)要比多步前向传播更为复杂。我们首先回顾SpikingJelly中的前向传播定义:
我们在前文中实现的前向传播可以表示为:
相应的,我们需要实现的反向传播为:
因而,BPTT函数所需要的输入为:
grad_spike_seq
:shape = [T, N]
,表示损失对T
个时刻的输出脉冲spike_seq
的梯度grad_v_seq
:shape = [T, N]
,表示损失对T
个时刻的放电后的电压v_seq
的梯度
BPTT函数的输出为:
grad_x_seq
:shape = [T, N]
,表示损失对T
个时刻的输入x_seq
的梯度grad_v_init
:shape = [N]
,表示损失对v_init
的梯度
根据前向传播,推出反向传播的计算式为:
其中 \(D_{reset}\) 表示是否detach reset:
合并公式得到:
上述公式中,\(\frac{\mathrm{d} H[t+1]}{\mathrm{d} V[t]}, \frac{\mathrm{d} H[t]}{\mathrm{d} X[t]}\) 是由神经元的充电方程\(H[t] = f(V[t - 1], X[t])\) 决定,与特定的神经元相关;\(\frac{\mathrm{d} S[t]}{\mathrm{d} H[t]}\) 由替代函数决定;其余部分则是通用的。
因而,spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronBPTTKernel
也实现了通用的计算部分。我们首先查看其函数参数:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
for key, value in base_kernel.cparams.items():
print(f'key="{key}",'.ljust(22), f'value="{value}"'.ljust(20))
输出为:
key="numel", value="const int &"
key="N", value="const int &"
key="grad_spike_seq", value="const float *"
key="grad_v_seq", value="const float *"
key="h_seq", value="const float *"
key="grad_x_seq", value="float *"
key="grad_v_init", value="float *"
key="v_th", value="float &"
key="v_reset", value="float &"
参数含义在前文中已经介绍过。
这里需要注意,我们设置 NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, ...
,因为在反向传播时需要指定替代函数。
在SpikingJelly的替代函数类中,提供了 cuda_codes
函数以生成反向传播的CUDA代码。以 spikingjelly.activation_based.surrogate.Sigmoid
为例,其定义为:
class Sigmoid(SurrogateFunctionBase):
# ...
def cuda_codes(self, y: str, x: str, dtype: str):
return cfunction.sigmoid_backward(y=y, x=x, alpha=self.alpha, dtype=dtype)
我们尝试打印出反向传播的代码:
from spikingjelly.activation_based import surrogate
print(surrogate.Sigmoid().cuda_codes(y='grad_s', x='over_th', dtype='float'))
输出为:
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
grad_s = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
如果我们要自行实现支持CUDA反向传播的替代函数,也应该遵循类似的规范,按照如下格式进行定义:
class CustomSurrogateFunction:
# ...
def cuda_codes(self, y: str, x: str, dtype: str):
# ...
接下来查看 NeuronBPTTKernel
完整内核代码:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
print(base_kernel.full_codes)
输出为:
#include <cuda_fp16.h>
extern "C" __global__
void NeuronBPTTKernel_float_hard_reset_nodetach_reset(
const int & N, const float * grad_spike_seq, float * grad_v_init, const float * grad_v_seq, float * grad_x_seq, const float * h_seq, const int & numel, float & v_reset, float & v_th
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
float grad_h = 0.0f;
for(int t = numel - N + index; t >= 0; t -= dt)
{
const float over_th = h_seq[t] - v_th;
const float spike_seq_t = over_th >= 0.0f ? 1.0f: 0.0f;
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
const float grad_s_to_h = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
float grad_v_to_h = (1.0f) - spike_seq_t;
{
float temp_var = v_reset - h_seq[t];
temp_var = temp_var * grad_s_to_h;
grad_v_to_h = temp_var + grad_v_to_h;
}
// grad_h_next_to_v should be defined here!;
grad_h = grad_h * grad_h_next_to_v;
grad_h = grad_v_seq[t] + grad_h;
grad_h = grad_h * grad_v_to_h;
{
float temp_var = grad_spike_seq[t] * grad_s_to_h;
grad_h = grad_h + temp_var;
}
// grad_h_to_x should be defined here!;
grad_x_seq[t] = grad_h * grad_h_to_x;
}
// grad_h_next_to_v should be defined here!;
grad_v_init[index] = grad_h * grad_h_next_to_v;
}
}
上述代码中注释的位置,即提示我们需要补充的位置。它们在 NeuronBPTTKernel
中有对应的函数:
class NeuronBPTTKernel(base.CKernel2D):
# ...
def grad_h_next_to_v(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t+1]}{\\mathrm{d} V[t]}`.
This function should define how ``grad_h_next_to_v`` is calculated. Note that ``grad_h_next_to_v`` has not been
declared. Thus, this function should also declare ``grad_h_next_to_v``.
For example, the IF neuron define this function as:
.. code-block:: python
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
"""
return '// grad_h_next_to_v should be defined here!'
def grad_h_to_x(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t]}{\\mathrm{d} X[t]}`.
This function should define how ``grad_h_to_x`` is calculated. Note that ``grad_h_to_x`` has not been
declared. Thus, this function should also declare ``grad_h_to_x``.
For example, the IF neuron define this function as:
.. code-block:: python
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
"""
return '// grad_h_to_x should be defined here!'
对于IF神经元,\(\frac{\mathrm{d} H[t+1]}{\mathrm{d} V[t]}=1, \frac{\mathrm{d} H[t]}{\mathrm{d} X[t]}=1\)。因此,可以很容易实现IF神经元的BPTT内核:
class IFNodeBPTTKernel(neuron_kernel.NeuronBPTTKernel):
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
接下来,就可以打印出完整的IF神经元BPTT的CUDA内核:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeBPTTKernel(neuron_kernel.NeuronBPTTKernel):
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
kernel = IFNodeBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
print(kernel.full_codes)
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeBPTTKernel_float_hard_reset_nodetach_reset(
const int & N, const float * grad_spike_seq, float * grad_v_init, const float * grad_v_seq, float * grad_x_seq, const float * h_seq, const int & numel, float & v_reset, float & v_th
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
float grad_h = 0.0f;
for(int t = numel - N + index; t >= 0; t -= dt)
{
const float over_th = h_seq[t] - v_th;
const float spike_seq_t = over_th >= 0.0f ? 1.0f: 0.0f;
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
const float grad_s_to_h = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
float grad_v_to_h = (1.0f) - spike_seq_t;
{
float temp_var = v_reset - h_seq[t];
temp_var = temp_var * grad_s_to_h;
grad_v_to_h = temp_var + grad_v_to_h;
}
const float grad_h_next_to_v = 1.0f;
grad_h = grad_h * grad_h_next_to_v;
grad_h = grad_v_seq[t] + grad_h;
grad_h = grad_h * grad_v_to_h;
{
float temp_var = grad_spike_seq[t] * grad_s_to_h;
grad_h = grad_h + temp_var;
}
const float grad_h_to_x = 1.0f;
grad_x_seq[t] = grad_h * grad_h_to_x;
}
const float grad_h_next_to_v = 1.0f;
grad_v_init[index] = grad_h * grad_h_next_to_v;
}
}
Python包装
接下来,使用 torch.autograd.Function
对FPTT和BPTT进行包装。
spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronATGFBase
提供了一些通用的函数用来包装。我们将在实现IF神经元的autograd Function时进行使用。建议首先阅读 NeuronATGFBase
的API文档,我们在下文中会默认读者已经了解其各个函数的使用。
首先需要确定输入。在SpikingJelly中,CUDA内核会被作为前向传播的输入,是由神经元的类去生成,而不是autograd Function生成(在0.0.0.0.12及之前的老版本中是这样做的)。前向传播的定义如下:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
接下来根据输入,生成 py_dict
,并交给 NeuronATGFBase.pre_forward
处理:
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
接下来就可以直接调用前向传播了:
forward_kernel((blocks,), (threads,), py_dict)
接下来,我们需要保存反向传播所需的参数:
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
最后返回 T
个time-steps的脉冲和电压。不要忘了 v_v_seq[1:]
才是要返回的 v_seq
,因此返回值为:
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
完整的前向传播代码为:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
forward_kernel((blocks,), (threads,), py_dict)
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
接下来实现反向传播。反向传播函数的输入,是前向传播函数的输出tensor的梯度tensor,因此输入是:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor):
借助 NeuronATGFBase.pre_backward
,进行预处理,得到执行反向传播内核的参数:
backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq)
然后直接执行反向传播内核:
backward_kernel((blocks,), (threads,), py_dict)
最后返回梯度。前向传播有几个输入,反向传播就有几个返回值:
return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None
完整的代码为:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
forward_kernel((blocks,), (threads,), py_dict)
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
@staticmethod
def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor):
backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq)
backward_kernel((blocks,), (threads,), py_dict)
return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None
实现CUPY后端
利用之前我们已经定义好的 IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF
,我们实现一个简化的IF神经元,并添加CUPY后端。
完整的代码如下:
from spikingjelly.activation_based.auto_cuda.neuron_kernel import IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF
# put sources of ``IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF`` before the following codes
import torch
from typing import Callable
from spikingjelly.activation_based import base, surrogate
class CUPYIFNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float or None = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
self.detach_reset = detach_reset
self.step_mode = 'm'
if v_reset is not None:
self.register_memory('v', v_reset)
else:
self.register_memory('v', 0.)
def multi_step_forward(self, x_seq: torch.Tensor):
if isinstance(self.v, float):
self.v = torch.zeros_like(x_seq[0])
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
forward_kernel = IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype)
backward_kernel = IFNodeBPTTKernel(surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype)
# All tensors wil be regard as 2D or 1D. Thus, we use flatten
spike_seq, v_seq = IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(), self.v_threshold, self.v_reset, forward_kernel, backward_kernel)
spike_seq = spike_seq.view(x_seq.shape)
self.v = v_seq[-1].view(x_seq.shape[1:])
return spike_seq
接下来,让我们与纯pytorch实现对比输出误差:
from spikingjelly.activation_based import neuron
@torch.no_grad()
def max_error(x: torch.Tensor, y: torch.Tensor):
return (x - y).abs().max()
T = 8
N = 64
C = 32 * 32 * 32
device = 'cuda:0'
x_seq = torch.rand([T, N, C], device=device, requires_grad=True)
net_cupy = CUPYIFNode()
y_cupy = net_cupy(x_seq)
y_cupy.sum().backward()
x_grad_cupy = x_seq.grad.clone()
x_seq.grad.zero_()
net_torch = neuron.IFNode(backend='torch', step_mode='m')
y_torch = net_torch(x_seq)
y_torch.sum().backward()
x_grad_torch = x_seq.grad.clone()
print('max error of y_seq', max_error(y_cupy, y_torch))
print('max error of x_seq.grad', max_error(x_grad_cupy, x_grad_torch))
输出为:
max error of y_seq tensor(0., device='cuda:0')
max error of x_seq.grad tensor(1.3113e-06, device='cuda:0')
可以发现,基本没有误差,我们的实现是正确的。
接下来对比速度。实验在 NVIDIA Quadro RTX 6000
上进行:
from spikingjelly.activation_based import neuron, cuda_utils, functional
def forward_backward(net: torch.nn.Module, x_seq: torch.Tensor):
y_seq = net(x_seq)
y_seq.sum().backward()
x_seq.grad.zero_()
functional.reset_net(net)
N = 64
C = 32 * 32 * 32
device = 'cuda:0'
net_cupy = CUPYIFNode()
net_torch = neuron.IFNode(backend='torch', step_mode='m')
repeats = 16
for dtype in [torch.float, torch.half]:
for T in [2, 4, 8, 16, 32]:
x_seq = torch.rand([T, N, C], device=device, requires_grad=True, dtype=dtype)
t_cupy = cuda_utils.cal_fun_t(repeats, device, forward_backward, net_cupy, x_seq)
t_torch = cuda_utils.cal_fun_t(repeats, device, forward_backward, net_torch, x_seq)
print(f'dtype={dtype}, T={T},'.ljust(30), f't_torch / t_cupy = {round(t_torch / t_cupy, 2)}')
输出为:
dtype=torch.float32, T=2, t_torch / t_cupy = 0.59
dtype=torch.float32, T=4, t_torch / t_cupy = 1.47
dtype=torch.float32, T=8, t_torch / t_cupy = 2.67
dtype=torch.float32, T=16, t_torch / t_cupy = 4.17
dtype=torch.float32, T=32, t_torch / t_cupy = 6.93
dtype=torch.float16, T=2, t_torch / t_cupy = 0.68
dtype=torch.float16, T=4, t_torch / t_cupy = 1.31
dtype=torch.float16, T=8, t_torch / t_cupy = 2.2
dtype=torch.float16, T=16, t_torch / t_cupy = 4.77
dtype=torch.float16, T=32, t_torch / t_cupy = 6.7
可以发现,在是使用梯度替代法训练时常用的 T >= 4
时,手动编写的 CUPY
内核拥有较大的加速效果。
当 T
较小时,由于SpikingJelly中的pytorch函数大多使用jit进行了封装,因此速度比手写CUPY快是正常的。因为手写的CUPY逐元素内核,速度慢于jit优化后的pytorch逐元素操作。
在灵汐芯片上推理
本教程作者: fangwei123456
在GPU上训练float16模型
我们使用的是 灵汐科技 的lynxi HP300芯片,完全支持float16,对于float32也可支持但会有一定的计算误差。从我们的使用经验来看,使用float32容易出现误差逐层累计的情况,因此最好使用float16。
将 spikingjelly.activation_based.examples.conv_fashion_mnist
中的网络稍作更改,改为使用float16训练:
# ...
net = CSNN(T=args.T, channels=args.channels, use_cupy=args.cupy).half()
# ...
for img, label in train_data_loader:
optimizer.zero_grad()
img = img.to(args.device).half()
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).half()
# ...
train_acc += (out_fr.argmax(1) == label).half().sum().item()
# ...
# ...
with torch.no_grad():
for img, label in test_data_loader:
img = img.to(args.device).half()
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).half()
# ...
test_acc += (out_fr.argmax(1) == label).half().sum().item()
# ...
将修改后的文件保存为 w1.py,进行训练。需要注意,训练时不再使用AMP:
python w1.py -T 4 -device cuda:0 -b 128 -epochs 64 -data-dir /datasets/FashionMNIST/ -cupy -opt sgd -lr 0.1 -j 8
训练完成后:
Namespace(T=4, device='cuda:0', b=128, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=False, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128, save_es=None)
./logs/T4_b128_sgd_lr0.1_c128_cupy
epoch = 63, train_loss = 0.0041, train_acc = 0.9836, test_loss = 0.0110, test_acc = 0.9312, max_test_acc = 0.9330
train speed = 8056.0318 images/s, test speed = 11152.5812 images/s
escape time = 2022-08-16 10:52:51
最高正确率为 0.9330,模型保存在 ./logs/T4_b128_sgd_lr0.1_c128_cupy 中:
cxhpc@lxnode01:~/fangwei/tempdir/fmnist_test/logs/T4_b128_sgd_lr0.1_c128_cupy$ ls
args.txt checkpoint_latest.pth checkpoint_max.pth events.out.tfevents.1660617801.mlg-ThinkStation-P920.3234566.0
模型编译
并非所有SpikingJelly中的模块都支持灵汐的芯片。为了正确编译, spikingjelly.activation_based.lynxi_exchange
提供了将SpikingJelly的部分网络层 转换到支持的网络层的函数。可以通过 spikingjelly.activation_based.lynxi_exchange.to_lynxi_supported_module
将一个网络层或使用 spikingjelly.activation_based.lynxi_exchange.to_lynxi_supported_modules
将多个网络层进行转换。
需要注意的是,灵汐的芯片不支持5D的tensor,而 shape = [T, N, C, H, W] 经常出现在多步模式下的卷积层之间。对于使用多步模式的网络,使用 to_lynxi_supported_module
或 to_lynxi_supported_modules
进行转换时,会将输入视作 shape = [TN, *]。
例如,查看转换神经元的源代码可以发现,在多步模式下,输入被当作 shape = [TN, *],首先被reshape到 shape = [T, N, *] 然后才进行计算。由于灵汐不支持5D的tensor,神经元 内部直接reshape为3D的tensor:
# spikingjelly/activation_based/lynxi_exchange.py
class BaseNode(nn.Module):
# ...
def forward(self, x: torch.Tensor, v: torch.Tensor = None):
# ...
elif self.step_mode == 'm':
x = x.reshape(self.T, x.shape[0] // self.T, -1)
# ...
接下来,我们将训练好的识别FashionMNIST的网络进行转换。原始网路的定义如下:
# spikingjelly/activation_based/examples/conv_fashion_mnist.py
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
super().__init__()
self.T = T
self.conv_fc = nn.Sequential(
layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 14 * 14
layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 7 * 7
layer.Flatten(),
layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(channels * 4 * 4, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
)
functional.set_step_mode(self, step_mode='m')
if use_cupy:
functional.set_backend(self, backend='cupy')
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
x_seq = self.conv_fc(x_seq)
fr = x_seq.mean(0)
return fr
我们首先加载原始网络:
net_sj = conv_fashion_mnist.CSNN(T=args.T, channels=args.channels)
net_sj.eval()
ckp = torch.load(args.pt_path, map_location='cpu')
print(f'max_test_acc={ckp["max_test_acc"]}')
net_sj.load_state_dict(ckp['net'])
然后转换为支持的网络层:
module_list = lynxi_exchange.to_lynxi_supported_modules(net_sj.conv_fc, args.T)
需要注意的是,根据原始网络的定义,net_sj.conv_fc 的输出 shape = [T, N, C];我们转换后,输出的 shape = [TN, C]。为了得到分类结果,我们需要求得发放率。
因此,新建一个网络:
class InferenceNet(nn.Module):
def __init__(self, T: int, modules_list: list):
super().__init__()
self.T = T
self.module_list = nn.Sequential(*modules_list)
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
x = x.repeat(self.T, 1, 1, 1)
# [N, C, H, W] -> [T, N, C, H, W]
x = self.module_list(x)
# [TN, *] -> [T, N, *]
x = x.reshape(self.T, x.shape[0] // self.T, -1)
return x.mean(0)
net = InferenceNet(args.T, module_list)
net.eval()
print(net)
输出为:
InferenceNet(
(module_list): Sequential(
(0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): IFNode()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): IFNode()
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): Flatten(start_dim=1, end_dim=-1)
(9): Linear(in_features=6272, out_features=2048, bias=False)
(10): IFNode()
(11): Linear(in_features=2048, out_features=10, bias=False)
(12): IFNode()
)
)
接下来对模型进行编译:
output_path = lynxi_exchange.compile_lynxi_model(lynxi_model_path, net, in_data_type='float16',
out_data_type='float16',
input_shape_dict={'x': torch.Size([batch_size, 1, 28, 28])})
推理
推理时,首先加载编译好的网络:
net_lynxi = lynxi_exchange.load_lynxi_model(device_id, output_path)
然后将pytorch的输入tensor转换为灵汐的tensor,送入网络;将输出的灵汐tensor转化为pytorch的tensor,便于计算正确率:
test_acc = 0
test_samples = 0
with torch.no_grad():
for img, label in tqdm.tqdm(test_data_loader, disable=False):
y = net_lynxi(lynxi_exchange.torch_tensor_to_lynxi(img, device_id))
y = lynxi_exchange.lynxi_tensor_to_torch(y, shape=[label.shape[0], 10], dtype='float16')
test_acc += (y.argmax(1) == label).half().sum().item()
test_samples += img.shape[0]
test_acc = test_acc / test_samples
print(f'lynxi inference accuracy = {test_acc}')
最终正确率为:
lynxi inference accuracy = 0.9316
完整代码和输入输出
完整的代码位于 spikingjelly/activation_based/examples/lynxi_fmnist_inference.py,运行的命令行参数为:
(fangwei) cxhpc@lxnode01:~/fangwei/spikingjelly$ python -m spikingjelly.activation_based.examples.lynxi_fmnist_inference -epochs
lynxi_exchange.py[line:185]-CRITICAL: lyngor.version=1.1.0
usage: test.py [-h] [-T T] [-j N] [-data-dir DATA_DIR] [-channels CHANNELS]
[-b B] [-pt-path PT_PATH] [-out-model-path OUT_MODEL_PATH]
[-lynxi-device LYNXI_DEVICE]
Inference on Lynxi chips
optional arguments:
-h, --help show this help message and exit
-T T simulating time-steps
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of Fashion-MNIST dataset
-channels CHANNELS channels of CSNN
-b B batch size
-pt-path PT_PATH checkpoint file path for conv_fashion_mnist.CSNN
-out-model-path OUT_MODEL_PATH
path for saving the model compiled by lynxi
-lynxi-device LYNXI_DEVICE
device id for lynxi
完整的输出日志为:
CRITICAL:root:lyngor.version=1.1.0
lynxi_exchange.py[line:185]-CRITICAL: lyngor.version=1.1.0
Namespace(T=4, b=16, channels=128, data_dir=None, j=4, lynxi_device=0, out_model_path='/home/cxhpc/fangwei/tempdir/fmnist_test/lynxi_model', pt_path='/home/cxhpc/fangwei/tempdir/fmnist_test/logs/T4_b128_sgd_lr0.1_c128_cupy/checkpoint_max.pth')
max_test_acc=0.933
InferenceNet(
(module_list): Sequential(
(0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): IFNode()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): IFNode()
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): Flatten(start_dim=1, end_dim=-1)
(9): Linear(in_features=6272, out_features=2048, bias=False)
(10): IFNode()
(11): Linear(in_features=2048, out_features=10, bias=False)
(12): IFNode()
)
)
util.py[line:268]-INFO: [optimize] Total time running optimize(2): 1.9529 seconds
util.py[line:268]-INFO: [apu_build+optimize] Total time running apu_build(1): 4.8967 seconds
Aborted (core dumped)
builder.py[line:252]-ERROR: abc_map is error, error num is -2
INFO: build abc map failed, try to build by auto mode
util.py[line:268]-INFO: [optimize] Total time running optimize(1519): 1.2367 seconds
util.py[line:268]-INFO: [apu_build+optimize] Total time running apu_build(1518): 4.1377 seconds
lx_map compile option:
git tag : LX_APU_0626
APU_LOG_LEVEL: 1
isNewCmd : true
gen_golden : false
savePDF : false
sanityCheck : false
dynPattern : false
release : true
logFile : "APU.log"
batch : 1
MC conv info:
bHasMCConv : true
bFastModeConv: true
test thread received convert primitive worker done message. 占用CPU时间 = 0.62s (累计用时 0.62s)
====================================
test thread received resource assign worker done message. 占用CPU时间 = 0.71s (累计用时 1.33s)
====================================
test thread received core slice worker done message. 占用CPU时间 = 65.14s (累计用时 66.47s)
====================================
test thread received core map worker done message. 占用CPU时间 = 693.75s (累计用时 760.22s)
====================================
test thread received core mem arrange worker done message. 占用CPU时间 = 129.37s (累计用时 889.59s)
====================================
test thread received rc cfg worker done message. 占用CPU时间 = 176.15s (累计用时 1065.74s)
====================================
test thread received route map worker done message. 占用CPU时间 = 17.04s (累计用时 1082.78s)
====================================
支持多batch编译,最大支持数准确值请参考batchsize=2的信息说明,当前结果: 10
test thread received print worker done message. 占用CPU时间 = 23.28s (累计用时 1106.06s)
====================================
util.py[line:268]-INFO: [map] Total time running apu_map(3034): 1110.0334 seconds
util.py[line:268]-INFO: [build+map] Total time running build(0): 1136.2683 seconds
['net_params.json', 'Net_0']
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [02:36<00:00, 4.00it/s]
lynxi inference accuracy = 0.9316
转换到Lava框架以进行Loihi部署
本教程作者: fangwei123456
感谢 AllenYolk 和 banzhuangonglxh 对 lava_exchange 模块的贡献
Lava框架简介
Lava 是Intel主导开发的神经形态计算框架,支持Intel Loihi芯片的部署。Lava 提供了一个名为 Lava DL 的深度学习子包,可以搭建和训练深度SNN。
若想将SNN部署到Loihi芯片运行,则需要使用Lava框架。SpikingJelly中提供了对应的转换模块,可以将SpikingJelly中的模块或训练的网络转换到Lava框架,以便将网络部署到Loihi芯片运行。其基本流程为:
SpikingJelly -> Lava DL -> Lava -> Loihi
与Lava相关的模块,都定义在 spikingjelly.activation_based.lava_exchange
中。
基本转换
数据格式转换
Lava DL默认数据格式为 shape = [N, *, T]
,其中 N
是batch维度,T
是time-step维度。而SpikingJelly中的模块在多步模式(step_mode = 'm'
)下,使用的数据格式是shape = [T, N, *]
。因此,lava_exchange
提供了两种格式的相互转换函数,TNX_to_NXT
和NXT_to_TNX
。示例如下:
import torch
from spikingjelly.activation_based import lava_exchange
T = 6
N = 4
C = 2
x_seq = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq)
print(f'x_seq_la.shape=[N, C, T]={x_seq_la.shape}')
x_seq_sj = lava_exchange.NXT_to_TNX(x_seq_la)
print(f'x_seq_sj.shape=[T, N, C]={x_seq_sj.shape}')
输出为:
x_seq_la.shape=[N, C, T]=torch.Size([4, 2, 6])
x_seq_sj.shape=[T, N, C]=torch.Size([6, 4, 2])
神经元转换
SpikingJelly中的神经元可以直接转换为Lava DL中的神经元。由于开发者精力有限,目前仅支持最常用的IF神经元和LIF神经元,其他神经元将在视用户需求添加。
使用 to_lava_neuron
进行转换,示例如下:
import torch
from spikingjelly.activation_based import lava_exchange, neuron
if_sj = neuron.IFNode(v_threshold=1., v_reset=0., step_mode='m')
if_la = lava_exchange.to_lava_neuron(if_sj)
T = 8
N = 2
C = 1
x_seq_sj = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq_sj)
print('output of sj(reshaped to NXT):\n', lava_exchange.TNX_to_NXT(if_sj(x_seq_sj)))
print('output of lava:\n', if_la(x_seq_la))
输出为:
output of sj(reshaped to NXT):
tensor([[[0., 0., 1., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 1., 0., 1., 0., 1.]]])
output of lava:
tensor([[[0., 0., 1., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 1., 0., 1., 0., 1.]]])
使用LIF神经元的示例如下:
import torch
from spikingjelly.activation_based import lava_exchange, neuron
if_sj = neuron.LIFNode(tau=50., decay_input=False, v_threshold=1., v_reset=0., step_mode='m')
if_la = lava_exchange.to_lava_neuron(if_sj)
T = 8
N = 2
C = 1
x_seq_sj = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq_sj)
print('output of sj:\n', lava_exchange.TNX_to_NXT(if_sj(x_seq_sj)))
print('output of lava:\n', if_la(x_seq_la))
输出为:
output of sj:
tensor([[[0., 1., 0., 1., 0., 0., 1., 0.]],
[[0., 0., 1., 0., 0., 1., 0., 1.]]])
output of lava:
tensor([[[0., 1., 0., 1., 0., 0., 1., 0.]],
[[0., 0., 1., 0., 0., 1., 0., 1.]]])
突触转换
常用的卷积、全连接、池化层都支持转换。需要注意的是:
不支持bias
Lava只支持求和池化,相当于是平均池化不做平均
示例如下:
from spikingjelly.activation_based import lava_exchange, layer
conv = layer.Conv2d(3, 4, kernel_size=3, stride=1, bias=False)
fc = layer.Linear(4, 2, bias=False)
ap = layer.AvgPool2d(2, 2)
conv_la = lava_exchange.conv2d_to_lava_synapse_conv(conv)
fc_la = lava_exchange.linear_to_lava_synapse_dense(fc)
sp_la = lava_exchange.avgpool2d_to_lava_synapse_pool(ap)
print(f'conv_la={conv_la}')
print(f'fc_la={fc_la}')
print(f'sp_la={sp_la}')
输出为:
WARNING:root:The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer.
conv_la=Conv(3, 4, kernel_size=(3, 3, 1), stride=(1, 1, 1), bias=False)
fc_la=Dense(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
sp_la=Pool(1, 1, kernel_size=(2, 2, 1), stride=(2, 2, 1), bias=False)
Lava DL中几乎所有突触都是由 torch.nn.Conv3d
实现的,因此打印出来会显示含有3个元素的tuple的 kernel_size
和 stride
。
BlockContainer
使用Lava DL的一般流程是:
使用Lava DL框架中的 Blocks 搭建并训练网络
将网络导出为hdf5文件
使用Lava框架读取hdf5文件,以Lava的格式重建网络,并使用Loihi或CPU仿真的Loihi进行推理
具体信息,请参考 Lava: Deep Learning。
Blocks 可以被视作突触和神经元组成的集合。例如,lava.lib.dl.slayer.block.cuba.Conv
实际上就是由卷积突触和CUBA神经元组成的。
需要注意的是,为了进行网络部署,Blocks
中的突触权重和神经元的神经动态都进行了量化,因此 Blocks
并不是简单的synapse + neuron
,而是 quantize(synapse) + quantize(neuron)
。
SpikingJelly提供了 BlockContainer
,主要特点如下:
支持替代梯度训练
对突触和神经动态进行了量化,与
lava.lib.dl.slayer.block
具有完全相同的输出支持直接转换为一个
lava.lib.dl.slayer.block
目前 BlockContainer
仅支持 lava_exchange.CubaLIFNode
,但也支持自动将输入的 IFNode
和 LIFNode
转换为 CubaLIFNode
。例如:
from spikingjelly.activation_based import lava_exchange, layer, neuron
fc_block_sj = lava_exchange.BlockContainer(
synapse=layer.Linear(8, 1, bias=False),
neu=neuron.IFNode(),
step_mode='m'
)
print('fc_block_sj=\n', fc_block_sj)
fc_block_la = fc_block_sj.to_lava_block()
print('fc_block_la=\n', fc_block_la)
输出为:
fc_block_sj=
BlockContainer(
(synapse): Linear(in_features=8, out_features=1, bias=False)
(neuron): CubaLIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
fc_block_la=
Dense(
(neuron): Neuron()
(synapse): Dense(8, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
)
MNIST CSNN示例
最后,让我们训练一个用于分类MNIST的卷积SNN,并转换到Lava DL框架。
网络定义如下:
class MNISTNet(nn.Module):
def __init__(self, channels: int = 16):
super().__init__()
self.conv_fc = nn.Sequential(
lava_exchange.BlockContainer(
nn.Conv2d(1, channels, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
# 14 * 14
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
# 7 * 7
lava_exchange.BlockContainer(
nn.Flatten(),
None
),
lava_exchange.BlockContainer(
nn.Linear(channels * 7 * 7, 128, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Linear(128, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
)
def forward(self, x):
return self.conv_fc(x)
我们为其增加一个转换到Lava DL网络的转换函数,在训练完成后可以使用:
def to_lava(self):
ret = []
for i in range(self.conv_fc.__len__()):
m = self.conv_fc[i]
if isinstance(m, lava_exchange.BlockContainer):
ret.append(m.to_lava_block())
return nn.Sequential(*ret)
接下来,对这个网络进行训练即可。训练流程与普通网络区别不大,只是在 lava_exchange.BlockContainer
内部,突触和神经动态都做了量化,这会导致正确率低于普通网络。部分训练代码如下:
encoder = encoding.PoissonEncoder(step_mode='m')
# ...
for img, label in train_data_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
img = img.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)
fr = net(encoder(img)).mean(0)
loss = F.cross_entropy(fr, label)
loss.backward()
optimizer.step()
# ...
当我们训练完成后,将网络转换到Lava DL,并检查测试集的正确率:
net_ladl = net.to_lava().to(args.device)
net_ladl.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
for img, label in test_data_loader:
img = img.to(args.device)
label = label.to(args.device)
img = img.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)
img = encoder(img)
img = lava_exchange.TNX_to_NXT(img)
fr = net_ladl(img).mean(-1)
loss = F.cross_entropy(fr, label)
test_samples += label.numel()
test_loss += loss.item() * label.numel()
test_acc += (fr.argmax(1) == label).float().sum().item()
test_loss /= test_samples
test_acc /= test_samples
print('test acc[lava dl] =', test_acc)
最后,我们将Lava DL的网络导出hdf5,这样之后可以使用Lava框架加载,并在Loihi或者CPU模拟的Loihi上进行推理。具体流程请参考 Network Exchange (NetX) Library。
导出部分的代码如下:
def export_hdf5(net, filename):
# network export to hdf5 format
h = h5py.File(filename, 'w')
layer = h.create_group('layer')
for i, b in enumerate(net):
handle = layer.create_group(f'{i}')
b.export_hdf5(handle)
export_hdf5(net_ladl, os.path.join(args.out_dir, 'net_la.net'))
完整的代码位于 spikingjelly.activation_based.examples.lava_mnist
,命令行参数如下:
(lava-env) wfang@mlg-ThinkStation-P920:~/tempdir/w1$ python -m spikingjelly.activation_based.examples.lava_mnist -h
usage: lava_mnist.py [-h] [-T T] [-b B] [-device DEVICE] [-data-dir DATA_DIR]
[-channels CHANNELS] [-epochs EPOCHS] [-lr LR] [-out-dir OUT_DIR]
options:
-h, --help show this help message and exit
-T T simulating time-steps
-b B batch size
-device DEVICE device
-data-dir DATA_DIR root dir of the MNIST dataset
-channels CHANNELS channels of CSNN
-epochs EPOCHS training epochs
-lr LR learning rate
-out-dir OUT_DIR path for saving weights
在启动后,会首先训练网络,然后转换到Lava DL并进行推理,最后将hdf5格式的网络导出:
(lava-env) wfang@mlg-ThinkStation-P920:~/tempdir/w1$ python -m spikingjelly.activation_based.examples.lava_mnist -T 32 -device cuda:0 -b 128 -epochs 16 -data-dir /datasets/MNIST/ -lr 0.1 -channels 16
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
epoch = 0, train_loss = 1.7607, train_acc = 0.7245, test_loss = 1.5243, test_acc = 0.9443, max_test_acc = 0.9443
# ...
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
epoch = 15, train_loss = 1.4743, train_acc = 0.9881, test_loss = 1.4760, test_acc = 0.9855, max_test_acc = 0.9860
finish training
test acc[sj] = 0.9855
test acc[lava dl] = 0.9863
save net.state_dict() to ./net.pt
save net_ladl.state_dict() to ./net_ladl.pt
export net_ladl to ./net_la.net
模块文档
文档索引
引用和出版物
如果您在自己的工作中用到了惊蜇(SpikingJelly),您可以按照下列格式进行引用:
@misc{SpikingJelly,
title = {SpikingJelly},
author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Tian, Yonghong and other contributors},
year = {2020},
howpublished = {\url{https://github.com/fangwei123456/spikingjelly}},
note = {Accessed: YYYY-MM-DD},
}
其中的 YYYY-MM-DD 需要更改为您的工作使用的惊蜇(SpikingJelly)版本对应的最后一次代码修改日期。
使用惊蜇(SpikingJelly)的出版物可见于 Publications using SpikingJelly。
项目信息
北京大学信息科学技术学院数字媒体所媒体学习组 Multimedia Learning Group 和 鹏城实验室 是SpikingJelly的主要开发者。


开发人员名单可见于 贡献者 。
友情链接
Welcome to SpikingJelly’s documentation
SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
Notification
From the version 0.0.0.0.14
, modules including clock_driven
and event_driven
are renamed. Please refer to the tutorial Migrate From Old Versions.
Docs for different versions (latest is the developing version):
Installation
Note that SpikingJelly is based on PyTorch. Please make sure that you have installed PyTorch before you install SpikingJelly.
The odd version number is the developing version, which is updated with GitHub/OpenI repository. The even version number is the stable version and available at PyPI.
Install the last stable version from PyPI:
pip install spikingjelly
Install the latest developing version from the source codes:
From GitHub:
git clone https://github.com/fangwei123456/spikingjelly.git
cd spikingjelly
python setup.py install
From OpenI:
git clone https://git.openi.org.cn/OpenI/spikingjelly.git
cd spikingjelly
python setup.py install
Migrate From Old Versions
Author: fangwei123456
There is some difference between the old and new versions of SpikingJelly. We recommend the users read this tutorial if they are familiar with the old version and want to try the new version. SpikingJelly has nice compatibility for the old version, and the users do not need to do much change to their codes to Migrate from the old version to the new version.
We also recommend that the users read the tutorial Basic Conception
The old version of SpikingJelly means the version number <=0.0.0.0.12
.
Rename of Packages
In the new version, SpikingJelly renames some sub-packages, which are:
Old |
New |
---|---|
clock_driven |
activation_based |
event_driven |
timing_based |
Step Mode and Propagation Patterns
All modules in the old version (<=0.0.0.0.12
) of SpikingJelly are the single-step modules by default, except for the module that has the prefix MultiStep
.
The new version of SpikingJelly does not use the prefix to distinguish the single/multi-step module. Now the step mode is controlled by the module itself, which is the attribute step_mode
. Refer to Basic Conception for more details.
Hence, there is no multi-step module defined additionally in the new version of SpikingJelly. Now one module can be both the single-step module and the multi-step module, which is determined by step_mode
is 's'
or 'm'
.In the old version of SpikingJelly, if we want to use the LIF neuron with single-step, we write codes as:
from spikingjelly.clock_driven import neuron
lif = neuron.LIFNode()
In the new version of SpikingJelly, all modules are single-step modules by default. We write codes similar to the old version, except we replace clock_driven``with ``activation_based
:
from spikingjelly.activation_based import neuron
lif = neuron.LIFNode()
In the old version of SpikingJelly, if we want to use the LIF neuron with multi-step, we should write codes as:
from spikingjelly.clock_driven import neuron
lif = neuron.MultiStepLIFNode()
In the new version of SpikingJelly, one module can use both single-step and multi-step. We can use the LIF neuron with multi-step easily by setting step_mode='m'
:
from spikingjelly.activation_based import neuron
lif = neuron.LIFNode(step_mode='m')
In the old version of SpikingJelly, we use the step-by-step or layer-by-layer propagation patterns as the following codes:
import torch
import torch.nn as nn
from spikingjelly.clock_driven import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
# step-by-step
net_sbs = nn.Sequential(
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
neuron.IFNode()
)
y_seq = functional.multi_step_forward(x_seq, net_sbs)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net_sbs)
# layer-by-layer
net_lbl = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
),
neuron.MultiStepIFNode()
)
y_seq = net_lbl(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net_lbl)
In the new version of SpikingJelly, we can use spikingjelly.activation_based.functional.set_step_mode
to change the step mode of all modules in the whole network.If all modules use single-step, the network can use a step-by-step propagation pattern; if all modules use multi-step, the network can use a layer-by-layer propagation pattern:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
# the network uses step-by-step because step_mode='s' is the default value for all modules
net = nn.Sequential(
layer.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(C),
neuron.IFNode()
)
y_seq = functional.multi_step_forward(x_seq, net)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
# set the network to use layer-by-layer
functional.set_step_mode(net, step_mode='m')
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
Basic Conception
Author: fangwei123456
Translator: Qiu Haonan, fangwei123456
This tutorial introduces spikingjelly.activation_based
. It is recommended that all users read this tutorial before using SpikingJelly.
Spikingjelly is a deep learning framework for Spiking Neural Network (SNN) based on PyTorch. Users who want to use SpikingJelly should first be familiar with the usage of PyTorch.If the user doesn’t know much about PyTorch, we recommend that the user can learn the basic tutorial of PyTorch first PyTorch Tutorials 。
Activation-based Representation
spikingjelly.activation_based
uses tensors whose element is only 0 or 1 to represent spikes. For example:
import torch
v = torch.rand([8])
v_th = 0.5
spike = (v >= v_th).to(v)
print('spike =', spike)
# spike = tensor([0., 0., 0., 1., 1., 0., 1., 0.])
Data Format
In spikingjelly.activation_based
, There are two formats of data:
Data in a single time-step with
shape = [N, *]
, whereN
is the batch dimension,*
represents any extra dimensions.Data in many time-steps with
shape = [T, N, *]
, whereT
is the time-step dimension,N
is the batch dimension and * represents any additional dimensions.
Step Mode
Modules in spikingjelly.activation_based
have two propagation modes, which are the single-step mode ‘s’ and the multi-step mode ‘m’. In single-step mode, the data use the shape = [N, *]
format. In multi-step mode, the data use the shape = [T, N, *]
format.
The user can set step_mode
of a module in its __init__
or change step_mode
anytime after the module is built.
import torch
from spikingjelly.activation_based import neuron
net = neuron.IFNode(step_mode='m')
# 'm' is the multi-step mode
net.step_mode = 's'
# 's' is the single-step mode
If we want to input the sequence data with shape = [T, N, *]
to a single-step module, we need to implement a for-loop in time-steps manually, which splits the sequence data into T
data with shape = [N, *]
and sends the data step-by-step. Let’s create a new layer of IF neurons, set it to single-step mode, and input sequence data step-by-step:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = []
for t in range(T):
x = x_seq[t] # x.shape = [N, C, H, W]
y = net_s(x) # y.shape = [N, C, H, W]
y_seq.append(y.unsqueeze(0))
y_seq = torch.cat(y_seq)
# y_seq.shape = [T, N, C, H, W]
multi_step_forward
wraps the for-loop in time-steps for single-step modules to handle sequence data with shape = [T, N, *]
, which is more convenient to use:
import torch
from spikingjelly.activation_based import neuron, functional
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
However, the best usage is to set the module as a multi-step module directly:
import torch
from spikingjelly.activation_based import neuron
net_m = neuron.IFNode(step_mode='m')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = net_m(x_seq)
# y_seq.shape = [T, N, C, H, W]
To maintain compatibility with codes using older versions of SpikingJelly, the default step mode for all modules in SpikingJelly is single-step.
Saving and Resetting of States
Similar to RNN, neurons and other modules in SNN have hidden states, and their outputs \(Y[t]\) are determined not only by the input :math: X[t] at the current time-step t, but also by the state \(H[t-1]\) at last time-step t-1, which is \(Y[t] = f(X[t], H[t-1])\).
In PyTorch, RNN outputs not only \(Y\) but also \(H\). Refer to torch.nn.RNN
for more details. Different from PyTorch, the states are stored inside the module in spikingjelly.activation_based
. For example, let us create a new layer of IF neurons, set them to single-step mode, and check the default voltage before and after giving inputs:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(net_s)
print(f'the initial v={net_s.v}')
y = net_s(x)
print(f'x={x}')
print(f'y={y}')
print(f'v={net_s.v}')
# outputs are:
'''
IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
the initial v=0.0
x=tensor([0.5543, 0.0350, 0.2171, 0.6740])
y=tensor([0., 0., 0., 0.])
v=tensor([0.5543, 0.0350, 0.2171, 0.6740])
'''
After initialization, the v
of the IF neurons layer is set to 0 and is automatically broadcast to have the same shape
as the input.
If we give a new input sample, we should clear the previous states of the neurons and reset the neurons to the initialization states, which can be done by calling the module’s self.reset()
function:
import torch
from spikingjelly.activation_based import neuron
net_s = neuron.IFNode(step_mode='s')
x = torch.rand([4])
print(f'check point 0: v={net_s.v}')
y = net_s(x)
print(f'check point 1: v={net_s.v}')
net_s.reset()
print(f'check point 2: v={net_s.v}')
x = torch.rand([8])
y = net_s(x)
print(f'check point 3: v={net_s.v}')
# outputs are:
'''
check point 0: v=0.0
check point 1: v=tensor([0.9775, 0.6598, 0.7577, 0.2952])
check point 2: v=0.0
check point 3: v=tensor([0.8728, 0.9031, 0.2278, 0.5089, 0.1059, 0.0479, 0.5008, 0.8530])
'''
For convenience, we can also call spikingjelly.activation_based.functional.reset_net
to reset all modules in a network.
If the network uses one or more stateful modules, it must be reset after processing one batch of data during training and inference:
from spikingjelly.activation_based import functional
# ...
for x, label in tqdm(train_data_loader):
# ...
optimizer.zero_grad()
y = net(x)
loss = criterion(y, label)
loss.backward()
optimizer.step()
functional.reset_net(net)
# Never forget to reset the network!
If we forget to reset, we may get a wrong output during inference or an error during training:
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
Propagation Patterns
If all modules in a network are single-step modules, the computation graph of the entire network is built step-by-step. For example:
for t in range(T):
x = x_seq[t]
y = net(x)
y_seq_step_by_step.append(y.unsqueeze(0))
y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)
If all modules in a network are multi-step modules, the computation graph of the entire network is built layer-by-layer. For example:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4
N = 2
C = 8
x_seq = torch.rand([T, N, C]) * 64.
net = nn.Sequential(
layer.Linear(C, 4),
neuron.IFNode(),
layer.Linear(4, 2),
neuron.IFNode()
)
functional.set_step_mode(net, step_mode='m')
with torch.no_grad():
y_seq_layer_by_layer = x_seq
for i in range(net.__len__()):
y_seq_layer_by_layer = net[i](y_seq_layer_by_layer)
In most cases, we don’t need an explicit implementation of for i in range(net.__len__())
, because torch.nn.Sequential
has already done that for us. So, we write codes in the following simple style:
y_seq_layer_by_layer = net(x_seq)
The only difference between step-by-step and layer-by-layer is the building order of the computation graph, and their outputs are identical:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, layer
T = 4
N = 2
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W]) * 64.
net = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, stride=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
neuron.IFNode(),
layer.Flatten(start_dim=1),
layer.Linear(8 * H // 2 * W // 2, 10),
neuron.IFNode(),
)
print(f'net={net}')
with torch.no_grad():
y_seq_step_by_step = []
for t in range(T):
x = x_seq[t]
y = net(x)
y_seq_step_by_step.append(y.unsqueeze(0))
y_seq_step_by_step = torch.cat(y_seq_step_by_step, 0)
# we can also use `y_seq_step_by_step = functional.multi_step_forward(x_seq, net)` to get the same results
print(f'y_seq_step_by_step=\n{y_seq_step_by_step}')
functional.reset_net(net)
functional.set_step_mode(net, step_mode='m')
y_seq_layer_by_layer = net(x_seq)
max_error = (y_seq_layer_by_layer - y_seq_step_by_step).abs().max()
print(f'max_error={max_error}')
The outputs of the above codes are:
net=Sequential(
(0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=s)
(3): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(4): Flatten(start_dim=1, end_dim=-1, step_mode=s)
(5): Linear(in_features=128, out_features=10, bias=True)
(6): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
y_seq_step_by_step=
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
[[0., 1., 0., 0., 0., 0., 0., 1., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]],
[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 1., 0., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 0., 0., 0., 1., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 1., 0.]]])
max_error=0.0
The following figure shows how the computation graph is built in the step-by-step propagation pattern:

The following figure shows how the computation graph is built in the layer-by-layer propagation pattern:

There are two dimensions in the computation graph of SNN, which are the time-step and the depth dimension. As the above figures show, the propagation of SNN is the building of the computation graph.We can find that the step-by-step propagation pattern is a Depth-First-Search (DFS) for traversing the computation graph, while the layer-by-layer propagation pattern is a Breadth-First-Search (BFS) for traversing the computation graph.
Although the difference is only in the building order of the computation graph, there are still some slight differences in computation speed and memory consumption of the two propagation patterns.
When using the surrogate gradient method to train SNN directly, it is recommended to use the layer-by-layer propagation pattern. When the network is built correctly, the layer-by-layer propagation pattern has the advantage of parallelism and speed.
Using step-by-step propagation pattern when memory is limited. For example, a large
T
is required in the ANN2SNN task. In the layer-by-layer propagation pattern, the real batch size for stateless layers isTN
rather thanN
(refer to the next tutorial). whenT
is too large, the memory consumption may be too large.
Container
Author: fangwei123456
Translator: Qiu Haonan, fangwei123456
The major containers in SpikingJelly are:
multi_step_forward
in functional style andMultiStepContainer
in module styleseq_to_ann_forward
in functional style andSeqToANNContainer
in module styleStepModeContainer
for wrapping a single-step module for single/multi-step propagation
multi_step_forward
can use a single-step module to implement multi-step propagation, and MultiStepContainer
can wrap a single-step module to a multi-step module. For example:
import torch
from spikingjelly.activation_based import neuron, functional, layer
net_s = neuron.IFNode(step_mode='s')
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
y_seq = functional.multi_step_forward(x_seq, net_s)
# y_seq.shape = [T, N, C, H, W]
net_s.reset()
net_m = layer.MultiStepContainer(net_s)
z_seq = net_m(x_seq)
# z_seq.shape = [T, N, C, H, W]
# z_seq is identical to y_seq
For a stateless ANN layer such as torch.nn.Conv2d
, which requires input data with shape = [N, *]
, to be used in multi-step mode, we can wrap it by the multi-step containers:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
However, the ANN layers are stateless and \(Y[t]\) is only determined by \(X[t]\). Hence, it is not necessary to calculate \(Y[t]\) step-bt-step.We can use seq_to_ann_forward
or SeqToANNContainer
to wrap, which will reshape the input with shape = [T, N, *]
to shape = [TN, *]
, send data to ann layers, and reshape output to shape = [T, N, *]
. The calculation in different time-steps are in parallelism and faster:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer
with torch.no_grad():
T = 4
N = 1
C = 3
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
conv = nn.Conv2d(C, 8, kernel_size=3, padding=1, bias=False)
bn = nn.BatchNorm2d(8)
y_seq = functional.multi_step_forward(x_seq, (conv, bn))
# y_seq.shape = [T, N, 8, H, W]
net = layer.MultiStepContainer(conv, bn)
z_seq = net(x_seq)
# z_seq.shape = [T, N, 8, H, W]
# z_seq is identical to y_seq
p_seq = functional.seq_to_ann_forward(x_seq, (conv, bn))
# p_seq.shape = [T, N, 8, H, W]
net = layer.SeqToANNContainer(conv, bn)
q_seq = net(x_seq)
# q_seq.shape = [T, N, 8, H, W]
# q_seq is identical to p_seq, and also identical to y_seq and z_seq
Most frequently-used ann modules have been defined in spikingjelly.activation_based.layer
. It is recommended to use modules in spikingjelly.activation_based.layer
, rather than using a container to wrap the ann layers manually. Althouth the modules in spikingjelly.activation_based.layer
are implementd by using seq_to_ann_forward
to wrap forward function, the advantages of modules in spikingjelly.activation_based.layer
are:
Both single-step and multi-step modes are supported. When using
SeqToANNContainer
orMultiStepContainer
to wrap modules, only the multi-step mode is supported.The wrapping of containers will add a prefix of
keys()
ofstate_dict
, which brings some troubles for loading weights.
For example:
import torch
import torch.nn as nn
from spikingjelly.activation_based import functional, layer, neuron
ann = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
nn.ReLU()
)
print(f'ann.state_dict.keys()={ann.state_dict().keys()}')
net_container = nn.Sequential(
layer.SeqToANNContainer(
nn.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
),
neuron.IFNode(step_mode='m')
)
print(f'net_container.state_dict.keys()={net_container.state_dict().keys()}')
net_origin = nn.Sequential(
layer.Conv2d(3, 8, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(8),
neuron.IFNode(step_mode='m')
)
print(f'net_origin.state_dict.keys()={net_origin.state_dict().keys()}')
try:
print('net_container is trying to load state dict from ann...')
net_container.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_container can not load! The error message is\n', e)
try:
print('net_origin is trying to load state dict from ann...')
net_origin.load_state_dict(ann.state_dict())
print('Load success!')
except BaseException as e:
print('net_origin can not load! The error message is', e)
The outputs are
ann.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container.state_dict.keys()=odict_keys(['0.0.weight', '0.1.weight', '0.1.bias', '0.1.running_mean', '0.1.running_var', '0.1.num_batches_tracked'])
net_origin.state_dict.keys()=odict_keys(['0.weight', '1.weight', '1.bias', '1.running_mean', '1.running_var', '1.num_batches_tracked'])
net_container is trying to load state dict from ann...
net_container can not load! The error message is
Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.0.weight", "0.1.weight", "0.1.bias", "0.1.running_mean", "0.1.running_var".
Unexpected key(s) in state_dict: "0.weight", "1.weight", "1.bias", "1.running_mean", "1.running_var", "1.num_batches_tracked".
net_origin is trying to load state dict from ann...
Load success!
MultiStepContainer
and SeqToANNContainer
only support for multi-step mode and do not allow to switch to single-step mode.
StepModeContainer
works like the merged version of MultiStepContainer
and SeqToANNContainer
, which can be used to wrap stateless or stateful single-step modules.The user should specify whether the wrapped modules are stateless or stateful when using this container. This container also supports switching step modes.
Here is an example of wrapping a stateless layer:
import torch
from spikingjelly.activation_based import neuron, layer
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
False,
nn.Conv2d(C, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C),
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
Here is an example of wrapping a stateful layer:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
T = 4
N = 2
C = 4
H = 8
W = 8
x_seq = torch.rand([T, N, C, H, W])
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
net.step_mode = 'm'
y_seq = net(x_seq)
# y_seq.shape = [T, N, C, H, W]
functional.reset_net(net)
net.step_mode = 's'
y = net(x_seq[0])
# y.shape = [N, C, H, W]
functional.reset_net(net)
It is safe to use set_step_mode
to change the step mode of StepModeContainer
. Only the step_mode
of the container itself is changed, and the modules inside the container still use single-step:
import torch
from spikingjelly.activation_based import neuron, layer, functional
with torch.no_grad():
net = layer.StepModeContainer(
True,
neuron.IFNode()
)
functional.set_step_mode(net, 'm')
print(f'net.step_mode={net.step_mode}')
print(f'net[0].step_mode={net[0].step_mode}')
If the module itself supports for switching between single-step and multi-step modes, is not recommended to use MultiStepContainer
or StepModeContainer
to wrap.Because the multi-step forward implemented by the container may not be as fast as the forward defined by the module itself.
In most cases, we use MultiStepContainer
or StepModeContainer
to wrap modules which do not define the multi-step forward, such as a network layer that exists in torch.nn
but does not exist in spikingjelly.activation_based.layer
.
Neuron
Author: fangwei123456
This tutorial is about spikingjelly.activation_based.neuron
and introduces the spiking neurons.
Spiking Neuron Modules
In SpikingJelly, we define the spiking neuron as the neuron that can only output spikes (or tensor whose element can only be 0 or 1). The network which uses spiking neurons is the Spiking Neural Network (SNN). Many frequently-used spiking neurons are defined in spikingjelly.activation_based.neuron
. Let us use the spikingjelly.activation_based.neuron.IFNode
as the example to learn how to use neurons in SpikingJelly.
Firstly, let us import modules:
import torch
from spikingjelly.activation_based import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
Define an IF neurons layer:
if_layer = neuron.IFNode()
There are some parameters for building IF neurons, and we can refer to API docs for more details. For the moment, we just focus on the following parameters:
v_threshold – threshold of this neurons layer
- v_reset – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. If
None
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function – the function for calculating surrogate gradients of the heaviside step function in backward
The user may be curious about how many neurons are in this layer. In most of the neurons layer in spikingjelly.activation_based.neuron.IFNode
, the number of neurons is defined by the shape
of input after this layer is initialized or reset()
.
Similar to RNN cells, the spiking neuron is stateful (or has memory). The state of spiking neurons is the membrane potentials \(V[t]\). All neurons in spikingjelly.activation_based.neuron
have the attribute v
. We can print the v
:
print(if_layer.v)
# if_layer.v=0.0
We can find that if_layer.v
is 0.0
because we have not given the neurons layer any input. Let us give different inputs and check the v.shape
. We can find that it is the same with the input:
x = torch.rand(size=[2, 3])
if_layer(x)
print(f'x.shape={x.shape}, if_layer.v.shape={if_layer.v.shape}')
# x.shape=torch.Size([2, 3]), if_layer.v.shape=torch.Size([2, 3])
if_layer.reset()
x = torch.rand(size=[4, 5, 6])
if_layer(x)
print(f'x.shape={x.shape}, if_layer.v.shape={if_layer.v.shape}')
# x.shape=torch.Size([4, 5, 6]), if_layer.v.shape=torch.Size([4, 5, 6])
if_layer.reset()
Note that the spiking neurons are stateful. So, we must call reset()
before we give a new input sample to the spiking neurons.
What is teh realization between \(V[t]\) and \(X[t]\)? In spiking neurons, \(V[t]\) is not determined by the input \(X[t]\) at the current time-step t
, but also by the membrane potential \(V[t-1]\) at the last time-step t-1
.
We use the sub-threshold neuronal dynamics \(\frac{\mathrm{d}V(t)}{\mathrm{d}t} = f(V(t), X(t))\) to describe the charging of continuous-time spiking neurons. For the IF neuron, the charging function is:
spikingjelly.activation_based.neuron
uses the discrete-time difference equation to approximate the continuous-time ordinary differential equation. The discrete-time difference equation of the IF neuron is:
\(V[t]\) can be got by
We can find the following codes in spikingjelly.activation_based.neuron.IFNode.neuronal_charge
:
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
Different spiking neurons have different charging equations. But after the membrane potential exceeds the threshold voltage, the firing and resetting equations are the same. Hence, these equations are inherited from spikingjelly.activation_based.neuron.BaseNode
. We can find the codes in spikingjelly.activation_based.neuron.BaseNode.neuronal_fire
:
def neuronal_fire(self):
self.spike = self.surrogate_function(self.v - self.v_threshold)
surrogate_function()
is the Heaviside step function in forward, which returns 1 when input is greater or equal to 0, otherwise returns 0. We regard the tensor
whose element is only 0 or 1 as the spike.
Firing spike will consume the accumulated potential, and make the potential decrease instantly, which is the neuronal reset. In SNN, there are two kinds of reset:
Hard reset: the membrane potential will be set to the reset voltage after firing: \(V[t] = V_{reset}\)
Soft reset: the membrane potential will decrease the threshold potential after firing: \(V[t] = V[t] - V_{threshold}\)
We can find that the neuron that uses soft reset does not need the attribute \(V_{reset}\). The default value of v_reset
in the __init__
function of spikingjelly.activation_based.neuron
is 1.0
and the neuron will use hard reset by default.If we set v_reset = None
, then the neuron will use the soft reset. We can find the codes for neuronal reset in spikingjelly.activation_based.neuron.BaseNode.neuronal_fire.neuronal_reset
:
# The following codes are for tutorials. The actual codes are different but have similar behavior.
def neuronal_reset(self):
if self.v_reset is None:
self.v = self.v - self.spike * self.v_threshold
else:
self.v = (1. - self.spike) * self.v + self.spike * self.v_reset
Three equations for describing spiking neurons
Now we can use the three equations: neuronal charge, neuronal fire, and neuronal reset, to describe all kinds of spiking neurons:
where \(\Theta(x)\) is the surrogate_function
in the parameters of __init__
. \(\Theta(x)\) is the heaviside step function:
The hard reset equation is:
The soft reset equation is:
where \(X[t]\) is the external input. To avoid confusion, we use \(H[t]\) to represent the membrane potential after neuronal charging but before neuronal firing. \(V[t]\) is the membrane potential after neuronal firing. \(f(V[t-1], X[t])\) is the neuronal charging function, and is different for different neurons.
The neuronal dynamics can be described by the following figure (the figure is cited from Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks):

Simulation
Now let us give inputs to the spiking neurons step-by-step, check the membrane potential and output spikes, and plot them:
if_layer.reset()
x = torch.as_tensor([0.02])
T = 150
s_list = []
v_list = []
for t in range(T):
s_list.append(if_layer(x))
v_list.append(if_layer.v)
dpi = 300
figsize = (12, 8)
visualizing.plot_one_neuron_v_s(torch.cat(v_list).numpy(), torch.cat(s_list).numpy(), v_threshold=if_layer.v_threshold,
v_reset=if_layer.v_reset,
figsize=figsize, dpi=dpi)
plt.show()
The input has shape=[1]
. So, there is only 1 neuron. Its membrane potential and output spikes are:
Reset the neurons layer, and give the input with shape=[32]
. Then we can check the membrane potential and output spikes of these 32 neurons:
if_layer.reset()
T = 50
x = torch.rand([32]) / 8.
s_list = []
v_list = []
for t in range(T):
s_list.append(if_layer(x).unsqueeze(0))
v_list.append(if_layer.v.unsqueeze(0))
s_list = torch.cat(s_list)
v_list = torch.cat(v_list)
figsize = (12, 8)
dpi = 200
visualizing.plot_2d_heatmap(array=v_list.numpy(), title='membrane potentials', xlabel='simulating step',
ylabel='neuron index', int_x_ticks=True, x_max=T, figsize=figsize, dpi=dpi)
visualizing.plot_1d_spikes(spikes=s_list.numpy(), title='membrane sotentials', xlabel='simulating step',
ylabel='neuron index', figsize=figsize, dpi=dpi)
plt.show()
The results are:
Step mode and backend
We have introduced step modes in Basic Conception. In the above codes, we use the single-step mode. By setting step_mode
, we can switch to multi-step easily:
import torch
from spikingjelly.activation_based import neuron, functional
if_layer = neuron.IFNode(step_mode='s')
T = 8
N = 2
x_seq = torch.rand([T, N])
y_seq = functional.multi_step_forward(x_seq, if_layer)
if_layer.reset()
if_layer.step_mode = 'm'
y_seq = if_layer(x_seq)
if_layer.reset()
In addition, some neurons support for cupy
backend when using multi-step mode. cupy
backend can accelerate forward and backward:
import torch
from spikingjelly.activation_based import neuron
if_layer = neuron.IFNode()
print(f'if_layer.backend={if_layer.backend}')
# if_layer.backend=torch
print(f'step_mode={if_layer.step_mode}, supported_backends={if_layer.supported_backends}')
# step_mode=s, supported_backends=('torch',)
if_layer.step_mode = 'm'
print(f'step_mode={if_layer.step_mode}, supported_backends={if_layer.supported_backends}')
# step_mode=m, supported_backends=('torch', 'cupy')
device = 'cuda:0'
if_layer.to(device)
if_layer.backend = 'cupy' # switch to the cupy backend
print(f'if_layer.backend={if_layer.backend}')
# if_layer.backend=cupy
x_seq = torch.rand([8, 4], device=device)
y_seq = if_layer(x_seq)
if_layer.reset()
Custom Spiking Neurons
As mentioned above, SpikingJelly uses three equations: neuronal change, neuronal fire, and neuronal reset, to describe all kinds of spiking neurons.We can find the corresponding codes in BaseNode
. The forward of single-step, which is the single_step_forward
function, is composed of the three equations:
# spikingjelly.activation_based.neuron.BaseNode
def single_step_forward(self, x: torch.Tensor):
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
neuronal_fire
and neuronal_reset
are same for most spiking neurons, and are defined by BaseNode
. The difference of neurons are __init__
and neuronal_charge
functions.Hence, if we want to implement a new kind of spiking neuron, we only need to change the __init__
and neuronal_charge
functions.
Suppose we want to build a Square-Integrated-and-Fire neuron, whose neuronal charge equation is:
We can implement this kind of neuron by the following codes:
import torch
from spikingjelly.activation_based import neuron
class SquareIFNode(neuron.BaseNode):
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x ** 2
BaseNode
is inherited from MemoryModule
, which uses for t in range(T)
to call single-step forward function to implement the multi-step forward by default. So, after we define the neuronal_charge
, then single_step_forward
is completed, and multi_step_forward
is also completed.
Use our SquareIFNode
to implement the single/multi-step forward:
import torch
from spikingjelly.activation_based import neuron
class SquareIFNode(neuron.BaseNode):
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x ** 2
sif_layer = SquareIFNode()
T = 4
N = 1
x_seq = torch.rand([T, N])
print(f'x_seq={x_seq}')
for t in range(T):
yt = sif_layer(x_seq[t])
print(f'sif_layer.v[{t}]={sif_layer.v}')
sif_layer.reset()
sif_layer.step_mode = 'm'
y_seq = sif_layer(x_seq)
print(f'y_seq={y_seq}')
sif_layer.reset()
The outputs are:
x_seq=tensor([[0.7452],
[0.8062],
[0.6730],
[0.0942]])
sif_layer.v[0]=tensor([0.5554])
sif_layer.v[1]=tensor([0.])
sif_layer.v[2]=tensor([0.4529])
sif_layer.v[3]=tensor([0.4618])
y_seq=tensor([[0.],
[1.],
[0.],
[0.]])
Surrogate Gradient Method
Author: fangwei123456
As mentioned in Neuron, the Heaviside function \(S[t] = \Theta(H[t] - V_{threshold})\) is used to describe the neuronal firing.The Heaviside function is:
Its derivative is the unit impulse function, which is defined by:
If we use the unit impulse function to calculate the gradient and apply the gradient descent, the training will be very unstable. To solve this problem, the surrogate gradient method is proposed. Refer to Surrogate Gradient Learning in Spiking Neural Networks for more details.
The surrogate function is used to generate spikes, which can be found in the codes of BaseNode.neuronal_fire
:
# spikingjelly.activation_based.neuron
class BaseNode(base.MemoryModule):
def __init__(..., surrogate_function: Callable = surrogate.Sigmoid(), ...)
# ...
self.surrogate_function = surrogate_function
# ...
def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
The surrogate gradient method uses \(y = \Theta(x)\) in forward and \(\frac{\mathrm{d}y}{\mathrm{d}x} = \sigma'(x)\), rather than \(\frac{\mathrm{d}y}{\mathrm{d}x} = \Theta'(x)\) in backward, where \(\sigma(x)\) is the surrogate function. In most cases, \(\sigma(x)\) is a continuous and smooth function whose shape is similar to \(\Theta(x)\).spikingjelly.activation_based.surrogate
provides many frequently-used surrogate functions. For example, the Sigmoid function spikingjelly.activation_based.surrogate.Sigmoid
is \(\sigma(x, \alpha) = \frac{1}{1 + \exp(-\alpha x)}\).The following figure shows the primitive Heaviside function, the sigmoid function when alpha=5
and its gradient:
We can use the surrogate function easily, just as we use other functions:
import torch
from spikingjelly.activation_based import surrogate
sg = surrogate.Sigmoid(alpha=4.)
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = sg(x)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
The outputs are:
x=tensor([-0.1303, 0.4976, 0.3364, 0.4296, 0.2779, 0.4580, 0.4447, 0.2466],
requires_grad=True)
y=tensor([0., 1., 1., 1., 1., 1., 1., 1.], grad_fn=<sigmoidBackward>)
x.grad=tensor([0.9351, 0.4231, 0.6557, 0.5158, 0.7451, 0.4759, 0.4943, 0.7913])
All surrogate functions have a module style API, e.g., spikingjelly.activation_based.surrogate.Sigmoid
, and a functional style API, e.g., spikingjelly.activation_based.surrogate.sigmoid
.The module style API uses Camel-Case to name modules, while the functional API uses Snake-Case to name functions. Their relation are similar to torch.nn
and torch.nn.functional
.Here are some examples:
module |
function |
---|---|
|
|
|
|
|
|
Here is an example of using the functional API:
import torch
from spikingjelly.activation_based import surrogate
alpha = 4.
x = torch.rand([8]) - 0.5
x.requires_grad = True
y = surrogate.sigmoid.apply(x, alpha)
y.sum().backward()
print(f'x={x}')
print(f'y={y}')
print(f'x.grad={x.grad}')
Most surrogate functions have one or many hyper-parameters to control the shape, e.g., alpha
of spikingjelly.activation_based.surrogate.Sigmoid
. In SpikingJelly, the default shape hyper-parameters are set to make the maximum of the surrogate function’s gradient to be 1, which can relieve the gradient vanishing or exploding problem caused by the cumulative product of gradients.
Monitor
Author: fangwei123456
spikingjelly.activation_based.monitor
has defined some commonly used monitors, with which the users can record the data that they are interested in. Now let us try these monitors.
Usage
All monitors have similar usage. Let us take spikingjelly.activation_based.monitor.OutputMonitor
as the example.
Firstly, let us build a simple single-step network. To avoid no spikes, we set all weights to be positive:
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode)
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
net(x_seq)
The recorded data will be stored in .records
whose type is list
. The data are recorded by the order in how they are created:
print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')
The outputs are:
spike_seq_monitor.records=
[tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]]), tensor([[[0., 0.]],
[[1., 0.]],
[[0., 1.]],
[[1., 0.]]])]
We can also use the index to get the i
-th data:
print(f'spike_seq_monitor[0]={spike_seq_monitor[0]}')
The outputs are:
spike_seq_monitor[0]=tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]])
The names of monitored layers are stored in .monitored_layers
:
print(f'net={net}')
print(f'spike_seq_monitor.monitored_layers={spike_seq_monitor.monitored_layers}')
The outputs are:
net=Sequential(
(0): Linear(in_features=8, out_features=4, bias=True)
(1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): Linear(in_features=4, out_features=2, bias=True)
(3): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
spike_seq_monitor.monitored_layers=['1', '3']
We can also use the name as the index to get the recorded data of the layer, which are stored in a list
:
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")
The outputs are:
spike_seq_monitor['1']=[tensor([[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]],
[[0., 0., 0., 0.]],
[[1., 1., 1., 1.]]])]
We can call .clear_recorded_data()
to clear the recorded data:
spike_seq_monitor.clear_recorded_data()
print(f'spike_seq_monitor.records={spike_seq_monitor.records}')
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")
The outputs are:
spike_seq_monitor.records=[]
spike_seq_monitor['1']=[]
All monitor
will remove hooks when they are deleted. However, python will not guarantee to call the __del__()
function of the monitor even if we call del a_monitor
manually:
del spike_seq_monitor
# hooks may still work
Instead, we should call remove_hooks
to remove all hooks:
spike_seq_monitor.remove_hooks()
OutputMonitor
can also process the data when recording, which is implemented by function_on_output
. The default value of function_on_output
is lambda x: x
, which means record the origin data. If we want to record the firing rates, we can define the function of calculating the firing rates:
def cal_firing_rate(s_seq: torch.Tensor):
# s_seq.shape = [T, N, *]
return s_seq.flatten(1).mean(1)
Then, we can set this function as function_on_output
to get a firing rates monitor:
fr_monitor = monitor.OutputMonitor(net, neuron.IFNode, cal_firing_rate)
.disable()
can pause monitor
, and .enable()
can restart monitor
:
with torch.no_grad():
fr_monitor.disable()
net(x_seq)
functional.reset_net(net)
print(f'after call fr_monitor.disable(), fr_monitor.records=\n{fr_monitor.records}')
fr_monitor.enable()
net(x_seq)
print(f'after call fr_monitor.enable(), fr_monitor.records=\n{fr_monitor.records}')
functional.reset_net(net)
del fr_monitor
The outputs are:
after call fr_monitor.disable(), fr_monitor.records=
[]
after call fr_monitor.enable(), fr_monitor.records=
[tensor([0.0000, 1.0000, 0.5000, 1.0000]), tensor([0., 1., 0., 1.])]
Record Attributes
To record the attributes of some modules, e.g., the membrane potential, we can use spikingjelly.activation_based.monitor.AttributeMonitor
.
store_v_seq: bool = False
is the default arg in __init__
of spiking neurons, which means only v
at the last time-step will be stored, and v_seq
at each time-step will not be sotred. To record all \(V[t]\), we set store_v_seq = True
:
for m in net.modules():
if isinstance(m, neuron.IFNode):
m.store_v_seq = True
Then, we use spikingjelly.activation_based.monitor.AttributeMonitor
to record:
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.IFNode)
with torch.no_grad():
net(x_seq)
print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
functional.reset_net(net)
del v_seq_monitor
The outputs are:
v_seq_monitor.records=
[tensor([[[0.8102, 0.8677, 0.8153, 0.9200]],
[[0.0000, 0.0000, 0.0000, 0.0000]],
[[0.0000, 0.8129, 0.0000, 0.9263]],
[[0.0000, 0.0000, 0.0000, 0.0000]]]), tensor([[[0.2480, 0.4848]],
[[0.0000, 0.0000]],
[[0.8546, 0.6674]],
[[0.0000, 0.0000]]])]
Record Inputs
To record inputs, we can use spikingjelly.activation_based.monitor.InputMonitor
, which is similar to spikingjelly.activation_based.monitor.OutputMonitor
:
input_monitor = monitor.InputMonitor(net, neuron.IFNode)
with torch.no_grad():
net(x_seq)
print(f'input_monitor.records=\n{input_monitor.records}')
functional.reset_net(net)
del input_monitor
The outputs are:
input_monitor.records=
[tensor([[[1.1710, 0.7936, 0.9325, 0.8227]],
[[1.4373, 0.7645, 1.2167, 1.3342]],
[[1.6011, 0.9850, 1.2648, 1.2650]],
[[0.9322, 0.6143, 0.7481, 0.9770]]]), tensor([[[0.8072, 0.7733]],
[[1.1186, 1.2176]],
[[1.0576, 1.0153]],
[[0.4966, 0.6030]]])]
Record the Input Gradients \(\frac{\partial L}{\partial Y}\)
We can use spikingjelly.activation_based.monitor.GradOutputMonitor
to record the input gradients \(\frac{\partial L}{\partial S}\) of each module:
spike_seq_grad_monitor = monitor.GradOutputMonitor(net, neuron.IFNode)
net(x_seq).sum().backward()
print(f'spike_seq_grad_monitor.records=\n{spike_seq_grad_monitor.records}')
functional.reset_net(net)
del spike_seq_grad_monitor
The outputs are:
spike_seq_grad_monitor.records=
[tensor([[[1., 1.]],
[[1., 1.]],
[[1., 1.]],
[[1., 1.]]]), tensor([[[ 0.0803, 0.0383, 0.1035, 0.1177]],
[[-0.1013, -0.1346, -0.0561, -0.0085]],
[[ 0.5364, 0.6285, 0.3696, 0.1818]],
[[ 0.3704, 0.4747, 0.2201, 0.0596]]])]
Note that the input gradients of the last layer’s output spikes are all 1
because we use .sum().backward()
.
Record the Output Gradients \(\frac{\partial L}{\partial X}\)
We can use spikingjelly.activation_based.monitor.GradInputMonitor
to record the output gradients \(\frac{\partial L}{\partial X}\) of each module.
Let us build a deep SNN, tune alpha
for surrogate functions, and compare the effect:
import torch
import torch.nn as nn
from spikingjelly.activation_based import monitor, neuron, functional, layer, surrogate
net = []
for i in range(10):
net.append(layer.Linear(8, 8))
net.append(neuron.IFNode())
net = nn.Sequential(*net)
functional.set_step_mode(net, 'm')
T = 4
N = 1
x_seq = torch.rand([T, N, 8])
input_grad_monitor = monitor.GradInputMonitor(net, neuron.IFNode, function_on_grad_input=torch.norm)
for alpha in [0.1, 0.5, 2, 4, 8]:
for m in net.modules():
if isinstance(m, surrogate.Sigmoid):
m.alpha = alpha
net(x_seq).sum().backward()
print(f'alpha={alpha}, input_grad_monitor.records=\n{input_grad_monitor.records}\n')
functional.reset_net(net)
# zero grad
for param in net.parameters():
param.grad.zero_()
input_grad_monitor.records.clear()
The outputs are:
alpha=0.1, input_grad_monitor.records=
[tensor(0.3868), tensor(0.0138), tensor(0.0003), tensor(9.1888e-06), tensor(1.0164e-07), tensor(1.9384e-09), tensor(4.0199e-11), tensor(8.6942e-13), tensor(1.3389e-14), tensor(2.7714e-16)]
alpha=0.5, input_grad_monitor.records=
[tensor(1.7575), tensor(0.2979), tensor(0.0344), tensor(0.0045), tensor(0.0002), tensor(1.5708e-05), tensor(1.6167e-06), tensor(1.6107e-07), tensor(1.1618e-08), tensor(1.1097e-09)]
alpha=2, input_grad_monitor.records=
[tensor(3.3033), tensor(1.2917), tensor(0.4673), tensor(0.1134), tensor(0.0238), tensor(0.0040), tensor(0.0008), tensor(0.0001), tensor(2.5466e-05), tensor(3.9537e-06)]
alpha=4, input_grad_monitor.records=
[tensor(3.5353), tensor(1.6377), tensor(0.7076), tensor(0.2143), tensor(0.0369), tensor(0.0069), tensor(0.0026), tensor(0.0006), tensor(0.0003), tensor(8.5736e-05)]
alpha=8, input_grad_monitor.records=
[tensor(4.3944), tensor(2.4396), tensor(0.8996), tensor(0.4376), tensor(0.0640), tensor(0.0122), tensor(0.0053), tensor(0.0016), tensor(0.0013), tensor(0.0005)]
Single Fully Connected Layer SNN to Classify MNIST
Author: Yanqi-Chen
Translator: Lv Liuzhenghao
The tutorial will introduce how to train a simple SNN using the encoder and the surrogate gradient method to classify MNIST.
Build a simple SNN
When building neural networks with PyTorch, we can simply use nn.Sequential
to stack layers to get a feedforward network, where input data will flow through each layer in order to get the output.
MNIST dataset contains 8-bit grey-scale images whose size is \(28\times 28\) and category is from 1 to 10. A simple single layer ANN to classify MNIST is as follows:
nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 10, bias=False),
nn.Softmax()
)
A SNN with similar structures can also be used for classification tasks. For this network, all activation functions should be replaced with spiking neurons (LIF neurons are used here), and the connetions between neurons should be packaged with spikingjelly.activation_based.layer
:
nn.Sequential(
layer.Flatten(),
layer.Linear(28 * 28, 10, bias=False),
neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
)
The membrane potential constant \(\tau\) is set by tau
, and surrogate.ATan
is used as the surrogate gradient function.
Train the SNN
Training parameters like learning rate and other configurations need to be set:
Adam is used as the optimizer by default, and the poisson encoder is used to encode input images as spikes.
# Use Adam optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
# Use PoissonEncoder
encoder = encoding.PoissonEncoder()
There are three key points to follow when programing training codes:
The output of the spiking neuron is binary, and the output of a single run is easily disturbed by the noise caused by coding. Therefore, it is generally to consider the firing rate of the output layer in a period of time as the output of SNN. The value of the firing rate indicates the response intensity of the corresponding category. So we should run the SNN for a period of time
T
and take the average firing rate inT
as classifying evidence.The ideal outcome is that except for the proper neurons firing at the highest rate, the other neurons keep silent. Cross-entropy loss or MSE loss is often used. Here we use MSE loss for its better effect.
After each network simulation, the network state should be reset by
functional.reset_net(net)
.
The core training codes are as follows:
for epoch in range(start_epoch, args.epochs):
start_time = time.time()
net.train()
train_loss = 0
train_acc = 0
train_samples = 0
for img, label in train_data_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
label_onehot = F.one_hot(label, 10).float()
# Mixed-precision training
if scaler is not None:
with amp.autocast():
out_fr = 0.
# Run T time steps
for t in range(args.T):
encoded_img = encoder(img)
out_fr += net(encoded_img)
out_fr = out_fr / args.T
# out_fr is tensor whose shape is [batch_size, 10]
# The firing rate of 10 neurons in the output layer was recorded during the whole simulation period
loss = F.mse_loss(out_fr, label_onehot)
# The loss function is the MSE between the firing rate of the output layer and the true category.
# The loss function will cause the firing rate of the correct neuron in the output layer to approach 1 when the label i is given, and the firing rate of the other neurons to approach 0.
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
out_fr = 0.
for t in range(args.T):
encoded_img = encoder(img)
out_fr += net(encoded_img)
out_fr = out_fr / args.T
loss = F.mse_loss(out_fr, label_onehot)
loss.backward()
optimizer.step()
train_samples += label.numel()
train_loss += loss.item() * label.numel()
# The correct rate is calculated as follows. The subscript i of the neuron with the highest firing rate in the output layer is considered as the result of classification.
train_acc += (out_fr.argmax(1) == label).float().sum().item()
# After optimizing the parameters, the state of the network should be reset because the neurons of the SNN have “memory”.
functional.reset_net(net)
The complete code is in activation_based.examples.lif_fc_mnist.py
, where Tensorboard is used to save training logs. It can be run in the command line as follows:
$ python -m spikingjelly.activation_based.examples.lif_fc_mnist --help
usage: lif_fc_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N]
[-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-opt {sgd,adam}]
[-momentum MOMENTUM] [-lr LR] [-tau TAU]
LIF MNIST Training
optional arguments:
-h, --help show this help message and exit
-T T simulating time-steps
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-opt {sgd,adam} use which optimizer. SGD or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
-tau TAU parameter tau of LIF neuron
It should be noted that the amount of memory required to train such an SNN is linearly related to the simulation time T
.
A larger T
is equivalent to using a smaller simulation time step, and the training is more “refined” but not necessarily better. When T
is too large, the SNN unfolds in time and becomes a very deep network,
which will cause BPTT to decay or explode when calculating the gradient.
In addition, since we use the poisson encoder, a large T
is needed to ensure that the coding noise is not too large.
Results of Training
We set tau=2.0,T=100,batch_size=64,lr=1e-3
, the corresponding command is:
python -m spikingjelly.activation_based.examples.lif_fc_mnist -tau 2.0 -T 100 -device cuda:0 -b 64 -epochs 100 -data-dir <PATH to MNIST> -amp -opt adam -lr 1e-3 -j 8
In order to speed up training, mixed precision training is used. After 100 Epoch training, two npy files and a training log are output. The highest accuracy on the test dataset is 92.9%. The accuracy curve visualized by matplotlib is as follows:
Select the first image in the test dataset:

The classification results are obtained by using the trained model:
Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]
Voltages and spikes are as follows, which are gotten by the visualization function in the visualizing
module.
Obviously, except for the corresponding neuron in the correct category, no other neurons are firing. The complete training code is in activation_based/examples/lif_fc_mnist.py .
Convolutional SNN to Classify FMNIST
Author: fangwei123456
In this tutorial, we will build a convolutional SNN to classify the Fashion-MNIST dataset. Images in the Fashion-MNIST dataset have the same shape as these in the MNIST dataset, which is 1 * 28 * 28
.
Network Structure
We use the common convolutional network structure. More specifically, the network structure is:
{Conv2d-BatchNorm2d-IFNode-MaxPool2d}-{Conv2d-BatchNorm2d-IFNode-MaxPool2d}-{Linear-IFNode}
We build the network like the following codes:
# spikingjelly.activation_based.examples.conv_fashion_mnist
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from spikingjelly.activation_based import neuron, functional, surrogate, layer
from torch.utils.tensorboard import SummaryWriter
import os
import time
import argparse
from torch.cuda import amp
import sys
import datetime
from spikingjelly import visualizing
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
super().__init__()
self.T = T
self.conv_fc = nn.Sequential(
layer.Conv2d(1, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 14 * 14
layer.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
layer.BatchNorm2d(channels),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.MaxPool2d(2, 2), # 7 * 7
layer.Flatten(),
layer.Linear(channels * 7 * 7, channels * 4 * 4, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(channels * 4 * 4, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan()),
)
For faster training speed, we use the multi-step mode and use the cupy
backend if specified by use_cupy
in __init__
:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
def __init__(self, T: int, channels: int, use_cupy=False):
# ...
functional.set_step_mode(self, step_mode='m')
if use_cupy:
functional.set_backend(self, backend='cupy')
Recently, sending the image to SNN directly is a popular method in deep SNNs, which we will also use in this tutorial. In this case, the image-spike
encoding is implemented by the first three layers of the network, which are {Conv2d-BatchNorm2d-IFNode}
.
The input image has shape=[N, C, H, W]
. We add an additional time-step dimension, repeat it T
times, and get the input sequence with shape=[T, N, C, H, W]
. The output is defined by the firing rate of the last spiking neurons layer. Thus, the forward function is defined by:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
x_seq = self.conv_fc(x_seq)
fr = x_seq.mean(0)
return fr
Training
How to define the training method, loss function, and classification result are identical to the last tutorial, and we will not introduce them in this tutorial. The only difference is we use the Fashion-MNIST dataset:
# spikingjelly.activation_based.examples.conv_fashion_mnist
train_set = torchvision.datasets.FashionMNIST(
root=args.data_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True)
test_set = torchvision.datasets.FashionMNIST(
root=args.data_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
We can use the following commands to print the training args:
(sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h
usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR] [-channels CHANNELS]
Classify Fashion-MNIST
optional arguments:
-h, --help show this help message and exit
-T T simulating time-steps
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of Fashion-MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-cupy use cupy backend
-opt OPT use which optimizer. SDG or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
-channels CHANNELS channels of CSNN
-save-es SAVE_ES dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}
We can use the following commands to train. For faster training speed, we enable the AMP (automatic mixed precision) and the cupy
backend:
python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 128 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8
The outputs are:
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
CSNN(
(conv_fc): Sequential(
(0): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(6): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(8): Flatten(start_dim=1, end_dim=-1, step_mode=m)
(9): Linear(in_features=6272, out_features=2048, bias=False)
(10): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(11): Linear(in_features=2048, out_features=10, bias=False)
(12): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=cupy
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
Mkdir ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy.
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =0, train_loss = 0.0325, train_acc = 0.7875, test_loss = 0.0248, test_acc = 0.8543, max_test_acc = 0.8543
train speed = 7109.7899 images/s, test speed = 7936.2602 images/s
escape time = 2022-05-24 21:42:15
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =1, train_loss = 0.0217, train_acc = 0.8734, test_loss = 0.0201, test_acc = 0.8758, max_test_acc = 0.8758
train speed = 7712.5343 images/s, test speed = 7902.5029 images/s
escape time = 2022-05-24 21:43:13
...
Namespace(T=4, device='cuda:0', b=256, epochs=64, j=8, data_dir='/datasets/FashionMNIST/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='sgd', momentum=0.9, lr=0.1, channels=128)
./logs/T4_b256_sgd_lr0.1_c128_amp_cupy
epoch =63, train_loss = 0.0024, train_acc = 0.9941, test_loss = 0.0113, test_acc = 0.9283, max_test_acc = 0.9308
train speed = 7627.8147 images/s, test speed = 7868.9090 images/s
escape time = 2022-05-24 21:42:16
We get max_test_acc = 0.9308
. If we fine-tune the hyper-parameters, we will get higher accuracy.
The following figure shows the accuracy curves during training:
Visualizing Encoding
As mentioned above, we send images to SNN directly, and the encoding is implemented by the first {Conv2d-BatchNorm2d-IFNode}
in the SNN. Now let us extract the encoder {Conv2d-BatchNorm2d-IFNode}
, give images to the encoder, and visualize the output spikes:
# spikingjelly.activation_based.examples.conv_fashion_mnist
class CSNN(nn.Module):
# ...
def spiking_encoder(self):
return self.conv_fc[0:3]
def main():
# ...
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
net.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch'] + 1
max_test_acc = checkpoint['max_test_acc']
if args.save_es is not None and args.save_es != '':
encoder = net.spiking_encoder()
with torch.no_grad():
for img, label in test_data_loader:
img = img.to(args.device)
label = label.to(args.device)
# img.shape = [N, C, H, W]
img_seq = img.unsqueeze(0).repeat(net.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
spike_seq = encoder(img_seq)
functional.reset_net(encoder)
to_pil_img = torchvision.transforms.ToPILImage()
vs_dir = os.path.join(args.save_es, 'visualization')
os.mkdir(vs_dir)
img = img.cpu()
spike_seq = spike_seq.cpu()
img = F.interpolate(img, scale_factor=4, mode='bilinear')
# 28 * 28 is too small to read. So, we interpolate it to a larger size
for i in range(label.shape[0]):
vs_dir_i = os.path.join(vs_dir, f'{i}')
os.mkdir(vs_dir_i)
to_pil_img(img[i]).save(os.path.join(vs_dir_i, f'input.png'))
for t in range(net.T):
print(f'saving {i}-th sample with t={t}...')
# spike_seq.shape = [T, N, C, H, W]
visualizing.plot_2d_feature_map(spike_seq[t][i], 8, spike_seq.shape[2] // 8, 2, f'$S[{t}]$')
plt.savefig(os.path.join(vs_dir_i, f's_{t}.png'))
plt.savefig(os.path.join(vs_dir_i, f's_{t}.pdf'))
plt.savefig(os.path.join(vs_dir_i, f's_{t}.svg'))
plt.clf()
exit()
# ...
Let us load the trained model, set batch_size=4
, which means we only save 4 images and their spikes, and save data in ./logs
. The running commands are:
python -m spikingjelly.activation_based.examples.conv_fashion_mnist -T 4 -device cuda:0 -b 4 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -resume ./logs/T4_b256_sgd_lr0.1_c128_amp_cupy/checkpoint_latest.pth -save-es ./logs
Images and spikes will be saved in ./logs/visualization
. Here are two images and spikes encoded from them:


Neuromorphic Datasets Processing
Authors: fangwei123456
spikingjelly.datasets
provides frequently-used neuromorphic datasets, including N-MNIST 1, CIFAR10-DVS 2, DVS128 Gesture 3, N-Caltech101 1, ASLDVS 4, etc. All datasets are processed by SpikingJelly in the same method, which is friendly for developers to write codes for new datasets. In this tutorial, we will take DVS 128 Gesture dataset as an example to show how to use SpikingJelly to process neuromorphic datasets.
Download Automatically/Manually
SpikingJelly can download some datasets (e.g., CIFAR10-DVS) automatically. When we first use these datasets, SpikingJelly
will download the dataset to download
in the root directory. The downloadable()
function of each dataset defines
whether this dataset can be downloaded automatically, and the resource_url_md5()
function defines the download url and
MD5 of each file. Here is an example:
from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
print('CIFAR10-DVS downloadable', CIFAR10DVS.downloadable())
print('resource, url, md5/n', CIFAR10DVS.resource_url_md5())
print('DVS128Gesture downloadable', DVS128Gesture.downloadable())
print('resource, url, md5/n', DVS128Gesture.resource_url_md5())
The outputs are:
CIFAR10-DVS downloadable True
resource, url, md5
[('airplane.zip', 'https://ndownloader.figshare.com/files/7712788', '0afd5c4bf9ae06af762a77b180354fdd'), ('automobile.zip', 'https://ndownloader.figshare.com/files/7712791', '8438dfeba3bc970c94962d995b1b9bdd'), ('bird.zip', 'https://ndownloader.figshare.com/files/7712794', 'a9c207c91c55b9dc2002dc21c684d785'), ('cat.zip', 'https://ndownloader.figshare.com/files/7712812', '52c63c677c2b15fa5146a8daf4d56687'), ('deer.zip', 'https://ndownloader.figshare.com/files/7712815', 'b6bf21f6c04d21ba4e23fc3e36c8a4a3'), ('dog.zip', 'https://ndownloader.figshare.com/files/7712818', 'f379ebdf6703d16e0a690782e62639c3'), ('frog.zip', 'https://ndownloader.figshare.com/files/7712842', 'cad6ed91214b1c7388a5f6ee56d08803'), ('horse.zip', 'https://ndownloader.figshare.com/files/7712851', 'e7cbbf77bec584ffbf913f00e682782a'), ('ship.zip', 'https://ndownloader.figshare.com/files/7712836', '41c7bd7d6b251be82557c6cce9a7d5c9'), ('truck.zip', 'https://ndownloader.figshare.com/files/7712839', '89f3922fd147d9aeff89e76a2b0b70a7')]
DVS128Gesture downloadable False
resource, url, md5
[('DvsGesture.tar.gz', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '8a5c71fb11e24e5ca5b11866ca6c00a1'), ('gesture_mapping.csv', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '109b2ae64a0e1f3ef535b18ad7367fd1'), ('LICENSE.txt', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', '065e10099753156f18f51941e6e44b66'), ('README.txt', 'https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794', 'a0663d3b1d8307c329a43d949ee32d19')]
The DVS128 Gesture dataset can not be downloaded automatically. But its resource_url_md5()
will tell the user where to
download. The DVS128 Gesture dataset can be downloaded from https://ibm.ent.box.com/s/3hiq58ww1pbbjrinh367ykfdf60xsfm8/folder/50167556794.
The box website does not allow us to download data by python codes without login. Thus, the user has to download it manually.
Suppose we have downloaded the dataset into E:/datasets/DVS128Gesture/download
, then the directory structure is
.
|-- DvsGesture.tar.gz
|-- LICENSE.txt
|-- README.txt
`-- gesture_mapping.csv
Note
Different frameworks may use different pre-processing methods on the DVS128 Gesture dataset and cause different samples number. Refer to the API doc of spikingjelly.datasets.dvs128_gesture.DVS128Gesture
for more details.
Get Events Data
Let us create a train set. We set data_type='event'
to use Event data rather than frame data.
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
train_set = DVS128Gesture(root_dir, train=True, data_type='event')
SpikingJelly will do the followed work when running these codes:
Check whether the dataset exists. If the dataset exists, check MD5 to ensure the dataset is complete. Then SpikingJelly will extract the original data into the
extracted
folderThe sample in DVS128 Gesture is the video that records one actor displaying different gestures under different illumination conditions. Hence, an AER sample contains many gestures and there is also an adjoint csv file to label the time stamp of each gesture. Hence, an AER sample is not a sample with one class but multi-classes. SpikingJelly will use multi-threads to cut and extract each gesture from these files.
Here are the terminal outputs:
The [D:/datasets/DVS128Gesture/download] directory for saving downloaded files already exists, check files...
Mkdir [D:/datasets/DVS128Gesture/extract].
Extract [D:/datasets/DVS128Gesture/download/DvsGesture.tar.gz] to [D:/datasets/DVS128Gesture/extract].
Mkdir [D:/datasets/DVS128Gesture/events_np].
Start to convert the origin data from [D:/datasets/DVS128Gesture/extract] to [D:/datasets/DVS128Gesture/events_np] in np.ndarray format.
Mkdir [('D:/datasets/DVS128Gesture//events_np//train', 'D:/datasets/DVS128Gesture//events_np//test').
Mkdir ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9'] in [D:/datasets/DVS128Gesture/events_np/train] and ['0', '1', '10', '2', '3', '4', '5', '6', '7', '8', '9'] in [D:/datasets/DVS128Gesture/events_np/test].
Start the ThreadPoolExecutor with max workers = [8].
Start to split [D:/datasets/DVS128Gesture/extract/DvsGesture/user02_fluorescent.aedat] to samples.
[D:/datasets/DVS128Gesture/events_np/train/0/user02_fluorescent_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/train/1/user02_fluorescent_0.npz] saved.
......
[D:/datasets/DVS128Gesture/events_np/test/8/user29_lab_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/test/9/user29_lab_0.npz] saved.
[D:/datasets/DVS128Gesture/events_np/test/10/user29_lab_0.npz] saved.
Used time = [1017.27s].
All aedat files have been split to samples and saved into [('D:/datasets/DVS128Gesture//events_np//train', 'D:/datasets/DVS128Gesture//events_np//test')].
We have to wait for a moment because the cutting and extracting are very slow. A events_np
folder will be created and contain the train/test set:
|-- events_np
| |-- test
| `-- train
Print a sample:
event, label = train_set[0]
for k in event.keys():
print(k, event[k])
print('label', label)
The output is:
t [80048267 80048277 80048278 ... 85092406 85092538 85092700]
x [49 55 55 ... 60 85 45]
y [82 92 92 ... 96 86 90]
p [1 0 0 ... 1 0 0]
label 0
where event
is a dictionary with keys ['t', 'x', 'y', 'p']
;``label`` is the label of the sample. Note that the class number of DVS128 Gesture is 11.
Get Frames Data
The event-to-frame integrating method for pre-processing neuromorphic datasets is widely used. We use the same method from 5 in SpikingJelly. Data in neuromorphic datasets are in the formulation of \(E(x_{i}, y_{i}, t_{i}, p_{i})\) that represent the event’s coordinate, time and polarity. We split the event’s number \(N\) into \(T\) slices with nearly the same number of events in each slice and integrate events to frames. Note that \(T\) is also the simulating time-step. Denote a two channels frame as \(F(j)\) and a pixel at \((p, x, y)\) as \(F(j, p, x, y)\), the pixel value is integrated from the events data whose indices are between \(j_{l}\) and \(j_{r}\):
where \(\lfloor \cdot \rfloor\) is the floor operation, \(\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) is an indicator function and it equals 1 only when \((p, x, y) = (p_{i}, x_{i}, y_{i})\).
SpikingJelly will integrate events to frames when running the followed codes:
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
The outputs from the terminal are:
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/1].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/10].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/2].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/3].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/4].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/5].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/6].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/7].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/8].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/9].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/0].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/1].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/10].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/2].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/3].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/4].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/5].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/6].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/7].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/8].
Mkdir [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9].
Start ThreadPoolExecutor with max workers = [8].
Start to integrate [D:/datasets/DVS128Gesture/events_np/test/0/user24_fluorescent_0.npz] to frames and save to [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
Start to integrate [D:/datasets/DVS128Gesture/events_np/test/0/user24_fluorescent_led_0.npz] to frames and save to [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/test/0].
......
Frames [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9/user23_lab_0.npz] saved.Frames [D:/datasets/DVS128Gesture/frames_number_20_split_by_number/train/9/user23_led_0.npz] saved.
Used time = [102.11s].
A frames_number_20_split_by_number
folder will be created and contain the Frame data.
Print a sample:
frame, label = train_set[0]
print(frame.shape)
The outputs are:
(20, 2, 128, 128)
Let us visualize a sample:
from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)
We will get the images like:

Fixed Duration Integrating
Integrating by fixed duration is more compatible with the practical application. For example, if we set duration as 10 ms
,
then a sample with length L ms
can be integrated to frames with frame number math.floor(L / 10)
. However, the lengths
of samples in neuromorphic datasets are not identical, and we will get frames with different frame numbers when integrating
with a fixed duration. Fortunately, we can use spikingjelly.datasets.pad_sequence_collate
and
spikingjelly.datasets.padded_sequence_mask
to pad/unpad frames.
Example codes:
import torch
from torch.utils.data import DataLoader
from spikingjelly.datasets import pad_sequence_collate, padded_sequence_mask, dvs128_gesture
root='D:/datasets/DVS128Gesture'
train_set = dvs128_gesture.DVS128Gesture(root, data_type='frame', duration=1000000, train=True)
for i in range(5):
x, y = train_set[i]
print(f'x[{i}].shape=[T, C, H, W]={x.shape}')
train_data_loader = DataLoader(train_set, collate_fn=pad_sequence_collate, batch_size=5)
for x, y, x_len in train_data_loader:
print(f'x.shape=[N, T, C, H, W]={tuple(x.shape)}')
print(f'x_len={x_len}')
mask = padded_sequence_mask(x_len) # mask.shape = [T, N]
print(f'mask=\n{mask.t().int()}')
break
The outputs are:
The directory [D:/datasets/DVS128Gesture\duration_1000000] already exists.
x[0].shape=[T, C, H, W]=(6, 2, 128, 128)
x[1].shape=[T, C, H, W]=(6, 2, 128, 128)
x[2].shape=[T, C, H, W]=(5, 2, 128, 128)
x[3].shape=[T, C, H, W]=(5, 2, 128, 128)
x[4].shape=[T, C, H, W]=(7, 2, 128, 128)
x.shape=[N, T, C, H, W]=(5, 7, 2, 128, 128)
x_len=tensor([6, 6, 5, 5, 7])
mask=
tensor([[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)
Custom Integrating Method
SpikingJelly provides user-defined integrating method. The user should provide a function custom_integrate_function
and
the name of directory custom_integrated_frames_dir_name
for saving frames.
custom_integrate_function
is a user-defined function that inputs are events, H, W
.
events
is a dict whose keys are ['t', 'x', 'y', 'p']
and values are numpy.ndarray
. H
is the height of the
data and W
is the weight of the data. For example, H=128 and W=128 for the DVS128 Gesture dataset. The function should
return frames.
custom_integrated_frames_dir_name
can be None
, and then the name of directory for saving frames will be set to custom_integrate_function.__name__
.
For example, if we want to split events to two parts randomly, and integrate two parts to two frames, we can define such a function:
import spikingjelly.datasets as sjds
def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int):
index_split = np.random.randint(low=0, high=events['t'].__len__())
frames = np.zeros([2, 2, H, W])
t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p'))
frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split)
frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__())
return frames
Now let us use this function to create a frames dataset:
train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly)
After the process is finished, there will be a integrate_events_to_2_frames_randomly
directory in root_dir
. And the
integrate_events_to_2_frames_randomly
directory will save our frames integrated by the custom integrating function.
Now let us visualize the frames:
from spikingjelly.datasets import play_frame
frame, label = train_set[500]
play_frame(frame)

SpikingJelly provides more methods to integrate events to frames. Read the API doc for more details.
- 1(1,2)
Orchard, Garrick, et al. “Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.” Frontiers in Neuroscience, vol. 9, 2015, pp. 437–437.
- 2
Li, Hongmin, et al. “CIFAR10-DVS: An Event-Stream Dataset for Object Classification.” Frontiers in Neuroscience, vol. 11, 2017, pp. 309–309.
- 3
Amir, Arnon, et al. “A Low Power, Fully Event-Based Gesture Recognition System.” 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2017, pp. 7388–7397.
- 4
Bi, Yin, et al. “Graph-Based Object Classification for Neuromorphic Vision Sensing.” 2019 IEEE/CVF International Conference on Computer Vision (ICCV), 2019, pp. 491–501.
- 5
Fang, Wei, et al. “Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks.” ArXiv: Neural and Evolutionary Computing, 2020.
Classify DVS Gesture
Author: fangwei123456
Translator: Qiu Haonan, fangwei123456
In Neuromorphic Datasets Processing, we have learned how to use neuromorphic datasets. Let’s build a SNN to classify them.
Network Structure
We will use the network defined in 1, which has the following structure:

All networks in 1 are defined in spikingjelly.activation_based.model.parametric_lif_net
. The network structure for DVS Gesture is:
# spikingjelly.activation_based.model.parametric_lif_net
import torch
import torch.nn as nn
from .. import layer
class DVSGestureNet(nn.Module):
def __init__(self, channels=128, spiking_neuron: callable = None, *args, **kwargs):
super().__init__()
conv = []
for i in range(5):
if conv.__len__() == 0:
in_channels = 2
else:
in_channels = channels
conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False))
conv.append(layer.BatchNorm2d(channels))
conv.append(spiking_neuron(*args, **kwargs))
conv.append(layer.MaxPool2d(2, 2))
self.conv_fc = nn.Sequential(
*conv,
layer.Flatten(),
layer.Dropout(0.5),
layer.Linear(channels * 4 * 4, 512),
spiking_neuron(*args, **kwargs),
layer.Dropout(0.5),
layer.Linear(512, 110),
spiking_neuron(*args, **kwargs),
layer.VotingLayer(10)
)
def forward(self, x: torch.Tensor):
return self.conv_fc(x)
Train
How to define the training method, loss function, and classification result are identical to previous tutorials, and we will not introduce them in this tutorial. We will only introduce the difference.
We use multi-step mode for faster training speed, and use cupy backend if args.cupy:
# spikingjelly.activation_based.examples.classify_dvsg
import torch
import sys
import torch.nn.functional as F
from torch.cuda import amp
from spikingjelly.activation_based import functional, surrogate, neuron
from spikingjelly.activation_based.model import parametric_lif_net
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
import os
import argparse
import datetime
def main():
# ...
net = parametric_lif_net.DVSGestureNet(channels=args.channels, spiking_neuron=neuron.LIFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(net, 'm')
if args.cupy:
functional.set_backend(net, 'cupy', instance=neuron.LIFNode)
# ...
Define the dataset:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
train_set = DVS128Gesture(root=args.data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number')
test_set = DVS128Gesture(root=args.data_dir, train=False, data_type='frame', frames_number=args.T, split_by='number')
# ...
Note that dimension 0 is always the batch dimension for data packed by DataLoader
. So, the data we read from DataLoader
has shape = [N, T, C, H, W]
. We need to reshape the data to shape = [T, N, C, H, W]
for the multi-step mode:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
for epoch in range(start_epoch, args.epochs):
for frame, label in train_data_loader:
optimizer.zero_grad()
frame = frame.to(args.device)
frame = frame.transpose(0, 1) # [N, T, C, H, W] -> [T, N, C, H, W]
# ...
with torch.no_grad():
for frame, label in test_data_loader:
frame = frame.to(args.device)
frame = frame.transpose(0, 1) # [N, T, C, H, W] -> [T, N, C, H, W]
# ...
# ...
DVS Gesture has 11 classes:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
label_onehot = F.one_hot(label, 11).float()
# ...
DVSGestureNet
does not output the pulse frequency, but the original output of shape = [T, N, 11]
:
The networks in spikingjelly.activation_based.model.parametric_lif_net
output spikes, rather than firing rates:
# spikingjelly.activation_based.model.parametric_lif_net
class DVSGestureNet(nn.Module):
# ...
def forward(self, x: torch.Tensor):
return self.conv_fc(x)
Therefore, we need to average the output in the time-step dimension to get the firing rates, and then calculate the loss and accuracy by the firing rates:
# spikingjelly.activation_based.examples.classify_dvsg
def main():
# ...
out_fr = net(frame).mean(0)
loss = F.mse_loss(out_fr, label_onehot)
# ...
Train the network:
python -m spikingjelly.activation_based.examples.classify_dvsg -T 16 -device cuda:0 -b 16 -epochs 64 -data-dir /datasets/DVSGesture/ -amp -cupy -opt adam -lr 0.001 -j 8
The outputs are:
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
DVSGestureNet(
(conv_fc): Sequential(
(0): Conv2d(2, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(2): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(6): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(10): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(13): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(14): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=m)
(17): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=m)
(18): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(19): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False, step_mode=m)
(20): Flatten(start_dim=1, end_dim=-1, step_mode=m)
(21): Dropout(p=0.5)
(22): Linear(in_features=2048, out_features=512, bias=True)
(23): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(24): Dropout(p=0.5)
(25): Linear(in_features=512, out_features=110, bias=True)
(26): LIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=m, backend=cupy, tau=2.0
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(27): VotingLayer(voting_size=10, step_mode=m)
)
)
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
The directory [/datasets/DVSGesture/frames_number_16_split_by_number] already exists.
Mkdir ./logs/T16_b16_adam_lr0.001_c128_amp_cupy.
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 0, train_loss = 0.0666, train_acc = 0.3964, test_loss = 0.0514, test_acc = 0.6042, max_test_acc = 0.6042
train speed = 92.7646 images/s, test speed = 115.2935 images/s
escape time = 2022-05-25 21:31:54
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 1, train_loss = 0.0463, train_acc = 0.6036, test_loss = 0.0439, test_acc = 0.6319, max_test_acc = 0.6319
train speed = 101.5938 images/s, test speed = 120.5184 images/s
escape time = 2022-05-25 21:30:48
...
Namespace(T=16, device='cuda:0', b=16, epochs=64, j=8, data_dir='/datasets/DVSGesture/', out_dir='./logs', resume=None, amp=True, cupy=True, opt='adam', momentum=0.9, lr=0.001, channels=128)
./logs/T16_b16_adam_lr0.001_c128_amp_cupy
epoch = 63, train_loss = 0.0011, train_acc = 0.9991, test_loss = 0.0103, test_acc = 0.9375, max_test_acc = 0.9375
train speed = 100.4324 images/s, test speed = 121.0402 images/s
escape time = 2022-05-25 21:30:51
Finally, max_test_acc = 0.9375
is achieved. Higher accuracy can be achieved if the hyper-parameters are carefully adjusted with more training epochs.
The following figure shows the accuracy curves during the training process:
Recurrent Connection and Stateful Synapse
Author: fangwei123456
Self-connected Modules
Recurrent connection is the connection from outputs to inputs, e.g., the SRNN(recurrent networks of spiking neurons) in 1, which are shown in the following figure:

We can add recurrent connection to modules by SpikingJelly easily. Considering the most simple case that we add a recurrent connection to the spiking neruons layer to make its outputs \(s[t]\) at time-step \(t\) add to the external input \(x[t+1]\) as the input to the neuron at the next time-step. We can use spikingjelly.activation_based.layer.ElementWiseRecurrentContainer
to implement this idea.ElementWiseRecurrentContainer
is a container that add a recurrent connection to any sub_module
.The connection can be specified as a user-defined element-wise operation \(z=f(x, y)\). Denote \(x[t]\) as the external input for the whole module (container and sub_module
) at time-step \(t\), \(i[t]\) and \(y[t]\) are the input and output of sub_module
(note that \(y[t]\) is also the outputs of the whole module), then we can get
where \(f\) is a user-defined element-wise function. We regard \(y[-1] = 0\).
Let us use ElementWiseRecurrentContainer
to wrap one IF neuron. We set the element-wise function as addition:
The external intpus are \(x[t]=[1.5, 0, ..., 0]\):
import torch
from spikingjelly.activation_based import layer, functional, neuron
T = 8
N = 1
def element_wise_add(x, y):
return x + y
net = layer.ElementWiseRecurrentContainer(neuron.IFNode(), element_wise_add)
print(net)
x = torch.zeros([T, N])
x[0] = 1.5
for t in range(T):
print(t, f'x[t]={x[t]}, s[t]={net(x[t])}')
functional.reset_net(net)
The outputs are:
ElementWiseRecurrentContainer(
element-wise function=<function element_wise_add at 0x00000158FC15ACA0>, step_mode=s
(sub_module): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
0 x[t]=tensor([1.5000]), s[t]=tensor([1.])
1 x[t]=tensor([0.]), s[t]=tensor([1.])
2 x[t]=tensor([0.]), s[t]=tensor([1.])
3 x[t]=tensor([0.]), s[t]=tensor([1.])
4 x[t]=tensor([0.]), s[t]=tensor([1.])
5 x[t]=tensor([0.]), s[t]=tensor([1.])
6 x[t]=tensor([0.]), s[t]=tensor([1.])
7 x[t]=tensor([0.]), s[t]=tensor([1.])
We can find that even when \(t \ge 1\), \(x[t]=0\), the neuron can still fire spikes because of the recurrent connection.
We can use spikingjelly.activation_based.layer.LinearRecurrentContainer
to implement the more complex recurrent connection.
Stateful Synapse
Some papers, e.g., 2 and 3 , use the stateful synapses. By placing spikingjelly.activation_based.layer.SynapseFilter
after the synapse to filter the output current, we can get the stateful synapse:
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, neuron
stateful_conv = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, padding=1, stride=1),
layer.SynapseFilter(tau=100.)
)
Experiments on Sequential FashionMNIST
Now let us do some simple experiments on Sequential FashionMNIST to verify whether the recurrent connection or the stateful synapse can promote the network’s ability on the memory task. The Sequential FashionMNIST dataset is a modified FashionMNIST dataset. Images will be sent to the network row by row or column by column, rather than be sent entirely. To classify correctly, the network should have good memory ability. We will send images column by column, which is similar to how humans read the book from left to right:

The following figure shows the column that is being sent:

First, let us import some packages:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets
from spikingjelly.activation_based import neuron, surrogate, layer, functional
from torch.cuda import amp
import os, argparse
from torch.utils.tensorboard import SummaryWriter
import time
import datetime
import sys
Define the plain feedforward network PlainNet
:
class PlainNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
By adding a spikingjelly.activation_based.layer.SynapseFilter
behind the first spiking neurons layer of PlainNet
, we can get the network StatefulSynapseNet
:
class StatefulSynapseNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
neuron.IFNode(surrogate_function=surrogate.ATan()),
layer.SynapseFilter(tau=2., learnable=True),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
By adding a recurrent connection implemented by spikingjelly.activation_based.layer.LinearRecurrentContainer
to PlainNet
, we can get FeedBackNet
class FeedBackNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Sequential(
layer.Linear(28, 32),
layer.LinearRecurrentContainer(
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True),
in_features=32, out_features=32, bias=True
),
layer.Linear(32, 10),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
def forward(self, x: torch.Tensor):
return self.fc(x).mean(0)
The following figure shows the network structure of three networks:

The complete codes are saved in spikingjelly.activation_based.examples.rsnn_sequential_fmnist. We can run by the following commands:
usage: rsnn_sequential_fmnist.py [-h] [-model MODEL] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR] [-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]
Classify Sequential Fashion-MNIST
optional arguments:
-h, --help show this help message and exit
-model MODEL use which model, "plain", "ss" (StatefulSynapseNet) or "fb" (FeedBackNet)
-device DEVICE device
-b B batch size
-epochs N number of total epochs to run
-j N number of data loading workers (default: 4)
-data-dir DATA_DIR root dir of Fashion-MNIST dataset
-out-dir OUT_DIR root dir for saving logs and checkpoint
-resume RESUME resume from the checkpoint path
-amp automatic mixed precision training
-cupy use cupy backend
-opt OPT use which optimizer. SDG or Adam
-momentum MOMENTUM momentum for SGD
-lr LR learning rate
Train three networks:
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model plain
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model fb
python -m spikingjelly.activation_based.examples.rsnn_sequential_fmnist -device cuda:0 -b 256 -epochs 64 -data-dir /datasets/FashionMNIST/ -amp -cupy -opt sgd -lr 0.1 -j 8 -model ss
The following figures show the accuracy curves during training:
We can find that both StatefulSynapseNet
and FeedBackNet
have higher accuracy than PlainNet
, indicating that recurrent connection and stateful synapse can promote the network’s memory ability.
- 1
Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
- 2
Diehl P U, Cook M. Unsupervised learning of digit recognition using spike-timing-dependent plasticity[J]. Frontiers in computational neuroscience, 2015, 9: 99.
- 3
Fang H, Shrestha A, Zhao Z, et al. Exploiting Neuron and Synapse Filter Dynamics in Spatial Temporal Learning of Deep Spiking Neural Network[J].
Train large-scale SNNs
Author: fangwei123456
Usage of activation_based.model
spikingjelly.activation_based.model
has defined some classic networks, which we can use as we use torchvision.models
. For example, we can build the Spiking ResNet 1 :
import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=False, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
print(f's_resnet18={s_resnet18}')
with torch.no_grad():
T = 4
N = 1
x_seq = torch.rand([T, N, 3, 224, 224])
functional.set_step_mode(s_resnet18, 'm')
y_seq = s_resnet18(x_seq)
print(f'y_seq.shape={y_seq.shape}')
functional.reset_net(s_resnet18)
The outputs are:
s_resnet18=SpikingResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False, step_mode=s)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False, step_mode=s)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn1): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False, step_mode=s)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode=s)
(sn2): IFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=True, step_mode=s, backend=torch
(surrogate_function): ATan(alpha=2.0, spiking=True)
)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1), step_mode=s)
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
y_seq.shape=torch.Size([4, 1, 1000])
Spiking ResNet in SpikingJelly has the same network structure as that in torchvision
. Their state_dict().keys()
are identical and we can load pre-trained weights by setting pretrained=True
:
s_resnet18 = spiking_resnet.spiking_resnet18(pretrained=True, spiking_neuron=neuron.IFNode, surrogate_function=surrogate.ATan(), detach_reset=True)
Usage of activation_based.model.train_classify
spikingjelly.activation_based.model.train_classify
is modified by torchvision 0.12 references. We can use this module to train easily.
spikingjelly.activation_based.model.train_classify.Trainer
provides a flexible method to train. Users can change its functions to implement the desirable behaviors without too much
efforts. For example, spikingjelly.activation_based.model.train_classify.Trainer.set_optimizer
defines how to set the optimizer:
# spikingjelly.activation_based.model.train_classify
class Trainer:
# ...
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov="nesterov" in opt_name,
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError(f"Invalid optimizer
{args.opt}. Only SGD, RMSprop and AdamW are supported.")
return optimizer
def main(self, args):
# ...
optimizer = self.set_optimizer(args, parameters)
# ...
If we want to add an optimizer, e.g., Adamax
, we can inherit the class and override this function:
class MyTrainer(train_classify.Trainer):
def set_optimizer(self, args, parameters):
opt_name = args.opt.lower()
if opt_name.startswith("adamax"):
optimizer = torch.optim.Adamax(parameters, lr=args.lr, weight_decay=args.weight_decay)
return optimizer
else:
return super().set_optimizer(args, parameters)
Trainer.get_args_parser
defines the args for training:
(pytorch-env) PS spikingjelly> python -m spikingjelly.activation_based.model.train_classify -h
usage: train_classify.py [-h] [--data-path DATA_PATH] [--model MODEL] [--device DEVICE] [-b BATCH_SIZE] [--epochs N] [-j N] [--opt OPT] [--lr LR] [--momentum M] [--wd W] [--norm-weight-decay NORM_WEIGHT_DECAY] [--label-smoothing LABEL_SMOOTHING]
[--mixup-alpha MIXUP_ALPHA] [--cutmix-alpha CUTMIX_ALPHA] [--lr-scheduler LR_SCHEDULER] [--lr-warmup-epochs LR_WARMUP_EPOCHS] [--lr-warmup-method LR_WARMUP_METHOD] [--lr-warmup-decay LR_WARMUP_DECAY]
[--lr-step-size LR_STEP_SIZE] [--lr-gamma LR_GAMMA] [--output-dir OUTPUT_DIR] [--resume RESUME] [--start-epoch N] [--cache-dataset] [--sync-bn] [--test-only] [--pretrained] [--auto-augment AUTO_AUGMENT]
[--random-erase RANDOM_ERASE] [--world-size WORLD_SIZE] [--dist-url DIST_URL] [--model-ema] [--model-ema-steps MODEL_EMA_STEPS] [--model-ema-decay MODEL_EMA_DECAY] [--interpolation INTERPOLATION]
[--val-resize-size VAL_RESIZE_SIZE] [--val-crop-size VAL_CROP_SIZE] [--train-crop-size TRAIN_CROP_SIZE] [--clip-grad-norm CLIP_GRAD_NORM] [--ra-sampler] [--ra-reps RA_REPS] [--prototype] [--weights WEIGHTS] [--seed SEED]
[--print-logdir] [--clean] [--disable-pinmemory] [--disable-amp] [--local_rank LOCAL_RANK] [--disable-uda]
PyTorch Classification Training
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
dataset path
--model MODEL model name
--device DEVICE device (Use cuda or cpu Default: cuda)
-b BATCH_SIZE, --batch-size BATCH_SIZE
images per gpu, the total batch size is $NGPU x batch_size
--epochs N number of total epochs to run
-j N, --workers N number of data loading workers (default: 16)
--opt OPT optimizer
--lr LR initial learning rate
--momentum M momentum
--wd W, --weight-decay W
weight decay (default: 0.)
--norm-weight-decay NORM_WEIGHT_DECAY
weight decay for Normalization layers (default: None, same value as --wd)
--label-smoothing LABEL_SMOOTHING
label smoothing (default: 0.1)
--mixup-alpha MIXUP_ALPHA
mixup alpha (default: 0.2)
--cutmix-alpha CUTMIX_ALPHA
cutmix alpha (default: 1.0)
--lr-scheduler LR_SCHEDULER
the lr scheduler (default: cosa)
--lr-warmup-epochs LR_WARMUP_EPOCHS
the number of epochs to warmup (default: 5)
--lr-warmup-method LR_WARMUP_METHOD
the warmup method (default: linear)
--lr-warmup-decay LR_WARMUP_DECAY
the decay for lr
--lr-step-size LR_STEP_SIZE
decrease lr every step-size epochs
--lr-gamma LR_GAMMA decrease lr by a factor of lr-gamma
--output-dir OUTPUT_DIR
path to save outputs
--resume RESUME path of checkpoint. If set to 'latest', it will try to load the latest checkpoint
--start-epoch N start epoch
--cache-dataset Cache the datasets for quicker initialization. It also serializes the transforms
--sync-bn Use sync batch norm
--test-only Only test the model
--pretrained Use pre-trained models from the modelzoo
--auto-augment AUTO_AUGMENT
auto augment policy (default: ta_wide)
--random-erase RANDOM_ERASE
random erasing probability (default: 0.1)
--world-size WORLD_SIZE
number of distributed processes
--dist-url DIST_URL url used to set up distributed training
--model-ema enable tracking Exponential Moving Average of model parameters
--model-ema-steps MODEL_EMA_STEPS
the number of iterations that controls how often to update the EMA model (default: 32)
--model-ema-decay MODEL_EMA_DECAY
decay factor for Exponential Moving Average of model parameters (default: 0.99998)
--interpolation INTERPOLATION
the interpolation method (default: bilinear)
--val-resize-size VAL_RESIZE_SIZE
the resize size used for validation (default: 232)
--val-crop-size VAL_CROP_SIZE
the central crop size used for validation (default: 224)
--train-crop-size TRAIN_CROP_SIZE
the random crop size used for training (default: 176)
--clip-grad-norm CLIP_GRAD_NORM
the maximum gradient norm (default None)
--ra-sampler whether to use Repeated Augmentation in training
--ra-reps RA_REPS number of repetitions for Repeated Augmentation (default: 4)
--prototype Use prototype model builders instead those from main area
--weights WEIGHTS the weights enum name to load
--seed SEED the random seed
--print-logdir print the dirs for tensorboard logs and pt files and exit
--clean delete the dirs for tensorboard logs and pt files
--disable-pinmemory not use pin memory in dataloader, which can help reduce memory consumption
--disable-amp not use automatic mixed precision training
--local_rank LOCAL_RANK
args for DDP, which should not be set by user
--disable-uda not set 'torch.use_deterministic_algorithms(True)', which can avoid the error raised by some functions that do not have a deterministic implementation
If we want to add some args, we can also inherit and override it:
class MyTrainer(train_classify.Trainer):
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--do-something', type=str, help="do something")
We can modify most functions in Trainer
.
We can use the following codes to train with Trainer
or the user-defined trainer:
trainer = Trainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
Trainer
will calculate Acc@1, Acc@5, loss
on the training and test dataset, and save them by tensorboard
. The model weights of the latest epoch and the maximum test accuracy will also be saved.Trainer
also supports Distributed Data Parallel (DDP) training.
Training on ImageNet
The default data loading function load_data
will load the ImageNet 2 dataset. With Trainer
and spikingjelly.activation_based.model.spiking_resnet
, we can train large-scale SNNs easily. Here are the example codes:
# spikingjelly.activation_based.model.train_imagenet_example
import torch
from spikingjelly.activation_based import surrogate, neuron, functional
from spikingjelly.activation_based.model import spiking_resnet, train_classify
class SResNetTrainer(train_classify.Trainer):
def preprocess_train_sample(self, args, x: torch.Tensor):
# define how to process train sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def preprocess_test_sample(self, args, x: torch.Tensor):
# define how to process test sample before send it to model
return x.unsqueeze(0).repeat(args.T, 1, 1, 1, 1) # [N, C, H, W] -> [T, N, C, H, W]
def process_model_output(self, args, y: torch.Tensor):
return y.mean(0) # return firing rate
def get_args_parser(self, add_help=True):
parser = super().get_args_parser()
parser.add_argument('--T', type=int, help="total time-steps")
parser.add_argument('--cupy', action="store_true", help="set the neurons to use cupy backend")
return parser
def get_tb_logdir_name(self, args):
return super().get_tb_logdir_name(args) + f'_T{args.T}'
def load_model(self, args, num_classes):
if args.model in spiking_resnet.__all__:
model = spiking_resnet.__dict__[args.model](pretrained=args.pretrained, spiking_neuron=neuron.IFNode,
surrogate_function=surrogate.ATan(), detach_reset=True)
functional.set_step_mode(model, step_mode='m')
if args.cupy:
functional.set_backend(model, 'cupy', neuron.IFNode)
return model
else:
raise ValueError(f"args.model should be one of {spiking_resnet.__all__}")
if __name__ == "__main__":
trainer = SResNetTrainer()
args = trainer.get_args_parser().parse_args()
trainer.main(args)
The codes are saved in spikingjelly.activation_based.model.train_imagenet_example
. Training on a single GPU:
python -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
Training with DDP on two GPUs:
python -m torch.distributed.launch --nproc_per_node=2 -m spikingjelly.activation_based.model.train_imagenet_example --T 4 --model spiking_resnet18 --data-path /datasets/ImageNet0_03125 --batch-size 64 --lr 0.1 --lr-scheduler cosa --epochs 90
- 1
He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
- 2
Deng, Jia, et al. “Imagenet: A large-scale hierarchical image database.” 2009 IEEE conference on computer vision and pattern recognition. IEEE, 2009.
STDP Learning
Author: fangwei123456
Researchers of SNNs are always interested in biological learning rules. In SpkingJelly, STDP(Spike Timing Dependent Plasticity) is also provided and can be applied to convolutional or linear layers.
STDP(Spike Timing Dependent Plasticity)
STDP(Spike Timing Dependent Plasticity) is proposed by 1, which is a synaptic plasticity rule found in biological neural system. The experiments in the biological neural systems find that the weight of synapse is influenced by the firing time of spikes of the pre and post neuron. More specific, STDP can be formulated as:
If the pre neuron fires early and the post neuron fires later, then the weight will increase; If the pre neuron fires later while the post neuron fires early, then the weight will decrease.
The curve 2 that fits the experiments data is as follows:

We can use the following equation to describe STDP:
where \(A, B\) are the maximum of weight variation, and \(\tau_{+}, \tau_{-}\) are time constants.
However, the above equation is seldom used in practicals because it needs to record all firing times of pre and post neurons.The trace method 3 is a more popular method to implement STDP.
For the pre neuron \(i\) and the post neuron \(j\), we use the traces \(tr_{pre}[i]\) and \(tr_{post}[j]\) to track their firing. The update of traces are similar to the LIF neuron:
where \(\tau_{pre}, \tau_{post}\) are time constants of the pre and post neuron. \(s[i][t], s[j][t]\) are the spikes at time-step \(t\) of the pre neuron \(i\) and the post neuron \(j\), which can only be 0 or 1.
The update of weight is:
where \(F_{pre}, F_{post}\) are functions that control how weight changes.
STDP Learner
spikingjelly.activation_based.learning.STDPLearner
can apply STDP learning on convolutional or linear layers. Please read the api doc first to learn how to use it.
Now let us use STDPLearner
to build the simplest 1x1
SNN with only one pre and one post neuron. And we set the weight as 0.4
:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt
torch.manual_seed(0)
def f_weight(x):
return torch.clamp(x, -1, 1.)
tau_pre = 2.
tau_post = 2.
T = 128
N = 1
lr = 0.01
net = nn.Sequential(
layer.Linear(1, 1, bias=False),
neuron.IFNode()
)
nn.init.constant_(net[0].weight.data, 0.4)
STDPLearner
can add the negative weight variation - delta_w * scale
on the gradient of weight, which makes it compatible with deep learning methods. We can use the optimizer, learning rate scheduler with STDPLearner
together.
In this example, we use the simplest parameter update method:
where \(\nabla W\) is - delta_w * scale
. Thus, the optimizer will apply weight.data = weight.data - lr * weight.grad = weight.data + lr * delta_w * scale
.
We can implement the above parameter update method by the plain torch.optim.SGD
with momentum=0.
:
optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)
Then we create the input spikes and set STDPLearner
:
in_spike = (torch.rand([T, N, 1]) > 0.7).float()
stdp_learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
Then we send data to the network. Note that to plot the figure, we will squeeze()
the data, which reshape them from shape = [T, N, 1]
to shape = [T]
:
out_spike = []
trace_pre = []
trace_post = []
weight = []
with torch.no_grad():
for t in range(T):
optimizer.zero_grad()
out_spike.append(net(in_spike[t]).squeeze())
stdp_learner.step(on_grad=True) # add ``- delta_w * scale`` on grad
optimizer.step()
weight.append(net[0].weight.data.clone().squeeze())
trace_pre.append(stdp_learner.trace_pre.squeeze())
trace_post.append(stdp_learner.trace_post.squeeze())
in_spike = in_spike.squeeze()
out_spike = torch.stack(out_spike)
trace_pre = torch.stack(trace_pre)
trace_post = torch.stack(trace_post)
weight = torch.stack(weight)
The complete codes are available at spikingjelly/activation_based/examples/stdp_trace.py
:
Let us plot in_spike, out_spike, trace_pre, trace_post, weight
:
This figure is similar to Fig.3 in 3 (note that they use j as the pre neuron and i as the post neuron, while we use the opposite symbol):

Combine STDP Learning with Gradient Descent
A widely used method with STDP is using gradient descent and STDP to train different layers in an SNN. With STDPLearner
, we can combine STDP learning with gradient descent easily.
Our goal is to build a deep SNN, train convolutional layers with STDP, and train linear layers with gradient descent. First, let us define the hyper-parameters:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional
T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'
Here we use the input with shape = [T, N, C, H, W] = [8, 2, 3, 32, 32]
.
Then we define the weight function and the SNN. Here we build a convolutional SNN with a multi-step mode:
def f_weight(x):
return torch.clamp(x, -1, 1.)
net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)
functional.set_step_mode(net, step_mode)
We want to use STDP to train layer.Conv2d
while other layers are to be trained with gradient descent. We use instances_stdp
as the layers which are trained by STDP:
instances_stdp = (layer.Conv2d, )
We create an STDP learner for each layer in the SNN with the instance in instances_stdp
:
stdp_learners = []
for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)
Now we split parameters into two groups. The parameters from layers whose instances are in or not in instances_stdp
will be set to two optimizers. Here we use Adam
to optimize the parameters which are trained by gradient descent, and SGD
to optimize the parameters which are trained by STDP:
params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)
params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
if p not in params_stdp_set:
params_gradient_descent.append(p)
optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)
When we train the SNN in actual tasks, e.g., classifying CIFAR-10, we get samples from the dataset. But here we only want to implement an example. Hence, we create the samples manually:
x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])
Then we will use the two optimizers to update the parameters. Note that the following codes are different from the plain gradient descent we use before.
First, let us clear all gradients, do a forward, calculate the loss and do a backward:
optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()
Note that even though optimizer_gd
will only update parameters in params_gradient_descent
, loss.backward()
will calculate and set .grad
to all parameters including those we want to calculate the weight variation (implemented by on .grad
) by STDP.
Thus, we need to clear the gradients of params_stdp
:
optimizer_stdp.zero_grad()
Then we need to use STDPLearner
to get “gradients”, and use two optimizers to update all parameters:
for i in range(stdp_learners.__len__()):
stdp_learners[i].step(on_grad=True)
optimizer_gd.step()
optimizer_stdp.step()
All the learners ( STDPLearner
, for instance) inherit from MemoryModule
. Hence, they have internal memories ( trace_pre, trace_post
for STDPLearner
). In addition, the monitors inside the learners record the firing histories of the pre-synaptic and post-synaptic neurons; these histories may also be considered as internal memories of the learners. We should call the reset()
method to clear the internal memory promptly so as to avoid the nonstop growing of memory consumption. We suggest resetting the learners together with the network after each batch:
functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()
The complete codes are as follows:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
from spikingjelly.activation_based import learning, layer, neuron, functional
T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'
def f_weight(x):
return torch.clamp(x, -1, 1.)
net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)
functional.set_step_mode(net, step_mode)
instances_stdp = (layer.Conv2d, )
stdp_learners = []
for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)
params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)
params_stdp_set = set(params_stdp)
params_gradient_descent = []
for p in net.parameters():
if p not in params_stdp_set:
params_gradient_descent.append(p)
optimizer_gd = Adam(params_gradient_descent, lr=lr)
optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)
x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])
optimizer_gd.zero_grad()
optimizer_stdp.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()
optimizer_stdp.zero_grad()
for i in range(stdp_learners.__len__()):
stdp_learners[i].step(on_grad=True)
optimizer_gd.step()
optimizer_stdp.step()
functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()
- 1
Bi, Guo-qiang, and Mu-ming Poo. “Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type.” Journal of neuroscience 18.24 (1998): 10464-10472.
- 2
Froemke, Robert C., et al. “Contribution of individual spikes in burst-induced long-term synaptic modification.” Journal of neurophysiology (2006).
- 3(1,2)
Morrison, Abigail, Markus Diesmann, and Wulfram Gerstner. “Phenomenological models of synaptic plasticity based on spike timing.” Biological cybernetics 98.6 (2008): 459-478.
ANN2SNN
Author: DingJianhao, fangwei123456, Lv Liuzhenghao
This tutorial focuses on spikingjelly.activation_based.ann2snn
, introduce how to convert the trained feedforward ANN to SNN and simulate it on the SpikingJelly framework.
ANN2SNN api references are here api references .
There are two sets of implementations in earlier implementations: ONNX-based and PyTorch-based. This version is based on torch.fx. Fx is specially used to transform nn.Module instances, and will natively decouple complex models when building graph intermediate representation. Let’s have a look!
Theoretical basis of ANN2SNN
Compared with ANN, the generated pulses of SNN are discrete, which is conducive to efficient communication. Today, with the popularity of ANN, the direct training of SNN requires more resources.
Naturally, we will think of using the now very mature ANN to convert to SNN, and hope that SNN can have similar performance. This involves the problem of how to build a bridge between ANN and SNN.
Now the mainstream way of SNN is to use frequency encoding, so for the output layer, we will use the number of neuron output pulses to judge the category. Is there a relationship between the release rate and ANN?
Fortunately, there is a strong correlation between the nonlinear activation of ReLU neurons in ANN and the firing rate of IF neurons in SNN (reset by subtracting the threshold \(V_{threshold}\) ). this feature to convert. The neuron update method mentioned here is the Soft method mentioned in Neuron tutorial.
Experiment: Relationship between IF neuron spiking frequency and input
We gave constant input to the IF neuron and observed its output spikes and spike firing frequency. First import the relevant modules, create a new IF neuron layer, determine the input and draw the input of each IF neuron \(x_{i}\):
import torch
from spikingjelly.activation_based import neuron
from spikingjelly import visualizing
from matplotlib import pyplot as plt
import numpy as np
plt.rcParams['figure.dpi'] = 200
if_node = neuron.IFNode(v_reset=None)
T = 128
x = torch.arange(-0.2, 1.2, 0.04)
plt.scatter(torch.arange(x.shape[0]), x)
plt.title('Input $x_{i}$ to IF neurons')
plt.xlabel('Neuron index $i$')
plt.ylabel('Input $x_{i}$')
plt.grid(linestyle='-.')
plt.show()
Next, send the input to the IF neuron layer, and run the T=128
step to observe the pulses and pulse firing frequency of each neuron:
s_list = []
for t in range(T):
s_list.append(if_node(x).unsqueeze(0))
out_spikes = np.asarray(torch.cat(s_list))
visualizing.plot_1d_spikes(out_spikes, 'IF neurons\' spikes and firing rates', 't', 'Neuron index $i$')
plt.show()
It can be found that the frequency of the pulse firing is within a certain range, which is proportional to the size of the input \(x_{i}\).
Next, let’s plot the firing frequency of the IF neuron against the input \(x_{i}\) and compare it with \(\mathrm{ReLU}(x_{i})\):
plt.subplot(1, 2, 1)
firing_rate = np.mean(out_spikes, axis=1)
plt.plot(x, firing_rate)
plt.title('Input $x_{i}$ and firing rate')
plt.xlabel('Input $x_{i}$')
plt.ylabel('Firing rate')
plt.grid(linestyle='-.')
plt.subplot(1, 2, 2)
plt.plot(x, x.relu())
plt.title('Input $x_{i}$ and ReLU($x_{i}$)')
plt.xlabel('Input $x_{i}$')
plt.ylabel('ReLU($x_{i}$)')
plt.grid(linestyle='-.')
plt.show()
It can be found that the two curves are almost the same. It should be noted that the pulse frequency cannot be higher than 1, so the IF neuron cannot fit the input of the ReLU in the ANN is larger than 1.
Theoretical basis of ANN2SNN
The literature 1 provides a theoretical basis for analyzing the conversion of ANN to SNN. The theory shows that the IF neuron in SNN is an unbiased estimator of ReLU activation function over time.
For the first layer of the neural network, the input layer, discuss the relationship between the firing rate of SNN neurons \(r\) and the activation in the corresponding ANN. Assume that the input is constant as \(z \in [0,1]\). For the IF neuron reset by subtraction, its membrane potential V changes with time as follows:
Where: \(V_{threshold}\) is the firing threshold, usually set to 1.0. \(\theta_t\) is the output spike. The average firing rate in the \(T\) time steps can be obtained by summing the membrane potential:
Move all the items containing \(V_t\) to the left, and divide both sides by \(T\):
Where \(N\) is the number of pulses in the time step of \(T\), and \(\frac{N}{T}\) is the issuing rate \(r\). Use \(z = V_{threshold} a\) which is:
Therefore, when the simulation time step \(T\) is infinite:
Similarly, for the higher layers of the neural network, literature 1 further explains that the inter-layer firing rate satisfies:
For details, please refer to 1. The methods in ann2snn also mainly come from 1 .
Converting to spiking neural network
Conversion mainly solves two problems:
ANN proposes Batch Normalization for fast training and convergence. Batch normalization aims to normalize the ANN output to 0 mean, which is contrary to the properties of SNNs. Therefore, the parameters of BN can be absorbed into the previous parameter layers (Linear, Conv2d)
According to the transformation theory, the input and output of each layer of ANN need to be limited to the range of [0,1], which requires scaling the parameters (model normalization)
◆ BatchNorm parameter absorption
Assume that the parameters of BatchNorm are: math:gamma (BatchNorm.weight
), \(\beta\) (BatchNorm.bias
), \(\mu\) (BatchNorm. .running_mean
) ,
\(\sigma\) (BatchNorm.running_var
, \(\sigma = \sqrt{\mathrm{running\_var}}\)). For specific parameter definitions, see
torch.nn.BatchNorm1d .
Parameter modules (eg Linear) have parameters \(W\) and \(b\) . BatchNorm parameter absorption is to transfer the parameters of BatchNorm to \(W\) and \(b\) of the parameter module by operation, so that the output of the new module of data input is the same as when there is BatchNorm.
For this, the \(\bar{W}\) and \(\bar{b}\) formulas for the new model are expressed as:
◆ Model Normalization
For a parameter module, it is assumed that its input tensor and output tensor are obtained, the maximum value of its input tensor is \(\lambda_{pre}\), and the maximum value of its output tensor is \(\lambda\). Then, the normalized weight \(\hat{W}\) is:
The normalized bias \(\hat{b}\) is:
Although the distribution of the output of each layer of ANN obeys a certain distribution, there are often large outliers in the data, which will lead to a decrease in the overall neuron firing rate. To address this, robust normalization adjusts the scaling factor from the maximum value of the tensor to the p-quantile of the tensor. The recommended quantile value in the literature is 99.9.
So far, what we have done with neural networks is numerically equivalent. The current model should perform the same as the original model.
In the conversion, we need to change the ReLU activation function in the original model into IF neurons. For average pooling in ANN, we need to convert it to spatial downsampling. Since IF neurons can be equivalent to the ReLU activation function. Adding IF neurons or not after spatial downsampling has minimal effect on the results. There is currently no very ideal solution for max pooling in ANNs. The best solution so far is to control the pulse channel 1 with a gating function based on momentum accumulated pulses. Here we still recommend using avgpool2d. When simulating, according to the transformation theory, the SNN needs to input a constant analog input. Using a Poisson encoder will bring about a reduction in accuracy.
Implementation and optional configuration
The ann2snn framework was updated in April 2022. The two categories of parser and simulator have been cancelled, and instead the converter class has been used. It is more concise and has more modes for transformation settings.
The framework was updated again in October 2022. Fuse method has benn added to the converter class to fuse the conv layer and the bn layer.
◆ Converter class
This class is used to convert ReLU’s ANN to SNN.
Three common patterns are implemented here:
The most common is the maximum current switching mode (MaxNorm), which utilizes the upper and lower activation limits of the front and rear layers so that the case with the highest firing rate corresponds to the case where the activation achieves the maximum value. Using this mode requires setting the parameter mode to max
2.
The 99.9% current switching mode (RobustNorm) utilizes the 99.9% activation quantile to limit the upper activation limit. Using this mode requires setting the parameter mode to 99.9%
1.
In the scaling conversion mode, the user needs to specify the scaling parameters into the mode, and the current can be limited by the activated maximum value after scaling. Using this mode requires setting the parameter mode to a float of 0-1.
The optional fuse_conv_bn feature is realized:
You can set fuse_flag
to True
(by default), in order to fuse fuse the conv layer and the bn layer.
After converting, ReLU modules will be removed. And new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module snn tailor
.
Due to the type of the return model is fx.GraphModule, you can use ‘print(fx.GraphModule.graph)’ to view how modules links and the how the forward method works. More APIs are here GraphModule .
Classify MNIST
Build the ANN to be converted
Now we use ann2snn
to build a simple convolutional network to classify the MNIST dataset.
First define our network structure (see ann2snn.sample_models.mnist_cnn
):
class ANN(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(1, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Conv2d(32, 32, 3, 1),
nn.BatchNorm2d(32, eps=1e-3),
nn.ReLU(),
nn.AvgPool2d(2, 2),
nn.Flatten(),
nn.Linear(32, 10),
nn.ReLU()
)
def forward(self,x):
x = self.network(x)
return x
Note: If you need to expand the tensor, define a nn.Flatten
module in the network, and use the defined Flatten instead of the view function in the forward function.
Define our hyperparameters:
torch.random.manual_seed(0)
torch.cuda.manual_seed(0)
device = 'cuda'
dataset_dir = 'G:/Dataset/mnist'
batch_size = 100
T = 50
Here T is the inference time step used in inference for a while.
If you want to train, you also need to initialize the data loader, optimizer, loss function, for example:
lr = 1e-3
epochs = 10
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ann.parameters(), lr=lr, weight_decay=5e-4)
Train the ANN. In the example, our model is trained for 10 epochs. The test set accuracy changes during training are as follows:
Epoch: 0 100%|██████████| 600/600 [00:05<00:00, 112.04it/s]
Validating Accuracy: 0.972
Epoch: 1 100%|██████████| 600/600 [00:05<00:00, 105.43it/s]
Validating Accuracy: 0.986
Epoch: 2 100%|██████████| 600/600 [00:05<00:00, 107.49it/s]
Validating Accuracy: 0.987
Epoch: 3 100%|██████████| 600/600 [00:05<00:00, 109.26it/s]
Validating Accuracy: 0.990
Epoch: 4 100%|██████████| 600/600 [00:05<00:00, 103.98it/s]
Validating Accuracy: 0.984
Epoch: 5 100%|██████████| 600/600 [00:05<00:00, 100.42it/s]
Validating Accuracy: 0.989
Epoch: 6 100%|██████████| 600/600 [00:06<00:00, 96.24it/s]
Validating Accuracy: 0.991
Epoch: 7 100%|██████████| 600/600 [00:05<00:00, 104.97it/s]
Validating Accuracy: 0.992
Epoch: 8 100%|██████████| 600/600 [00:05<00:00, 106.45it/s]
Validating Accuracy: 0.991
Epoch: 9 100%|██████████| 600/600 [00:05<00:00, 111.93it/s]
Validating Accuracy: 0.991
After training the model, we quickly load the model to test the performance of the saved model:
model.load_state_dict(torch.load('SJ-mnist-cnn_model-sample.pth'))
acc = val(model, device, test_data_loader)
print('ANN Validating Accuracy: %.4f' % (acc))
The output is as follows:
100%|██████████| 200/200 [00:02<00:00, 89.44it/s]
ANN Validating Accuracy: 0.9870
Make the conversion with the converter
Converting with Converter is very simple, you only need to set the mode you want to use in the parameters. For example, to use MaxNorm, you need to define an ann2snn.Converter
first, and forward the model to this object:
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
snn_model is the output SNN model. View the network structure of the snn_model (the absence of BatchNorm2d is due to conv_bn_fuse during the conversion process, i.e. absorbing the parameters of the bn layer into the conv layer):
ANN(
(network): Module(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(3): AvgPool2d(kernel_size=2, stride=2, padding=0)
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(7): AvgPool2d(kernel_size=2, stride=2, padding=0)
(8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(11): AvgPool2d(kernel_size=2, stride=2, padding=0)
(12): Flatten(start_dim=1, end_dim=-1)
(13): Linear(in_features=32, out_features=10, bias=True)
(15): Softmax(dim=1)
)
(snn tailor): Module(
(0): Module(
(0): VoltageScaler(0.240048)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(4.165831)
)
(1): Module(
(0): VoltageScaler(0.307485)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(3.252196)
)
(2): Module(
(0): VoltageScaler(0.141659)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(7.059210)
)
(3): Module(
(0): VoltageScaler(0.060785)
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False, step_mode=s, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
(2): VoltageScaler(16.451399)
)
)
)
The type of snn_model is GraphModule
, referring to GraphModule .
Call the GraphModule.graph.print_tabular()
method to view the graph of the intermediate representation of the model in tabular form:
#snn_model.graph.print_tabular()
opcode name target args kwargs
----------- -------------- -------------- ----------------- --------
placeholder x x () {}
call_module network_0 network.0 (x,) {}
call_module snn_tailor_0_1 snn tailor.0.0 (network_0,) {}
call_module snn_tailor_0_2 snn tailor.0.1 (snn_tailor_0_1,) {}
call_module snn_tailor_0_3 snn tailor.0.2 (snn_tailor_0_2,) {}
call_module network_3 network.3 (snn_tailor_0_3,) {}
call_module network_4 network.4 (network_3,) {}
call_module snn_tailor_1_1 snn tailor.1.0 (network_4,) {}
call_module snn_tailor_1_2 snn tailor.1.1 (snn_tailor_1_1,) {}
call_module snn_tailor_1_3 snn tailor.1.2 (snn_tailor_1_2,) {}
call_module network_7 network.7 (snn_tailor_1_3,) {}
call_module network_8 network.8 (network_7,) {}
call_module snn_tailor_2_1 snn tailor.2.0 (network_8,) {}
call_module snn_tailor_2_2 snn tailor.2.1 (snn_tailor_2_1,) {}
call_module snn_tailor_2_3 snn tailor.2.2 (snn_tailor_2_2,) {}
call_module network_11 network.11 (snn_tailor_2_3,) {}
call_module network_12 network.12 (network_11,) {}
call_module network_13 network.13 (network_12,) {}
call_module snn_tailor_3_1 snn tailor.3.0 (network_13,) {}
call_module snn_tailor_3_2 snn tailor.3.1 (snn_tailor_3_1,) {}
call_module snn_tailor_3_3 snn tailor.3.2 (snn_tailor_3_2,) {}
call_module network_15 network.15 (snn_tailor_3_3,) {}
output output output (network_15,) {}
Comparison of different converting modes
Following this example, we define the modes as max
, 99.9%
, 1.0/2
, 1.0/3
, 1.0/4
, 1.0/ 5
case SNN transformation and separate inference T steps to get the accuracy.
print('---------------------------------------------')
print('Converting using MaxNorm')
model_converter = ann2snn.Converter(mode='max', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_max_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_max_accs[-1]))
print('---------------------------------------------')
print('Converting using RobustNorm')
model_converter = ann2snn.Converter(mode='99.9%', dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_robust_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_robust_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/2 max(activation) as scales...')
model_converter = ann2snn.Converter(mode=1.0 / 2, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_two_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_two_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/3 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 3, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_three_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_three_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/4 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 4, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_four_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_four_accs[-1]))
print('---------------------------------------------')
print('Converting using 1/5 max(activation) as scales')
model_converter = ann2snn.Converter(mode=1.0 / 5, dataloader=train_data_loader)
snn_model = model_converter(model)
print('Simulating...')
mode_five_accs = val(snn_model, device, test_data_loader, T=T)
print('SNN accuracy (simulation %d time-steps): %.4f' % (T, mode_five_accs[-1]))
Observe the control bar output:
---------------------------------------------
Converting using MaxNorm
100%|██████████| 600/600 [00:04<00:00, 128.25it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.44it/s] SNN accuracy (simulation 50 time-steps): 0.9777
---------------------------------------------
Converting using RobustNorm
100%|██████████| 600/600 [00:19<00:00, 31.06it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.75it/s] SNN accuracy (simulation 50 time-steps): 0.9841
---------------------------------------------
Converting using 1/2 max(activation) as scales...
100%|██████████| 600/600 [00:04<00:00, 126.64it/s] ]Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.90it/s] SNN accuracy (simulation 50 time-steps): 0.9844
---------------------------------------------
Converting using 1/3 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 126.27it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.73it/s] SNN accuracy (simulation 50 time-steps): 0.9828
---------------------------------------------
Converting using 1/4 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 128.94it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.47it/s] SNN accuracy (simulation 50 time-steps): 0.9747
---------------------------------------------
Converting using 1/5 max(activation) as scales
100%|██████████| 600/600 [00:04<00:00, 121.18it/s] Simulating...
100%|██████████| 200/200 [00:13<00:00, 14.42it/s] SNN accuracy (simulation 50 time-steps): 0.9487
---------------------------------------------
The speed of model conversion can be seen to be very fast. Model inference speed of 200 steps takes only 11s to complete (GTX 2080ti). Based on the time-varying accuracy of the model output, we can plot the accuracy for different settings.
fig = plt.figure()
plt.plot(np.arange(0, T), mode_max_accs, label='mode: max')
plt.plot(np.arange(0, T), mode_robust_accs, label='mode: 99.9%')
plt.plot(np.arange(0, T), mode_two_accs, label='mode: 1.0/2')
plt.plot(np.arange(0, T), mode_three_accs, label='mode: 1.0/3')
plt.plot(np.arange(0, T), mode_four_accs, label='mode: 1.0/4')
plt.plot(np.arange(0, T), mode_five_accs, label='mode: 1.0/5')
plt.legend()
plt.xlabel('t')
plt.ylabel('Acc')
plt.show()

Different settings can get different results, some inference speed is fast, but the final accuracy is low, and some inference is slow, but the accuracy is high. Users can choose model settings according to their needs.
- 1(1,2,3,4,5,6)
Rueckauer B, Lungu I-A, Hu Y, Pfeiffer M and Liu S-C (2017) Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification. Front. Neurosci. 11:682.
- 2
Diehl, Peter U. , et al. Fast classifying, high-accuracy spiking deep networks through weight and threshold balancing. Neural Networks (IJCNN), 2015 International Joint Conference on IEEE, 2015.
- 3
Rueckauer, B., Lungu, I. A., Hu, Y., & Pfeiffer, M. (2016). Theory and tools for the conversion of analog to spiking convolutional neural networks. arXiv preprint arXiv:1612.04052.
- 4
Sengupta, A., Ye, Y., Wang, R., Liu, C., & Roy, K. (2019). Going deeper in spiking neural networks: Vgg and residual architectures. Frontiers in neuroscience, 13, 95.
Legacy Tutorials
Author: fangwei123456
Because of the limited time and energy of the developers, not all tutorials can be updated along with the new version of SpikingJelly. And some tutorials have been pruned and absorbed in the new tutorials. Here we list some legacy tutorials which may be helpful.
The predecessor of Activation-based
Activation-based is called as Clock-driven in the previous version of SpikingJelly. Here is a tutorial about Clock-driven:
Encoders
This tutorial has not been updated. The user can refer to the old tutorial for the moment:
ANN to SNN conversion
This tutorial has not been updated. The user can refer to the old tutorial for the moment:
Applications of SNNs on other tasks
Reinforcement Learning: Deep Q Learning
Reinforcement Learning: Advantage Actor Critic (A2C)
Reinforcement Learning: Proximal Policy Optimization (PPO)
Classifying Names with a Character-level Spiking LSTM
The predecessor of step mode:
The predecessor of CUPY backend:
Accelerate with CUDA-Enhanced Neuron and Layer-by-Layer Propagation
Call for Updating
We encourage the users to update these tutorials with the master version of SpikingJelly and create the Pull Request.
Implement CUPY Neuron
Author: fangwei123456
This tutorial will introduce how to implement the cupy backend for spiking neurons. We suppose the reader:
Can implement simple element-wise CUDA kernels
Can implement custom backward with
torch.autograd.Function
Has read all APIs doc in
spikingjelly.activation_based.auto_cuda.base
, and can implement 2D CUDA kernel byspikingjelly.activation_based.auto_cuda.base
Implement Forward Propagation Through Time
If we want to implement Forward Propagation Through Time (FPTT) by a python function, then the function should use the following input args:
v_init
:shape = [N]
, which is the initial membrane potential at current time-step (the membrane potential after neuronal firing at the last time-step), whereN
is the number of neurons. When the neurons are multidimensional,N
should be the number of neurons after flatteningx_seq
:shape = [T, N]
, the input ofT
time-stepsv_th
:float
, the threshold potential
If we use hard reset, we need an extra arg:
v_reset
:float
, the reset potential
The output of the python FPTT function should include:
spike_seq
:shape = [T, N]
, the output spikes atT
time-stepsv_seq
:shape = [T, N]
, the membrane potential after neuronal firing atT
time-steps. We output the membrane potential of all time-steps rather than only the last time-step, because we may use this data
If we implement the FPTT by CUDA, we will use some extra args, which will be introduced later.
spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronFPTTKernel
is inherited from spikingjelly.activation_based.auto_cuda.base.CKernel2D
. NeuronFPTTKernel
is the base class for FPTT. Let us print its CUDA kernel declaration:
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronFPTTKernel(hard_reset=True, dtype='float')
for key, value in base_kernel.cparams.items():
print(f'key="{key}",'.ljust(20), f'value="{value}"'.ljust(20))
The outputs are:
key="numel", value="const int &"
key="N", value="const int &"
key="x_seq", value="const float *"
key="v_v_seq", value="float *"
key="h_seq", value="float *"
key="spike_seq", value="float *"
key="v_th", value="float &"
key="v_reset", value="float &"
Most args have been introduced before. The new args are:
numel
: the number of elements in input/output tensors, which isnumel = T * N
N
: the number of neuronsv_v_seq
:shape = [T + 1, N]
, which is concatenated fromv_init
andv_seq
h_seq
:shape = [T, N]
, the membrane potential after neuronal charging but before neuronal firing, which will be used in backward
NeuronFPTTKernel
is the base class of neurons’ FPTT CUDA kernels. Similar to spikingjelly.activation_based.neuron.BaseNode
, it has implemented the neuronal fire and neuronal reset functions. When we want to implement a neuron FPTT kernel, we only need to inherit it and implement the neuronal charge function.
Firstly, let us check the full codes of NeuronFPTTKernel
:
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronFPTTKernel(hard_reset=True, dtype='float')
print(base_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h>
extern "C" __global__
void NeuronFPTTKernel_float_hard_reset(
const int & numel, const int & N, const float * x_seq, float * v_v_seq, float * h_seq, float * spike_seq, float & v_th, float & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
// neuronal_charge should be defined here!;
spike_seq[t] = (h_seq[t] - v_th) >= 0.0f ? 1.0f: 0.0f;
v_v_seq[t + dt] = h_seq[t] * (1.0f - spike_seq[t]) + v_reset * spike_seq[t];
}
}
}
We can find that this kernel is almost finished. We only need to add the neuronal charge function.
The neuronal_charge
function in NeuronFPTTKernel
is:
class NeuronFPTTKernel(base.CKernel2D):
# ...
def neuronal_charge(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`H[t] = f(X[t], V[t-1], ...)`.
This function should define how ``h_seq[t]`` is calculated by ``x_seq[t], v_v_seq[t]`` and other params if
the neuron needs.
For example, the IF neuron defines this function as:
.. code-block:: python
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
"""
return '// neuronal_charge should be defined here!'
To implement the new neuron, we only need to define the neuronal_charge
function.
Take the IF neuron as the example, whose neuronal charge function is:
And we can implement it as:
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeFPTTKernel(neuron_kernel.NeuronFPTTKernel):
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
if_fptt_kernel = IFNodeFPTTKernel(hard_reset=True, dtype='float')
print(if_fptt_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeFPTTKernel_float_hard_reset(
const int & numel, const int & N, const float * x_seq, float * v_v_seq, float * h_seq, float * spike_seq, float & v_th, float & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
h_seq[t] = x_seq[t] + v_v_seq[t];
spike_seq[t] = (h_seq[t] - v_th) >= 0.0f ? 1.0f: 0.0f;
v_v_seq[t + dt] = h_seq[t] * (1.0f - spike_seq[t]) + v_reset * spike_seq[t];
}
}
}
The above codes have implemented a complete CUDA kernel. We can find that it is easy to implement the kernel with NeuronFPTTKernel
.
Note that we use cfunction.add
:
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
We do not write codes like:
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return 'h_seq[t] = x_seq[t] + v_v_seq[t];'
The reason is functions in spikingjelly.activation_based.auto_cuda.cfunction
provide both float
and half2
implementation. Thus, it is more convenient than we write CUDA code with different data types manually.
If we set dtype='half2'
, we will get the kernel of half2
:
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeFPTTKernel(neuron_kernel.NeuronFPTTKernel):
def neuronal_charge(self) -> str:
# note that v_v_seq[t] is v_seq[t - dt]
return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
if_fptt_kernel = IFNodeFPTTKernel(hard_reset=True, dtype='half2')
print(if_fptt_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeFPTTKernel_half2_hard_reset(
const int & numel, const int & N, const half2 * x_seq, half2 * v_v_seq, half2 * h_seq, half2 * spike_seq, half2 & v_th, half2 & v_reset
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
for(int t = index; t < numel; t += dt)
{
h_seq[t] = __hadd2(x_seq[t], v_v_seq[t]);
spike_seq[t] = __hgeu2(__hsub2(h_seq[t], v_th), __float2half2_rn(0.0f));
v_v_seq[t + dt] = __hfma2(h_seq[t], __hsub2(__float2half2_rn(1.0f), spike_seq[t]), __hmul2(v_reset, spike_seq[t]));
}
}
}
Implement Back Propagation Through Time
It is harder to implement Back Propagation Through Time (BPTT) than FPTT. Firstly, let us review how the forward of the neuron is defined in SpikingJelly:
The FPTT has the formulation:
Correspondingly, the BPTT should use the formulation as:
Thus, the input args for the BPTT function are:
grad_spike_seq
:shape = [T, N]
, the gradients ofspike_seq
grad_v_seq
:shape = [T, N]
, the gradients ofv_seq
The outputs of BPTT function are:
grad_x_seq
:shape = [T, N]
, the gradients ofx_seq
grad_v_init
:shape = [N]
, the gradients ofv_init
According to the forward, we can calculate the backward as:
where \(D_{reset}\) denotes whether we detach the neuronal reset:
Finally, we get the backward formulation:
where \(\frac{\mathrm{d} H[t+1]}{\mathrm{d} V[t]}, \frac{\mathrm{d} H[t]}{\mathrm{d} X[t]}\) are determined by the neuron’s charge function \(H[t] = f(V[t - 1], X[t])\). \(\frac{\mathrm{d} S[t]}{\mathrm{d} H[t]}\) is determined by the surrogate function. While other gradients compilation is general and can be used for all kinds of neurons.
spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronBPTTKernel
has implemented the general compilation. Let us check its declaration:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
for key, value in base_kernel.cparams.items():
print(f'key="{key}",'.ljust(22), f'value="{value}"'.ljust(20))
The outputs are:
key="numel", value="const int &"
key="N", value="const int &"
key="grad_spike_seq", value="const float *"
key="grad_v_seq", value="const float *"
key="h_seq", value="const float *"
key="grad_x_seq", value="float *"
key="grad_v_init", value="float *"
key="v_th", value="float &"
key="v_reset", value="float &"
We have introduced these args before.
Note that we use NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, ...
because we need to define the surrogate function before applying backward.
Surrogate functions in SpikingJelly provide the cuda_codes
function to create CUDA codes for backward. Let us check this function in spikingjelly.activation_based.surrogate.Sigmoid
:
class Sigmoid(SurrogateFunctionBase):
# ...
def cuda_codes(self, y: str, x: str, dtype: str):
return cfunction.sigmoid_backward(y=y, x=x, alpha=self.alpha, dtype=dtype)
Now let us print its codes:
from spikingjelly.activation_based import surrogate
print(surrogate.Sigmoid().cuda_codes(y='grad_s', x='over_th', dtype='float'))
The outputs are:
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
grad_s = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
To implement the custom surrogate function with support for CUDA kernel, we need to define the cuda_codes
function by the following formulation:
class CustomSurrogateFunction:
# ...
def cuda_codes(self, y: str, x: str, dtype: str):
# ...
Now let us check the full codes of NeuronBPTTKernel
:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel
base_kernel = neuron_kernel.NeuronBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
print(base_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h>
extern "C" __global__
void NeuronBPTTKernel_float_hard_reset_nodetach_reset(
const int & N, const float * grad_spike_seq, float * grad_v_init, const float * grad_v_seq, float * grad_x_seq, const float * h_seq, const int & numel, float & v_reset, float & v_th
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
float grad_h = 0.0f;
for(int t = numel - N + index; t >= 0; t -= dt)
{
const float over_th = h_seq[t] - v_th;
const float spike_seq_t = over_th >= 0.0f ? 1.0f: 0.0f;
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
const float grad_s_to_h = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
float grad_v_to_h = (1.0f) - spike_seq_t;
{
float temp_var = v_reset - h_seq[t];
temp_var = temp_var * grad_s_to_h;
grad_v_to_h = temp_var + grad_v_to_h;
}
// grad_h_next_to_v should be defined here!;
grad_h = grad_h * grad_h_next_to_v;
grad_h = grad_v_seq[t] + grad_h;
grad_h = grad_h * grad_v_to_h;
{
float temp_var = grad_spike_seq[t] * grad_s_to_h;
grad_h = grad_h + temp_var;
}
// grad_h_to_x should be defined here!;
grad_x_seq[t] = grad_h * grad_h_to_x;
}
// grad_h_next_to_v should be defined here!;
grad_v_init[index] = grad_h * grad_h_next_to_v;
}
}
The comments in the above codes are what we should complete. These functions to be completed are defined in NeuronBPTTKernel
:
class NeuronBPTTKernel(base.CKernel2D):
# ...
def grad_h_next_to_v(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t+1]}{\\mathrm{d} V[t]}`.
This function should define how ``grad_h_next_to_v`` is calculated. Note that ``grad_h_next_to_v`` has not been
declared. Thus, this function should also declare ``grad_h_next_to_v``.
For example, the IF neuron defines this function as:
.. code-block:: python
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
"""
return '// grad_h_next_to_v should be defined here!'
def grad_h_to_x(self) -> str:
"""
:return: CUDA code
:rtype: str
Returns CUDA code for calculating :math:`\\frac{\\mathrm{d} H[t]}{\\mathrm{d} X[t]}`.
This function should define how ``grad_h_to_x`` is calculated. Note that ``grad_h_to_x`` has not been
declared. Thus, this function should also declare ``grad_h_to_x``.
For example, the IF neuron defines this function as:
.. code-block:: python
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
"""
return '// grad_h_to_x should be defined here!'
For the IF neuron, \(\frac{\mathrm{d} H[t+1]}{\mathrm{d} V[t]}=1, \frac{\mathrm{d} H[t]}{\mathrm{d} X[t]}=1\). Thus, we can implement the BPTT kernel easily:
class IFNodeBPTTKernel(neuron_kernel.NeuronBPTTKernel):
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
Then we can print the full codes of the BPTT kernel of the IF neuron:
from spikingjelly.activation_based import surrogate
from spikingjelly.activation_based.auto_cuda import neuron_kernel, cfunction
class IFNodeBPTTKernel(neuron_kernel.NeuronBPTTKernel):
def grad_h_next_to_v(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
def grad_h_to_x(self) -> str:
return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
kernel = IFNodeBPTTKernel(surrogate_function=surrogate.Sigmoid().cuda_codes, hard_reset=True, detach_reset=False, dtype='float')
print(kernel.full_codes)
#include <cuda_fp16.h>
extern "C" __global__
void IFNodeBPTTKernel_float_hard_reset_nodetach_reset(
const int & N, const float * grad_spike_seq, float * grad_v_init, const float * grad_v_seq, float * grad_x_seq, const float * h_seq, const int & numel, float & v_reset, float & v_th
)
{
const int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < N)
{
const int dt = N;
float grad_h = 0.0f;
for(int t = numel - N + index; t >= 0; t -= dt)
{
const float over_th = h_seq[t] - v_th;
const float spike_seq_t = over_th >= 0.0f ? 1.0f: 0.0f;
const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (4.0f) * over_th));
const float grad_s_to_h = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (4.0f);
float grad_v_to_h = (1.0f) - spike_seq_t;
{
float temp_var = v_reset - h_seq[t];
temp_var = temp_var * grad_s_to_h;
grad_v_to_h = temp_var + grad_v_to_h;
}
const float grad_h_next_to_v = 1.0f;
grad_h = grad_h * grad_h_next_to_v;
grad_h = grad_v_seq[t] + grad_h;
grad_h = grad_h * grad_v_to_h;
{
float temp_var = grad_spike_seq[t] * grad_s_to_h;
grad_h = grad_h + temp_var;
}
const float grad_h_to_x = 1.0f;
grad_x_seq[t] = grad_h * grad_h_to_x;
}
const float grad_h_next_to_v = 1.0f;
grad_v_init[index] = grad_h * grad_h_next_to_v;
}
}
Python Wrap
Now we need to use torch.autograd.Function
to wrap the FPTT and BPTT CUDA kernel.
spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronATGFBase
provides some useful functions to help us wrap. We suppose that the user has read the APIs docs of NeuronATGFBase
.
Firstly, we should determine the input. In SpikingJelly, the CUDA kernels will be used as input args, rather than created by the autograd Function (we did this before version 0.0.0.0.12).The forward function is defined as:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
Then, we will create py_dict
and use NeuronATGFBase.pre_forward
to preprocess it:
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
And we can call the forward CUDA kernel directly:
forward_kernel((blocks,), (threads,), py_dict)
Do not forget to save the params for backward:
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
Finally, we return the spikes and membrane potential of T
time-steps. Note that we should return v_v_seq[1:]
because v_v_seq[0]
is v_init
:
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
The full codes of the python forward autograd function are:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
forward_kernel((blocks,), (threads,), py_dict)
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
Now we need to implement the backward autograd function. Note that the input args for backward are the gradients of output args of forward. Thus, the input args are:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor):
We use NeuronATGFBase.pre_backward
to preprocess args to get the args for the CUDA kernel:
backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq)
And then we can call the backward kernel:
backward_kernel((blocks,), (threads,), py_dict)
Finally, we return the gradients. Note that the number of return args is identical to the number of input args for forward:
return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None
The full codes are:
class IFNodeATGF(torch.autograd.Function):
@staticmethod
def forward(ctx, x_seq: torch.Tensor, v_init: torch.Tensor, v_th: float, v_reset: float or None,
forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel):
py_dict = {
'x_seq': x_seq,
'v_init': v_init,
'v_th': v_th,
'v_reset': v_reset
}
requires_grad, blocks, threads, py_dict = NeuronATGFBase.pre_forward(py_dict)
forward_kernel((blocks,), (threads,), py_dict)
NeuronATGFBase.ctx_save(ctx, requires_grad, py_dict['h_seq'], blocks=blocks, threads=threads,
numel=py_dict['numel'], N=py_dict['N'], v_th=py_dict['v_th'], v_reset=py_dict['v_reset'],
backward_kernel=backward_kernel)
return py_dict['spike_seq'], py_dict['v_v_seq'][1:, ]
@staticmethod
def backward(ctx, grad_spike_seq: torch.Tensor, grad_v_seq: torch.Tensor):
backward_kernel, blocks, threads, py_dict = NeuronATGFBase.pre_backward(ctx, grad_spike_seq, grad_v_seq)
backward_kernel((blocks,), (threads,), py_dict)
return py_dict['grad_x_seq'], py_dict['grad_v_init'], None, None, None, None
Implement the CUPY backend
We have implemented IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF
. Now we can use them to implement the simplified IF neuron with CUPY backend.
Here are the codes:
from spikingjelly.activation_based.auto_cuda.neuron_kernel import IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF
# put sources of ``IFNodeFPTTKernel, IFNodeBPTTKernel, IFNodeATGF`` before the following codes
import torch
from typing import Callable
from spikingjelly.activation_based import base, surrogate
class CUPYIFNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float or None = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
self.detach_reset = detach_reset
self.step_mode = 'm'
if v_reset is not None:
self.register_memory('v', v_reset)
else:
self.register_memory('v', 0.)
def multi_step_forward(self, x_seq: torch.Tensor):
if isinstance(self.v, float):
self.v = torch.zeros_like(x_seq[0])
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
forward_kernel = IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype)
backward_kernel = IFNodeBPTTKernel(surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype)
# All tensors wil be regard as 2D or 1D. Thus, we use flatten
spike_seq, v_seq = IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(), self.v_threshold, self.v_reset, forward_kernel, backward_kernel)
spike_seq = spike_seq.view(x_seq.shape)
self.v = v_seq[-1].view(x_seq.shape[1:])
return spike_seq
Let us check the output error compared with the python neuron:
from spikingjelly.activation_based import neuron
@torch.no_grad()
def max_error(x: torch.Tensor, y: torch.Tensor):
return (x - y).abs().max()
T = 8
N = 64
C = 32 * 32 * 32
device = 'cuda:0'
x_seq = torch.rand([T, N, C], device=device, requires_grad=True)
net_cupy = CUPYIFNode()
y_cupy = net_cupy(x_seq)
y_cupy.sum().backward()
x_grad_cupy = x_seq.grad.clone()
x_seq.grad.zero_()
net_torch = neuron.IFNode(backend='torch', step_mode='m')
y_torch = net_torch(x_seq)
y_torch.sum().backward()
x_grad_torch = x_seq.grad.clone()
print('max error of y_seq', max_error(y_cupy, y_torch))
print('max error of x_seq.grad', max_error(x_grad_cupy, x_grad_torch))
The outputs are:
max error of y_seq tensor(0., device='cuda:0')
max error of x_seq.grad tensor(1.3113e-06, device='cuda:0')
We can find that the error is almost zero, indicating that our implementation is correct.
Then let us evaluate the speed. The following experiment is running on NVIDIA Quadro RTX 6000
:
from spikingjelly.activation_based import neuron, cuda_utils, functional
def forward_backward(net: torch.nn.Module, x_seq: torch.Tensor):
y_seq = net(x_seq)
y_seq.sum().backward()
x_seq.grad.zero_()
functional.reset_net(net)
N = 64
C = 32 * 32 * 32
device = 'cuda:0'
net_cupy = CUPYIFNode()
net_torch = neuron.IFNode(backend='torch', step_mode='m')
repeats = 16
for dtype in [torch.float, torch.half]:
for T in [2, 4, 8, 16, 32]:
x_seq = torch.rand([T, N, C], device=device, requires_grad=True, dtype=dtype)
t_cupy = cuda_utils.cal_fun_t(repeats, device, forward_backward, net_cupy, x_seq)
t_torch = cuda_utils.cal_fun_t(repeats, device, forward_backward, net_torch, x_seq)
print(f'dtype={dtype}, T={T},'.ljust(30), f't_torch / t_cupy = {round(t_torch / t_cupy, 2)}')
The outputs are:
dtype=torch.float32, T=2, t_torch / t_cupy = 0.59
dtype=torch.float32, T=4, t_torch / t_cupy = 1.47
dtype=torch.float32, T=8, t_torch / t_cupy = 2.67
dtype=torch.float32, T=16, t_torch / t_cupy = 4.17
dtype=torch.float32, T=32, t_torch / t_cupy = 6.93
dtype=torch.float16, T=2, t_torch / t_cupy = 0.68
dtype=torch.float16, T=4, t_torch / t_cupy = 1.31
dtype=torch.float16, T=8, t_torch / t_cupy = 2.2
dtype=torch.float16, T=16, t_torch / t_cupy = 4.77
dtype=torch.float16, T=32, t_torch / t_cupy = 6.7
We can find that when using T >= 4
, our neuron with CUPY kernel is much faster than the python neuron.
When T
is small, due to the jit acceleration used in SpikingJelly, the python neuron is faster. It is caused by that the jit is faster when the operation is simple. For example, we can hardly write an element-wise CUDA kernel that is faster than jit.
Convert to Lava for Loihi Deployment
Author: fangwei123456
Thanks to AllenYolk and banzhuangonglxh for their contributions to lava_exchange
Introduction of Lava
Lava is a neuromorphic computing framework, which is mainly developed by Intel and supports deploying on Intel Loihi. Lava provides a sub-package Lava DL for deep learning, which can be used to build and train deep SNNs.
To deploy SNNs on Loihi, we need to use Lava. SpikingJelly provides conversion modules to convert the SNN trained by SpikingJelly to the Lava SNN format. And then we can run this SNN on Loihi. The workflow is:
SpikingJelly -> Lava DL -> Lava -> Loihi
The modules related to Lava are defined in spikingjelly.activation_based.lava_exchange
.
Basic Conversion
Data Format Conversion
The default data format in Lava DL is shape = [N, *, T]
, where N
is the batch dimension and T
is the time-step dimension. However, the module of SpikingJelly in multi-step mode (step_mode = 'm'
) uses the data format as shape = [T, N, *]
. Thus, lava_exchange
provides two conversion functions, TNX_to_NXT
and NXT_to_TNX
for conversion between two formats. Here is an example:
import torch
from spikingjelly.activation_based import lava_exchange
T = 6
N = 4
C = 2
x_seq = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq)
print(f'x_seq_la.shape=[N, C, T]={x_seq_la.shape}')
x_seq_sj = lava_exchange.NXT_to_TNX(x_seq_la)
print(f'x_seq_sj.shape=[T, N, C]={x_seq_sj.shape}')
The outputs are:
x_seq_la.shape=[N, C, T]=torch.Size([4, 2, 6])
x_seq_sj.shape=[T, N, C]=torch.Size([6, 4, 2])
Neuron Conversion
Neurons in SpikingJelly can be converted to neurons in Lava DL. Due to the limited time and energy of developers, SpikingJelly only supports the IF neuron and the LIF neuron, which are two of the most popular neurons in spiking deep learning. Other neurons will be considered to add according to user requirements.
We can use to_lava_neuron
to convert. Here is an example:
import torch
from spikingjelly.activation_based import lava_exchange, neuron
if_sj = neuron.IFNode(v_threshold=1., v_reset=0., step_mode='m')
if_la = lava_exchange.to_lava_neuron(if_sj)
T = 8
N = 2
C = 1
x_seq_sj = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq_sj)
print('output of sj(reshaped to NXT):\n', lava_exchange.TNX_to_NXT(if_sj(x_seq_sj)))
print('output of lava:\n', if_la(x_seq_la))
The outputs are:
output of sj(reshaped to NXT):
tensor([[[0., 0., 1., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 1., 0., 1., 0., 1.]]])
output of lava:
tensor([[[0., 0., 1., 0., 1., 0., 0., 0.]],
[[0., 1., 0., 1., 0., 1., 0., 1.]]])
Here is an example of using the LIF neuron:
import torch
from spikingjelly.activation_based import lava_exchange, neuron
if_sj = neuron.LIFNode(tau=50., decay_input=False, v_threshold=1., v_reset=0., step_mode='m')
if_la = lava_exchange.to_lava_neuron(if_sj)
T = 8
N = 2
C = 1
x_seq_sj = torch.rand([T, N, C])
x_seq_la = lava_exchange.TNX_to_NXT(x_seq_sj)
print('output of sj:\n', lava_exchange.TNX_to_NXT(if_sj(x_seq_sj)))
print('output of lava:\n', if_la(x_seq_la))
The outputs are:
output of sj:
tensor([[[0., 1., 0., 1., 0., 0., 1., 0.]],
[[0., 0., 1., 0., 0., 1., 0., 1.]]])
output of lava:
tensor([[[0., 1., 0., 1., 0., 0., 1., 0.]],
[[0., 0., 1., 0., 0., 1., 0., 1.]]])
Synapse Conversion
The frequently-used convolutional layer, linear layer, and pooling layer can be converted. Note that
bias is not supported
Lava only supports sum pooling, which can be regarded as average pooling without average
Here is an example:
from spikingjelly.activation_based import lava_exchange, layer
conv = layer.Conv2d(3, 4, kernel_size=3, stride=1, bias=False)
fc = layer.Linear(4, 2, bias=False)
ap = layer.AvgPool2d(2, 2)
conv_la = lava_exchange.conv2d_to_lava_synapse_conv(conv)
fc_la = lava_exchange.linear_to_lava_synapse_dense(fc)
sp_la = lava_exchange.avgpool2d_to_lava_synapse_pool(ap)
print(f'conv_la={conv_la}')
print(f'fc_la={fc_la}')
print(f'sp_la={sp_la}')
The outputs are:
WARNING:root:The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer.
conv_la=Conv(3, 4, kernel_size=(3, 3, 1), stride=(1, 1, 1), bias=False)
fc_la=Dense(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
sp_la=Pool(1, 1, kernel_size=(2, 2, 1), stride=(2, 2, 1), bias=False)
Almost all synapses in Lava DL are based on torch.nn.Conv3d
. Thus, when we print them, we will find that kernel_size
and stride
are tuples with three elements.
BlockContainer
The workflow for using Lava DL is:
using Blocks in Lava DL to build and train the deep SNN
exporting the SNN to the hdf5 file
using Lava to read the hdf5 file and rebuild the SNN, then the SNN can run on Loihi or the CPU-simulated Loihi
For more details, please refer to Lava: Deep Learning.
Blocks can be regarded as the ensemble of a synapse layer and a neuron layer. For example, lava.lib.dl.slayer.block.cuba.Conv
is composed of a convolutional layer and a CUDA LIF neuron layer.
Note that Blocks
is designed for SNN deployment. Thus, synapses and neuronal dynamics are quantized in Blocks
. Thus, Blocks
is not a simple synapse + neuron ``, but ``quantize(synapse) + quantize(neuron)
.
SpikingJelly provides BlockContainer
to mimic Blocks
in Lava. The features of BlockContainer
are as follows:
supports for surrogate gradient training
synapses and neuronal dynamics are quantized
the outputs are identical to
Blocks
of Lava DL when giving the same inputssupports for converting to
lava.lib.dl.slayer.block
For the moment, BlockContainer
only supports for lava_exchange.CubaLIFNode
. But it also supports for converting IFNode
or LIFNode
in init args to CubaLIFNode
. Here is an example:
from spikingjelly.activation_based import lava_exchange, layer, neuron
fc_block_sj = lava_exchange.BlockContainer(
synapse=layer.Linear(8, 1, bias=False),
neu=neuron.IFNode(),
step_mode='m'
)
print('fc_block_sj=\n', fc_block_sj)
fc_block_la = fc_block_sj.to_lava_block()
print('fc_block_la=\n', fc_block_la)
The outputs are:
fc_block_sj=
BlockContainer(
(synapse): Linear(in_features=8, out_features=1, bias=False)
(neuron): CubaLIFNode(
v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
(surrogate_function): Sigmoid(alpha=4.0, spiking=True)
)
)
fc_block_la=
Dense(
(neuron): Neuron()
(synapse): Dense(8, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
)
MNIST CSNN Example
Now let us train a spiking convolutional SNN for classifying MNIST, and then convert this network to Lava DL format.
The SNN is defined as:
class MNISTNet(nn.Module):
def __init__(self, channels: int = 16):
super().__init__()
self.conv_fc = nn.Sequential(
lava_exchange.BlockContainer(
nn.Conv2d(1, channels, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
# 14 * 14
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
# 7 * 7
lava_exchange.BlockContainer(
nn.Flatten(),
None
),
lava_exchange.BlockContainer(
nn.Linear(channels * 7 * 7, 128, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
lava_exchange.BlockContainer(
nn.Linear(128, 10, bias=False),
neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)
),
)
def forward(self, x):
return self.conv_fc(x)
We add a conversion function to convert the SNN to Lava DL format, which can be used after training:
def to_lava(self):
ret = []
for i in range(self.conv_fc.__len__()):
m = self.conv_fc[i]
if isinstance(m, lava_exchange.BlockContainer):
ret.append(m.to_lava_block())
return nn.Sequential(*ret)
Then, we train this SNN. The training process has no much difference from other SNNs. Note that the quantization inside lava_exchange.BlockContainer
will
reduce accuracy. An example of the training codes is:
encoder = encoding.PoissonEncoder(step_mode='m')
# ...
for img, label in train_data_loader:
optimizer.zero_grad()
img = img.to(args.device)
label = label.to(args.device)
img = img.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)
fr = net(encoder(img)).mean(0)
loss = F.cross_entropy(fr, label)
loss.backward()
optimizer.step()
# ...
After training, we can convert this SNN to Lava DL and check the accuracy:
net_ladl = net.to_lava().to(args.device)
net_ladl.eval()
test_loss = 0
test_acc = 0
test_samples = 0
with torch.no_grad():
for img, label in test_data_loader:
img = img.to(args.device)
label = label.to(args.device)
img = img.unsqueeze(0).repeat(args.T, 1, 1, 1, 1)
img = encoder(img)
img = lava_exchange.TNX_to_NXT(img)
fr = net_ladl(img).mean(-1)
loss = F.cross_entropy(fr, label)
test_samples += label.numel()
test_loss += loss.item() * label.numel()
test_acc += (fr.argmax(1) == label).float().sum().item()
test_loss /= test_samples
test_acc /= test_samples
print('test acc[lava dl] =', test_acc)
Finally, we can export the SNN in Lava DL format to an hdf5 file, which can then be read by Lava. Lava can rebuild the SNN and run the SNN on Loihi, or the CPU-simulated Loihi.Refer to Network Exchange (NetX) Library for more details.
The export function is:
def export_hdf5(net, filename):
# network export to hdf5 format
h = h5py.File(filename, 'w')
layer = h.create_group('layer')
for i, b in enumerate(net):
handle = layer.create_group(f'{i}')
b.export_hdf5(handle)
export_hdf5(net_ladl, os.path.join(args.out_dir, 'net_la.net'))
The complete codes are stored in spikingjelly.activation_based.examples.lava_mnist
. The arguments are defined as:
(lava-env) wfang@mlg-ThinkStation-P920:~/tempdir/w1$ python -m spikingjelly.activation_based.examples.lava_mnist -h
usage: lava_mnist.py [-h] [-T T] [-b B] [-device DEVICE] [-data-dir DATA_DIR]
[-channels CHANNELS] [-epochs EPOCHS] [-lr LR] [-out-dir OUT_DIR]
options:
-h, --help show this help message and exit
-T T simulating time-steps
-b B batch size
-device DEVICE device
-data-dir DATA_DIR root dir of the MNIST dataset
-channels CHANNELS channels of CSNN
-epochs EPOCHS training epochs
-lr LR learning rate
-out-dir OUT_DIR path for saving weights
When we run this script, it will firstly train a SNN, then convert the SNN to Lava DL format and run an inference, and finally export the SNN to the hdf5 file:
(lava-env) wfang@mlg-ThinkStation-P920:~/tempdir/w1$ python -m spikingjelly.activation_based.examples.lava_mnist -T 32 -device cuda:0 -b 128 -epochs 16 -data-dir /datasets/MNIST/ -lr 0.1 -channels 16
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
epoch = 0, train_loss = 1.7607, train_acc = 0.7245, test_loss = 1.5243, test_acc = 0.9443, max_test_acc = 0.9443
# ...
Namespace(T=32, b=128, device='cuda:0', data_dir='/datasets/MNIST/', channels=16, epochs=16, lr=0.1, out_dir='./')
epoch = 15, train_loss = 1.4743, train_acc = 0.9881, test_loss = 1.4760, test_acc = 0.9855, max_test_acc = 0.9860
finish training
test acc[sj] = 0.9855
test acc[lava dl] = 0.9863
save net.state_dict() to ./net.pt
save net_ladl.state_dict() to ./net_ladl.pt
export net_ladl to ./net_la.net
Modules Docs
Indices and tables
Citation
If you use SpikingJelly in your work, please cite it as follows:
@misc{SpikingJelly,
title = {SpikingJelly},
author = {Fang, Wei and Chen, Yanqi and Ding, Jianhao and Chen, Ding and Yu, Zhaofei and Zhou, Huihui and Tian, Yonghong and other contributors},
year = {2020},
howpublished = {\url{https://github.com/fangwei123456/spikingjelly}},
note = {Accessed: YYYY-MM-DD},
}
Note: To specify the version of framework you are using, the default value YYYY-MM-DD in the note field should be replaced with the date of the last change of the framework you are using, i.e. the date of the latest commit.
Publications using SpikingJelly are recorded in Publications using SpikingJelly. If you use SpikingJelly in your paper, you can also add it to this table by pull request.
About
Multimedia Learning Group, Institute of Digital Media (NELVT), Peking University and Peng Cheng Laboratory are the main developers of SpikingJelly.


The list of developers can be found at contributors.
spikingjelly.activation_based package
spikingjelly.activation_based.auto_cuda package
Module contents
- class spikingjelly.activation_based.auto_cuda.base.CKernel(kernel_name: str)[源代码]
基类:
object
- 参数
kernel_name (str) – the name of kernel
The base python class for simplifying the using of custom CUDA kernel.
Some critical attributes:
- cparams:
a dict for saving parameters name and type.
- reserved_cnames:
a list for saving reserved variables names, which can not be used to name variable again.
Here is an example:
from spikingjelly.activation_based.auto_cuda import base example_ck = base.CKernel(kernel_name='example_ck') print(example_ck.full_codes)
The outputs are:
#include <cuda_fp16.h> extern "C" __global__ void example_ck( ) {}
A
CKernel
is composed of three parts: declaration, head, core, and tail. When settinglogging level <= DEBUG
, some debug information will be added to cuda codes or printed. And we can check where is each part. Here is an example:import logging logging.basicConfig(level=logging.DEBUG) from spikingjelly.activation_based.auto_cuda import base example_ck = base.CKernel(kernel_name='example_ck') print(example_ck.full_codes)
The outputs are:
//------declaration start------ #include <cuda_fp16.h> extern "C" __global__ void example_ck( ) //------declaration end-------- //------head start------ { //------head end-------- //------core start------ //------core end-------- //------tail start------ } //------tail end--------
In most cases,
CKernel
is used as a base class. Refer toCKernel1D
andCKernel2D
for more details.- check_attributes(**kwargs)[源代码]
- 参数
kwargs (dict) – a dict of attributes
- 返回
if all
value
inkwargs[key]
is identical toself.__getattribute__(key)
- 返回类型
This function can be used to check if a
CKernel
is changed by if any of its attributes changes.
- property core
- set_contiguous(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict whose value is
torch.Tensor
orcupy.ndarray
Check if all values in py_dict are
torch.Tensor
orcupy.ndarray
and contiguous. If not, this function will raise an error.
- get_device(py_dict: dict) int [源代码]
- 参数
py_dict (dict) – a dict
Traverse the dict and return the device id of the first met
torch.Tensor
. If notorch.Tensor
inpy_dict
, this function will raise an error.
- check_device(device: int, py_dict: dict)[源代码]
-
Check if the device id of each
torch.Tensor
orcupy.ndarray
in py_dict is identical todevice
. If not, this function will raise an error.
- check_keys(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict
Check if keys of
py_dict
are identical to keys ofself.cparams
. If not, this function will raise an error.
- check_ctypes(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict
Check if the value in
py_dict
has the correspondingctype
inself.cparams
, which includes:torch.float
ornp.float32
——'const float'
or'float'
torch.half
ornp.float16
——'const half2'
or'half2'
np.int
——'const int'
or'int'
If not, this function will raise an error.
- get_ptrs(py_dict: dict)[源代码]
-
Get the address of the first element of each
torch.Tensor
orcupy.ndarray
inpy_dict
.
- __call__(grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs)[源代码]
- 参数
Execute the CUDA kernel.
*args_1, **kwargs
are used as*args_1, **kwargs
incupy.RawKernel
.py_dict
should containkey: value
wherekey
is the cuda kernel function param name, andvalue
is the variable. This dict should be one-to-one correspondence toself.cparams
.For example, if
self.cparams
is{ 'numel': 'const int &', 'x': 'const float *', 'y': 'const float *' }
Then
py_dict
sould be{ 'numel': numel, 'x': x, 'y': y }
where
numel, x, y
should betorch.Tensor
orcupy.ndarray
with the corresponding data type, e.g.,x
inpy_dict
should have data typetorch.float
becausex
inself.cparams
have value'const float *'
.The keys order is arbitrary because this function will sort keys to align formal and actual parameters.
- add_param(ctype: str, cname: str)[源代码]
-
Add a param to
self.cparams
.Note
When calling
self.__call__
, the params order in the CUDA kernel are sorted by the dictionary order. Thus, the user do not need to calladd_param
by some specific order.Here is an example:
from spikingjelly.activation_based.auto_cuda import base example_ck = base.CKernel(kernel_name='example_ck') print('origin:') print(example_ck.full_codes) example_ck.add_param(ctype='const float*', cname='x') example_ck.add_param(ctype='const float*', cname='y') example_ck.add_param(ctype='float', cname='z') print('after:') print(example_ck.full_codes)
origin: #include <cuda_fp16.h> extern "C" __global__ void example_ck( const int & numel ) after: #include <cuda_fp16.h> extern "C" __global__ void example_ck( const int & numel, const float* x, const float* y, float z )
- property declaration
- property head
- property tail
- property full_codes
the full cuda codes :rtype: str
- Type
return
- class spikingjelly.activation_based.auto_cuda.base.CKernel1D(*args, **kwargs)[源代码]
基类:
CKernel
- 参数
kernel_name (str) – the name of kernel
The 1D (element-wise) CUDA kernel, which is extended from
CKernel
. All input/output tensors will be regarded as 1D tensors.Some critical attributes:
- cparams:
A dict for saving parameters name and type. The default value is
{'numel': 'const int &'}
.numel
represents the numel of elements for element-wise operations, which is also the numer of cuda threads.- reserved_cnames:
A list for saving reserved variables names, which can not be used to name variable again. The defaule value is
['index']
.index
represents the index of element, which is also the cuda thread index.
Now let us check what the empty 1d kernel looks like:
from spikingjelly.activation_based.auto_cuda import base temp_kernel = base.CKernel1D(kernel_name='temp_kernel') print(temp_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h> extern "C" __global__ void temp_kernel( const int & numel ) { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < numel) { } }
With setting logging level, we can check each part of the kernel:
import logging logging.basicConfig(level=logging.DEBUG) from spikingjelly.activation_based.auto_cuda import base temp_kernel = base.CKernel1D(kernel_name='temp_kernel') print(temp_kernel.full_codes)
The outputs are:
//------declaration start------ #include <cuda_fp16.h> extern "C" __global__ void temp_kernel( const int & numel ) //------declaration end-------- //------head start------ { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < numel) { //------head end-------- //------core start------ //------core end-------- //------tail start------ } } //------tail end--------
self.code
can be specified by user.For example, if we want to write a heaviside kernel, we can implement it easily with the cuda code
y[index] = x[index] >= 0.0f ? 1.0f: 0.0f;
, and add two paramsx, y
, which are inputs and outputs.
Here is the example:
from spikingjelly.activation_based.auto_cuda import base c_heaviside = base.CKernel1D(kernel_name='heaviside') c_heaviside.add_param(ctype='const float *', cname='x') c_heaviside.add_param(ctype='float *', cname='y') c_heaviside.core = ''' y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; ''' print(c_heaviside.full_codes)
The outputs are:
#include <cuda_fp16.h> extern "C" __global__ void heaviside( const int & numel, const float * x, float * y ) { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < numel) { y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; } }
Here is an example of how to execute the kernel:
import torch from spikingjelly.activation_based import cuda_utils device = 'cuda:0' x = torch.rand([4, 4], device=device) - 0.5 y = torch.zeros_like(x) numel = x.numel() threads = 1024 blocks = cuda_utils.cal_blocks(numel, threads) print('x=') print(x) with cuda_utils.DeviceEnvironment(device=x.get_device()): numel = cupy.asarray(numel) py_dict = { 'numel': numel, 'x': x, 'y': y } c_heaviside((blocks, ), (threads, ), py_dict) print('y=') print(y)
The outputs are:
x= tensor([[-0.0423, -0.1383, -0.0238, 0.1018], [ 0.3422, 0.1449, -0.2938, -0.1858], [-0.3503, 0.0004, -0.4274, -0.2012], [-0.0227, 0.2229, -0.0776, 0.2687]], device='cuda:0') y= tensor([[0., 0., 0., 1.], [1., 1., 0., 0.], [0., 1., 0., 0.], [0., 1., 0., 1.]], device='cuda:0')
- property head
- property tail
- check_half2(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict
Check value in
py_dict
. If the value istorch.Tensor
withvalue.dtype == torch.half
orcupy.ndarray
withvalue.dtype == np.float16
, this function will check whether the number of elements of value is even.We assert when using half dtype, the numel should be even because we will use
half2
in CUDA kernel.Note
CKernel1D.__call__
will pad half tensor to even numel before executing the kernel. Thus, the user does not need to worry about padding.
- __call__(grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs)[源代码]
- 参数
Execute the CUDA kernel.
*args_1, **kwargs
are used as*args_1, **kwargs
incupy.RawKernel
.py_dict
should containkey: value
wherekey
is the cuda kernel function param name, andvalue
is the variable. This dict should be one-to-one correspondence toself.cparams
.For example, if
self.cparams
is{ 'numel': 'const int &', 'x': 'const float *', 'y': 'const float *' }
Then
py_dict
sould be{ 'numel': numel, 'x': x, 'y': y }
where
numel, x, y
should betorch.Tensor
orcupy.ndarray
with the corresponding data type, e.g.,x
inpy_dict
should have data typetorch.float
becausex
inself.cparams
have value'const float *'
.The keys order is arbitrary because this function will sort keys to align formal and actual parameters.
Note
All tensors in
py_dict
will be regarded as 1D.Note
If any tensor
x
inpy_dict
with data typetorch.half
ornp.float16
but odd numel will be flattened and padded byx = [x, x[-1]]
before executing the CUDA kernel. After execution, padded values inx
will be removed, andx
will be reshaped to the origin shape.
- simple_call(**kwargs)[源代码]
- 参数
kwargs (dict) – the dict that contains parameters for CUDA kernel
The simplified calling function, which is simplified from the standard calling function is
CKernel1D.simple_call
.Compared with
CKernel1D.simple_call
, the device, numel, numbers of CUDA threads and blocks are calculated automatically from tensors inkwargs
.Here is the example:
import torch from spikingjelly.activation_based import cuda_utils from spikingjelly.activation_based.auto_cuda import base c_heaviside = base.CKernel1D(kernel_name='heaviside') c_heaviside.add_param(ctype='const float *', cname='x') c_heaviside.add_param(ctype='float *', cname='y') c_heaviside.core = ''' y[index] = x[index] >= 0.0f ? 1.0f: 0.0f; ''' device = 'cuda:0' x = torch.rand([4, 4], device=device) - 0.5 y = torch.zeros_like(x) print('x=') print(x) c_heaviside.simple_call(x=x, y=y) print('y=') print(y)
The outputs are:
x= tensor([[-0.1706, 0.2063, -0.2077, 0.3335], [-0.0180, -0.2429, 0.3488, 0.1146], [ 0.0362, 0.1584, 0.4828, -0.1389], [-0.2684, 0.1898, 0.0560, 0.2058]], device='cuda:0') y= tensor([[0., 1., 0., 1.], [0., 0., 1., 1.], [1., 1., 1., 0.], [0., 1., 1., 1.]], device='cuda:0')
- class spikingjelly.activation_based.auto_cuda.base.CKernel2D(kernel_name: str, reverse: bool = False)[源代码]
基类:
CKernel
- 参数
The 2D CUDA kernel, which is extended from
CKernel
.All input/output tensors should have dimensions no more than 2. All 2D tensors will be regarded as
shape = [T, N]
, whereT
is the sequence length andN
is the elements number of data at one time-stepSome critical attributes:
- cparams:
A dict for saving parameters name and type. The default value is
{'numel': 'const int &', 'N': 'const int &'}
.N
: the number of elements number of sequence data at one time-step (the numel of 1-th dimension)numel
: the numel of elements in input/output tensors, which isT * N
- reserved_cnames:
A list for saving reserved variables names, which can not be used to name variable again. The defaule value is
['index', 'dt', 't']
.index
: the index in 1-th dimension, which is also the CUDA thread indext
: the index in 0-th dimensiondt
: used in CUDA kernel as the time-step stride. Whenx[t_py][j]
in python code is identical tox[t]
in CUDA code, thenx[t_py + 1][j]
in python code is identical tox[t + dt]
in CUDA code.
Now let us check what the empty 2d kernel looks like:
from spikingjelly.activation_based.auto_cuda import base temp_kernel = base.CKernel2D(kernel_name='temp_kernel') print(temp_kernel.full_codes)
The outputs are:
#include <cuda_fp16.h> extern "C" __global__ void temp_kernel( const int & numel, const int & N ) { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < N) { const int dt = N; for(int t = index; t < numel; t += dt) { } } }
With setting logging level, we can check each part of the kernel:
import logging logging.basicConfig(level=logging.DEBUG) from spikingjelly.activation_based.auto_cuda import base temp_kernel = base.CKernel2D(kernel_name='temp_kernel') print(temp_kernel.full_codes)
The outputs are:
//------declaration start------ #include <cuda_fp16.h> extern "C" __global__ void temp_kernel( const int & numel, const int & N ) //------declaration end-------- //------head start------ { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < N) { const int dt = N; //------pre_core start------ //------pre_core end-------- for(int t = index; t < numel; t += dt) { //------head end-------- //------core start------ //------core end-------- //------tail start------ } //------post_core start------ //------post_core end-------- } } //------tail end--------
self.pre_core, self.post_core, self.core
can be specified by user.Here is the example of how to implement the
cumsum
operation:import torch import cupy from spikingjelly.activation_based.auto_cuda import base from spikingjelly.activation_based import cuda_utils cumsum = base.CKernel2D(kernel_name='cumsum') cumsum.add_param(ctype='const float *', cname='x') cumsum.add_param(ctype='float *', cname='y') cumsum.core = ''' if (t - dt < 0) { y[t] = x[t]; } else { y[t] = x[t] + y[t - dt]; } ''' print(cumsum.full_codes) T = 4 N = 3 device = 'cuda:0' x = torch.randint(low=0, high=4, size=[T, N], device=device).float() y = torch.zeros_like(x) threads = 1024 blocks = cuda_utils.cal_blocks(N, threads) with cuda_utils.DeviceEnvironment(device=x.get_device()): numel = cupy.asarray(T * N) N = cupy.asarray(N) py_dict = { 'N': N, 'numel': numel, 'x': x, 'y': y } cumsum((blocks, ), (threads, ), py_dict) print('x=') print(x) print('y=') print(y)
The outputs are:
#include <cuda_fp16.h> extern "C" __global__ void cumsum( const int & numel, const int & N, const float * x, float * y ) { const int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < N) { const int dt = N; for(int t = index; t < numel; t += dt) { if (t - dt < 0) { y[t] = x[t]; } else { y[t] = x[t] + y[t - dt]; } } } }
x= tensor([[3., 0., 2.], [2., 0., 0.], [2., 3., 2.], [2., 1., 0.]], device='cuda:0') y= tensor([[3., 0., 2.], [5., 0., 2.], [7., 3., 4.], [9., 4., 4.]], device='cuda:0')
- property pre_core
- property post_core
- check_half2(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict
Check value in
py_dict
. If the value istorch.Tensor
withvalue.dtype == torch.half
orcupy.ndarray
withvalue.dtype == np.float16
, this function will check whether the number of elements of value is even.If the tensor
x
is 1D, it will be padded whenx.numel() % 2 != 0
. If the tensorx
is 2D, it will be padded whenx.shape[1] % 2 != 0
.We assert when using half dtype, the numel should be even because we will use
half2
in CUDA kernel.Note
CKernel2D.__call__
will pad half tensor to even numel before executing the kernel. Thus, the user does not need to worry about padding.
- __call__(grid: tuple, block: tuple, py_dict: dict, *args_1, **kwargs)[源代码]
- 参数
Execute the CUDA kernel.
*args_1, **kwargs
are used as*args_1, **kwargs
incupy.RawKernel
.py_dict
should containkey: value
wherekey
is the cuda kernel function param name, andvalue
is the variable. This dict should be one-to-one correspondence toself.cparams
.For example, if
self.cparams
is{ 'numel': 'const int &', 'x': 'const float *', 'y': 'const float *' }
Then
py_dict
sould be{ 'numel': numel, 'x': x, 'y': y }
where
numel, x, y
should betorch.Tensor
orcupy.ndarray
with the corresponding data type, e.g.,x
inpy_dict
should have data typetorch.float
becausex
inself.cparams
have value'const float *'
.The keys order is arbitrary because this function will sort keys to align formal and actual parameters.
Note
All tensors in
py_dict
should be 1D or 2D.Note
If any 1D tensor
x
inpy_dict
with data typetorch.half
ornp.float16
but odd numel will be flattened and padded byx = [x, x[-1]]
before executing the CUDA kernel.If any 2D tensor
x
with shape[T, N]
inpy_dict
with data typetorch.half
ornp.float16
butN
is odd, thenx
will be padded asx = [x, x[:, -1]]
, whose shape is[T, N + 1]
.After execution, padded values in
x
will be removed, andx
will be reshaped to the origin shape.
- property head
- property tail
- simple_call(**kwargs)[源代码]
- 参数
kwargs (dict) – the dict that contains parameters for CUDA kernel
The simplified calling function, which is simplified from the standard calling function is
CKernel2D.simple_call
.Compared with
CKernel2D.simple_call
, the device, N, numel, numbers of CUDA threads and blocks are calculated automatically from tensors inkwargs
.Here is the example:
import torch import cupy from spikingjelly.activation_based.auto_cuda import base from spikingjelly.activation_based import cuda_utils cumsum = base.CKernel2D(kernel_name='cumsum') cumsum.add_param(ctype='const float *', cname='x') cumsum.add_param(ctype='float *', cname='y') cumsum.core = ''' if (t - dt < 0) { y[t] = x[t]; } else { y[t] = x[t] + y[t - dt]; } ''' T = 4 N = 3 device = 'cuda:0' x = torch.randint(low=0, high=4, size=[T, N], device=device).float() y = torch.zeros_like(x) cumsum.simple_call(x=x, y=y) print('x=') print(x) print('y=') print(y)
The outputs are:
x= tensor([[0., 2., 1.], [1., 3., 1.], [2., 2., 0.], [2., 0., 1.]], device='cuda:0') y= tensor([[0., 2., 1.], [1., 5., 2.], [3., 7., 2.], [5., 7., 3.]], device='cuda:0')
- class spikingjelly.activation_based.auto_cuda.base.CodeTyper(indent_num: int)[源代码]
基类:
object
- 参数
indent_num (int) – the number of indents
A CUDA code formatter with adding indents. The full code can be accessed by
self.codes
.Here is an example:
from spikingjelly.activation_based.auto_cuda import base, cfunction code0 = cfunction.if_else(z='z', x='x', y='y', mask='mask', dtype='float') code1 = cfunction.sigmoid_backward(y='y', x='x', alpha=2., dtype='float') codes = '' codes += code0 codes += code1 print('// Without CodeTyper:') print('// ------------------') print(codes) print('// ------------------') ctyper = base.CodeTyper(4) ctyper.append(code0) ctyper.append(code1) print('// With CodeTyper:') print('// ------------------') print(ctyper.codes) print('// ------------------')
// Without CodeTyper: // ------------------ z = x * mask + y * (1.0f - mask);const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (2.0f) * x)); y = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (2.0f); // ------------------ // With CodeTyper: // ------------------ z = x * mask + y * (1.0f - mask); const float sigmoid_backward__sigmoid_ax = 1.0f / (1.0f + expf(- (2.0f) * x)); y = (1.0f - sigmoid_backward__sigmoid_ax) * sigmoid_backward__sigmoid_ax * (2.0f); // ------------------
- class spikingjelly.activation_based.auto_cuda.base.CodeBlock(env: CodeTyper)[源代码]
基类:
object
- 参数
env (CodeTyper) – a CodeTyper
A tool for adding a CUDA code block in
CodeTyper.code
. It is helpful when we want to calculate by intermediate variables.Here is an example:
from spikingjelly.activation_based.auto_cuda import base ctyper = base.CodeTyper(4) with base.CodeBlock(ctyper): ctyper.append('// swap x and y') ctyper.append('float temp_var = x;') ctyper.append('x = y;') ctyper.append('y = temp_var;') print(ctyper.codes)
The outputs are:
{ // swap x and y; float temp_var = x; x = y; y = temp_var; }
- spikingjelly.activation_based.auto_cuda.cfunction.if_else(z: str, x: str, y: str, mask: str, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.if_else_else(w: str, x: str, y: str, z: str, mask_x: str, mask_y: str, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.greater_equal(z: str, x: str, y: str, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.greater_than(z: str, x: str, y: str, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.sigmoid(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.sigmoid_backward(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.atan_backward(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.piecewise_leaky_relu_backward(y: str, x: str, w: float, c: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.s2nn_backward(y: str, x: str, alpha: float, beta: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.q_pseudo_spike_backward(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.leaky_k_relu_backward(y: str, x: str, leak: float, k: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.fake_numerical_gradient_backward(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.cfunction.log_tailed_relu_backward(y: str, x: str, alpha: float, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.neuron_kernel.neuronal_hard_reset(v_next: str, h: str, spike: str, v_reset: str, dtype: str = 'float')[源代码]
- spikingjelly.activation_based.auto_cuda.neuron_kernel.neuronal_soft_reset(v_next: str, h: str, spike: str, v_th: str, dtype: str = 'float')[源代码]
- spikingjelly.activation_based.auto_cuda.neuron_kernel.neuronal_fire(spike: str, v: str, v_th: str, dtype: str = 'float')[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronFPTTKernel(hard_reset: bool, dtype: str)[源代码]
基类:
CKernel2D
- neuronal_charge() str [源代码]
- 返回
CUDA code
- 返回类型
Returns CUDA code for calculating \(H[t] = f(X[t], V[t-1], ...)\).
This function should define how
h_seq[t]
is calculated byx_seq[t], v_v_seq[t]
and other params if the neuron needs.For example, the IF neuron define this function as:
def neuronal_charge(self) -> str: # note that v_v_seq[t] is v_seq[t - dt] return cfunction.add(z='h_seq[t]', x='x_seq[t]', y='v_v_seq[t]', dtype=self.dtype)
- property core
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronBPTTKernel(surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
基类:
CKernel2D
- property pre_core
- property post_core
- grad_h_next_to_v() str [源代码]
- 返回
CUDA code
- 返回类型
Returns CUDA code for calculating \(\frac{\mathrm{d} H[t+1]}{\mathrm{d} V[t]}\).
This function should define how
grad_h_next_to_v
is calculated. Note thatgrad_h_next_to_v
has not been declared. Thus, this function should also declaregrad_h_next_to_v
.For example, the IF neuron define this function as:
def grad_h_next_to_v(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_next_to_v', x=1., dtype=self.dtype)
- grad_h_to_x() str [源代码]
- 返回
CUDA code
- 返回类型
Returns CUDA code for calculating \(\frac{\mathrm{d} H[t]}{\mathrm{d} X[t]}\).
This function should define how
grad_h_to_x
is calculated. Note thatgrad_h_to_x
has not been declared. Thus, this function should also declaregrad_h_to_x
.For example, the IF neuron define this function as:
def grad_h_to_x(self) -> str: return cfunction.constant(y=f'const {self.dtype} grad_h_to_x', x=1., dtype=self.dtype)
- property core
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.IFNodeFPTTKernel(hard_reset: bool, dtype: str)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.IFNodeBPTTKernel(surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
- spikingjelly.activation_based.auto_cuda.neuron_kernel.scalar_to_cupy(py_dict: dict, ref: str = 'x_seq')[源代码]
- spikingjelly.activation_based.auto_cuda.neuron_kernel.new_tensors(news: tuple, py_dict: dict, ref: str = 'x_seq')[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.NeuronATGFBase[源代码]
基类:
object
- static pre_forward(py_dict: dict)[源代码]
- 参数
py_dict (dict) – a dict built from the neuron’s forward autograd function. It should at least contain
x_seq, v_init, v_reset
- 返回
requires_grad, blocks, threads, py_dict
- requires_grad: bool
if any tensor in
py_dict
requires grad, thenrequires_grad = True
;elserequires_grad = False
- blocks: int
CUDA param used in calling CUDA kernel
- threads: int
CUDA param used in calling CUDA kernel. The default value is
spikingjelly.configure.cuda_threads
- py_dict: dict
Compared with the input
py_dict
, the returnedpy_dict
will:convert all
float/int
scalars inpy_dict
tocupy.ndarray
add
h_seq, spike_seq, v_v_seq
topy_dict
.h_seq, spike_seq
are zero tensors with the same shape withx_seq
.v_v_seq
is concatenated fromv_init
andv_seq
, which is zero tensors with the same shape withx_seq
add
N, numel
topy_dict
. Note thatx_seq.shape = [T, N]
andnumel = T * N
. A specific case is thatx_seq.dtype == torch.half
, thenN = math.ceil(N / 2)
, andnumel = N * x_seq.shape[0]
. Note thatN, numel
in the returnedpy_dict
arecupy.ndarray
- 返回类型
- static ctx_save(ctx, requires_grad: bool, *args, **kwargs)[源代码]
- 参数
ctx –
ctx
intorch.autograd.Function
requires_grad (bool) – if any tensor in forward params requires grad
args – tensors that need to be saved by
ctx.save_for_backward
kwargs – items that need to be saved by
ctx.xx = xx
Saves
*args, **kwargs
inctx
byctx.save_for_backward(*args)
andctx.xx = xx
for allxx
inkwargs.items()
.
- static pre_backward(ctx, grad_spike_seq: Tensor, grad_v_seq: Tensor)[源代码]
- 参数
ctx –
ctx
intorch.autograd.Function
grad_spike_seq (torch.Tensor) – gradients of
spike_seq
grad_v_seq (torch.Tensor) – gradients of
v_seq
- 返回
backward_kernel, blocks, threads, py_dict
- backward_kernel: NeuronBPTTKernel
The CUDA kernel used for backward. It should be provided in
ctx.backward_kernel
- blocks: int
CUDA param used in calling CUDA kernel. It should be provided in
ctx.blocks
- threads: int
CUDA param used in calling CUDA kernel. It should be provided in
ctx.threads
- 返回类型
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.IFNodeATGF(*args, **kwargs)[源代码]
基类:
Function
- static forward(ctx, x_seq: Tensor, v_init: Tensor, v_th: float, v_reset: float, forward_kernel: IFNodeFPTTKernel, backward_kernel: IFNodeBPTTKernel)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.LIFNodeFPTTKernel(decay_input: bool, hard_reset: bool, dtype: str)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.LIFNodeBPTTKernel(decay_input: bool, surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.LIFNodeATGF(*args, **kwargs)[源代码]
基类:
Function
- static forward(ctx, x_seq: Tensor, v_init: Tensor, v_th: float, v_reset: float, decay: float, forward_kernel: LIFNodeFPTTKernel, backward_kernel: LIFNodeBPTTKernel)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.ParametricLIFNodeFPTTKernel(decay_input: bool, hard_reset: bool, dtype: str)[源代码]
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.ParametricLIFNodeBPTTKernel(decay_input: bool, surrogate_function: Callable, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
-
- property head
- property pre_core
- property core
- property tail
- class spikingjelly.activation_based.auto_cuda.neuron_kernel.ParametricLIFNodeATGF(*args, **kwargs)[源代码]
基类:
Function
- static forward(ctx, x_seq: Tensor, v_init: Tensor, v_th: float, v_reset: float, decay: Tensor, forward_kernel: ParametricLIFNodeFPTTKernel, backward_kernel: ParametricLIFNodeBPTTKernel)[源代码]
spikingjelly.activation_based.base package
Module contents
- spikingjelly.activation_based.base.check_backend_library(backend: str)[源代码]
-
- 参数
backend (str) –
'torch'
,'cupy'
或'lava'
检查某个后端的python库是否已经安装。若未安装则此函数会报错。
- 参数
backend (str) –
'torch'
,'cupy'
or'lava'
Check whether the python lib for backend is installed. If not, this function will raise an error.
- class spikingjelly.activation_based.base.StepModule[源代码]
基类:
object
- property step_mode
-
- 返回
模块当前使用的步进模式
- 返回类型
- 返回
the current step mode of this module
- 返回类型
- class spikingjelly.activation_based.base.SingleModule[源代码]
基类:
StepModule
只支持单步的模块 (
step_mode == 's'
)。The module that only supports for single-step (
step_mode == 's'
)
- class spikingjelly.activation_based.base.MultiStepModule[源代码]
基类:
StepModule
只支持多步的模块 (
step_mode == 'm'
)。The module that only supports for multi-step (
step_mode == 'm'
)
- class spikingjelly.activation_based.base.MemoryModule[源代码]
基类:
Module
,StepModule
MemoryModule
是SpikingJelly中所有有状态(记忆)模块的基类。MemoryModule
is the base class of all stateful modules in SpikingJelly.- property supported_backends
-
返回支持的后端,默认情况下只有 (‘torch’, )
Return the supported backends. The default return value is (‘torch’, )
- property backend
- abstract single_step_forward(x: Tensor, *args, **kwargs)[源代码]
-
- 参数
x (torch.Tensor) – input tensor with ``shape = [N, *] ``
本模块的单步的前向传播函数
- 参数
x (torch.Tensor) – input tensor with ``shape = [N, *] ``
The single-step forward function for this module
- multi_step_forward(x_seq: Tensor, *args, **kwargs)[源代码]
-
- 参数
x (torch.Tensor) – input tensor with ``shape = [T, N, *] ``
本模块的多步的前向传播函数,通过调用
T
次single_step_forward(x[t], *args, **kwargs)
实现- 参数
x (torch.Tensor) – input tensor with ``shape = [T, N, *] ``
The multi-step forward function for this module, which is implementd by calling
single_step_forward(x[t], *args, **kwargs)
overT
times
- register_memory(name: str, value)[源代码]
-
- 参数
name (str) – 变量的名字
value (any) – 变量的值
将变量存入用于保存有状态变量(例如脉冲神经元的膜电位)的字典中。这个变量的重置值会被设置为
value
。每次调用self.reset()
函数后,self.name
都会被重置为value
。- 参数
name (str) – variable’s name
value (any) – variable’s value
Register the variable to memory dict, which saves stateful variables (e.g., the membrane potential of a spiking neuron). The reset value of this variable will be
value
.self.name
will be set tovalue
after each calling ofself.reset()
.
- memories()[源代码]
-
- 返回
返回一个所有状态变量的迭代器
- 返回类型
Iterator
- 返回
an iterator over all stateful variables
- 返回类型
Iterator
- named_memories()[源代码]
-
- 返回
返回一个所有状态变量及其名称的迭代器
- 返回类型
Iterator
- 返回
an iterator over all stateful variables and their names
- 返回类型
Iterator
spikingjelly.activation_based.cuda_utils package
Module contents
- spikingjelly.activation_based.cuda_utils.cpu_timer(f: Callable, *args, **kwargs)[源代码]
-
计算在CPU上执行
f(*args, **kwargs)
所需的时间- 参数
f (Callable) – 函数
- 返回
用时,单位是毫秒
- 返回类型
Returns the used time for calling
f(*args, **kwargs)
in CPU- 参数
f (Callable) – a function
- 返回
used time in milliseconds
- 返回类型
- spikingjelly.activation_based.cuda_utils.cuda_timer(device: device, f: Callable, *args, **kwargs)[源代码]
-
计算在CUDA上执行
f(*args, **kwargs)
所需的时间- 参数
device (torch.device or int) –
f
运行的CUDA设备f (Callable) – 函数
- 返回
用时,单位是毫秒
- 返回类型
Returns the used time for calling
f(*args, **kwargs)
in CUDA- 参数
device (torch.device or int) – on which cuda device that
f
is runningf (Callable) – a function
- 返回
used time in milliseconds
- 返回类型
- spikingjelly.activation_based.cuda_utils.cal_fun_t(n: int, device: str, f: Callable, *args, **kwargs)[源代码]
-
测量在
device
上执行n
次f(*args, **kwargs)
的平均用时备注
当
n > 1
时,实际上会执行2n
次,然后返回后n
次的平均用时,以减小误差。- 参数
n (int) – 重复的次数
device (str or torch.device or int) –
f
执行的设备,可以为 ‘cpu’ 或CUDA设备f (Callable) – 函数
- 返回
用时,单位是毫秒
- 返回类型
Returns the used time averaged by calling
f(*args, **kwargs)
overn
timesNote
If
n > 1
, this function will callf
for2n
times and return the average used time by the lastn
times to reduce the measure error.- 参数
n (int) – repeat times
device (str or torch.device or int) – on which cuda device that
f
is running. It can be ‘cpu’ or a cuda deivcef (Callable) – function
- 返回
used time in milliseconds
- 返回类型
- spikingjelly.activation_based.cuda_utils.cal_blocks(numel: int, threads: int = -1)[源代码]
-
- 参数
- 返回
blocks的数量
- 返回类型
此函数返回 blocks的数量,用来按照
kernel((blocks,), (configure.cuda_threads,), ...)
调用cupy.RawKernel
- 参数
- 返回
the number of blocks
- 返回类型
Returns the number of blocks to call
cupy.RawKernel
bykernel((blocks,), (threads,), ...)
- spikingjelly.activation_based.cuda_utils.get_contiguous(*args)[源代码]
-
将
*args
中所有的torch.Tensor
或cupy.ndarray
进行连续化。备注
连续化的操作无法in-place,因此本函数返回一个新的list。
- 返回
一个元素全部为连续的
torch.Tensor
或cupy.ndarray
的list
- 返回类型
- 返回
a list that contains the contiguous
torch.Tensor
orcupy.ndarray
- 返回类型
Makes
torch.Tensor
orcupy.ndarray
in*args
to be contiguousNote
The making contiguous operation can not be done in-place. Hence, this function will return a new list.
- spikingjelly.activation_based.cuda_utils.wrap_args_to_raw_kernel(device: int, *args)[源代码]
-
此函数可以包装
torch.Tensor
和cupy.ndarray
并将其作为cupy.RawKernel.__call__
的args
- 参数
device (int) – on which CUDA device the raw kernel will run
- 返回
a
tuple
that contains args to callcupy.RawKernel
- 返回类型
This function can wrap
torch.Tensor
orcupy.ndarray
toargs
incupy.RawKernel.__call__
- class spikingjelly.activation_based.cuda_utils.DeviceEnvironment(device: int)[源代码]
基类:
object
这个模块可以被用作在指定的
device
上执行CuPy函数的上下文,用来避免 torch.cuda.current_device() 被CuPy意外改变( https://github.com/cupy/cupy/issues/6569 )。代码示例:
with DeviceEnvironment(device): kernel((blocks,), (configure.cuda_threads,), ...)
- 参数
device (int) – the CUDA device
This module is used as a context to make CuPy use the specific device, and avoids torch.cuda.current_device() is changed by CuPy ( https://github.com/cupy/cupy/issues/6569 ).
Codes example:
with DeviceEnvironment(device): kernel((blocks,), (configure.cuda_threads,), ...)
spikingjelly.activation_based.encoding package
Module contents
- class spikingjelly.activation_based.encoding.StatelessEncoder(step_mode='s')[源代码]
基类:
Module
,StepModule
无状态编码器的基类。无状态编码器
encoder = StatelessEncoder()
,直接调用encoder(x)
即可将x
编码为spike
。The base class of stateless encoder. The stateless encoder
encoder = StatelessEncoder()
can encodex
tospike
byencoder(x)
.- abstract forward(x: Tensor)[源代码]
-
- 参数
x (torch.Tensor) – 输入数据
- 返回
spike
, shape 与x.shape
相同- 返回类型
- 参数
x (torch.Tensor) – input data
- 返回
spike
, whose shape is same withx.shape
- 返回类型
- class spikingjelly.activation_based.encoding.StatefulEncoder(T: int, step_mode='s')[源代码]
基类:
MemoryModule
- 参数
T (int) – 编码周期。通常情况下,与SNN的仿真周期(总步长一致)
有状态编码器的基类。有状态编码器
encoder = StatefulEncoder(T)
,编码器会在首次调用encoder(x)
时对x
进行编码。在第t
次调用encoder(x)
时会输出spike[t % T]
encoder = StatefulEncoder(T) s_list = [] for t in range(T): s_list.append(encoder(x)) # s_list[t] == spike[t]
- 参数
T (int) – the encoding period. It is usually same with the total simulation time-steps of SNN
The base class of stateful encoder. The stateful encoder
encoder = StatefulEncoder(T)
will encodex
tospike
at the first time of callingencoder(x)
. It will outputspike[t % T]
at thet
-th callingencoder = StatefulEncoder(T) s_list = [] for t in range(T): s_list.append(encoder(x)) # s_list[t] == spike[t]
- single_step_forward(x: Optional[Tensor] = None)[源代码]
-
- 参数
x (torch.Tensor) – 输入数据
- 返回
spike
, shape 与x.shape
相同- 返回类型
- 参数
x (torch.Tensor) – input data
- 返回
spike
, whose shape is same withx.shape
- 返回类型
- abstract single_step_encode(x: Tensor)[源代码]
-
- 参数
x (torch.Tensor) – 输入数据
- 返回
spike
, shape 与x.shape
相同- 返回类型
- 参数
x (torch.Tensor) – input data
- 返回
spike
, whose shape is same withx.shape
- 返回类型
- class spikingjelly.activation_based.encoding.PeriodicEncoder(spike: Tensor, step_mode='s')[源代码]
-
- 参数
spike (torch.Tensor) – 输入脉冲
周期性编码器,在第
t
次调用时输出spike[t % T]
,其中T = spike.shape[0]
警告
不要忘记调用reset,因为这个编码器是有状态的。
- 参数
spike (torch.Tensor) – the input spike
The periodic encoder that outputs
spike[t % T]
att
-th calling, whereT = spike.shape[0]
Warning
Do not forget to reset the encoder because the encoder is stateful!
- class spikingjelly.activation_based.encoding.LatencyEncoder(T: int, enc_function='linear', step_mode='s')[源代码]
-
延迟编码器,将
0 <= x <= 1
的输入转化为在0 <= t_f <= T-1
时刻发放的脉冲。输入的强度越大,发放越早。 当enc_function == 'linear'
\[t_f(x) = (T - 1)(1 - x)\]- 当
enc_function == 'log'
- \[t_f(x) = (T - 1) - ln(\alpha * x + 1)\]
其中 \(lpha\) 满足 \(t_f(1) = T - 1\)
实例代码:
x = torch.rand(size=[8, 2]) print('x', x) T = 20 encoder = LatencyEncoder(T) for t om range(T): print(encoder(x))
警告
必须确保
0 <= x <= 1
。警告
不要忘记调用reset,因为这个编码器是有状态的。
- 参数
The latency encoder will encode
0 <= x <= 1
to spike whose firing time is0 <= t_f <= T-1
. A largerx
will cause a earlier firing time.- If
enc_function == 'linear'
- \[t_f(x) = (T - 1)(1 - x)\]
- If
enc_function == 'log'
- \[t_f(x) = (T - 1) - ln(\alpha * x + 1)\]
where \(lpha\) satisfies \(t_f(1) = T - 1\)
Example: .. code-block:: python
x = torch.rand(size=[8, 2]) print(‘x’, x) T = 20 encoder = LatencyEncoder(T) for t in range(T):
print(encoder(x))
Warning
The user must assert
0 <= x <= 1
.Warning
Do not forget to reset the encoder because the encoder is stateful!
- 当
- class spikingjelly.activation_based.encoding.PoissonEncoder(step_mode='s')[源代码]
-
无状态的泊松编码器。输出脉冲的发放概率与输入
x
相同。警告
必须确保
0 <= x <= 1
。The poisson encoder will output spike whose firing probability is
x
。Warning
The user must assert
0 <= x <= 1
.
- class spikingjelly.activation_based.encoding.WeightedPhaseEncoder(K: int, step_mode='s')[源代码]
-
- 参数
K (int) – 编码周期。通常情况下,与SNN的仿真周期(总步长一致)
Kim J, Kim H, Huh S, et al. Deep neural networks with weighted spikes[J]. Neurocomputing, 2018, 311: 373-386.
带权的相位编码,一种基于二进制表示的编码方法。
将输入按照二进制各位展开,从高位到低位遍历输入进行脉冲编码。相比于频率编码,每一位携带的信息量更多。编码相位数为 \(K\) 时, 可以对于处于区间 \([0, 1-2^{-K}]\) 的数进行编码。以下为原始论文中的示例:
Phase (K=8)
1
2
3
4
5
6
7
8
Spike weight \(\omega(t)\)
2-1
2-2
2-3
2-4
2-5
2-6
2-7
2-8
192/256
1
1
0
0
0
0
0
0
1/256
0
0
0
0
0
0
0
1
128/256
1
0
0
0
0
0
0
0
255/256
1
1
1
1
1
1
1
1
警告
不要忘记调用reset,因为这个编码器是有状态的。
- 参数
K (int) – the encoding period. It is usually same with the total simulation time-steps of SNN
The weighted phase encoder, which is based on binary system. It will flatten
x
as a binary number. WhenT=k
, it can encode \(x \in [0, 1-2^{-K}]\) to different spikes. Here is the example from the origin paper:Phase (K=8)
1
2
3
4
5
6
7
8
Spike weight \(\omega(t)\)
2-1
2-2
2-3
2-4
2-5
2-6
2-7
2-8
192/256
1
1
0
0
0
0
0
0
1/256
0
0
0
0
0
0
0
1
128/256
1
0
0
0
0
0
0
0
255/256
1
1
1
1
1
1
1
1
Warning
Do not forget to reset the encoder because the encoder is stateful!
spikingjelly.activation_based.functional package
Module contents
- spikingjelly.activation_based.functional.reset_net(net: Module)[源代码]
-
- 参数
net – 任何属于
nn.Module
子类的网络- 返回
None
将网络的状态重置。做法是遍历网络中的所有
Module
,若m `` 为 ``base.MemoryModule
函数或者是拥有reset()
方法,则调用m.reset()
。- 参数
net – Any network inherits from
nn.Module
- 返回
None
Reset the whole network. Walk through every
Module
asm
, and callm.reset()
if thism
isbase.MemoryModule
orm
hasreset()
.
- spikingjelly.activation_based.functional.set_step_mode(net: Module, step_mode: str)[源代码]
-
- 参数
net (nn.Module) – 一个神经网络
step_mode (str) – ‘s’ (单步模式) 或 ‘m’ (多步模式)
- 返回
None
将
net
中所有模块的步进模式设置为step_mode
。备注
spikingjelly.activation_based.layer.StepModeContainer
,spikingjelly.activation_based.layer.ElementWiseRecurrentContainer
,spikingjelly.activation_based.layer.LinearRecurrentContainer
的子模块(不包含包装器本身)的step_mode
不会被改变。- 参数
net (nn.Module) – a network
step_mode (str) – ‘s’ (single-step) or ‘m’ (multi-step)
- 返回
None
Set
step_mode
for all modules innet
.Note
The submodule (not including the container itself) of
spikingjelly.activation_based.layer.StepModeContainer
,spikingjelly.activation_based.layer.ElementWiseRecurrentContainer
,spikingjelly.activation_based.layer.LinearRecurrentContainer
will not be changed.
- spikingjelly.activation_based.functional.set_backend(net: ~torch.nn.modules.module.Module, backend: str, instance: object = (<class 'torch.nn.modules.module.Module'>, ))[源代码]
-
- 参数
- 返回
None
将
net
中 所有类型为instance
的模块后端更改为backend
- 参数
- 返回
None
Sets backends of all modules whose instance is
instance
innet
tobackend
- spikingjelly.activation_based.functional.detach_net(net: Module)[源代码]
-
- 参数
net – 任何属于
nn.Module
子类的网络- 返回
None
将网络与之前的时间步的计算图断开。做法是遍历网络中的所有
Module
,若m `` 为 ``base.MemoryModule
函数或者是拥有detach()
方法,则调用m.detach()
。- 参数
net – Any network inherits from
nn.Module
- 返回
None
Detach the computation graph of the whole network from previous time-steps. Walk through every
Module
asm
, and callm.detach()
if thism
isbase.MemoryModule
orm
hasdetach()
.
- spikingjelly.activation_based.functional.spike_similar_loss(spikes: Tensor, labels: Tensor, kernel_type='linear', loss_type='mse', *args)[源代码]
-
- 参数
- 返回
shape=[1]的tensor,相似损失
将N个数据输入到输出层有M个神经元的SNN,运行T步,得到shape=[N, M, T]的脉冲。这N个数据的标签为shape=[N, C]的
labels
。用shape=[N, N]的矩阵
sim
表示实际相似度矩阵,sim[i][j] == 1
表示数据i与数据j相似,反之亦然。若labels[i]
与labels[j]
共享至少同一个标签,则认为他们相似,否则不相似。用shape=[N, N]的矩阵
sim_p
表示输出相似度矩阵,sim_p[i][j]
的取值为0到1,值越大表示数据i与数据j的脉冲越相似。使用内积来衡量两个脉冲之间的相似性,
kernel_type
是计算内积时,所使用的核函数种类:‘linear’,线性内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}}\)。
‘sigmoid’,Sigmoid内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{sigmoid}(\alpha \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})\),其中 \(\alpha = args[0]\)。
‘gaussian’,高斯内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{exp}(- \frac{||\boldsymbol{x_{i}} - \boldsymbol{y_{j}}||^{2}}{2\sigma^{2}})\),其中 \(\sigma = args[0]\)。
当使用Sigmoid或高斯内积时,内积的取值范围均在[0, 1]之间;而使用线性内积时,为了保证内积取值仍然在[0, 1]之间,会进行归一化:按照 \(\text{sim_p}[i][j]=\frac{\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}})}{||\boldsymbol{x_{i}}|| · ||\boldsymbol{y_{j}}||}\)。
对于相似的数据,根据输入的
loss_type
,返回度量sim
与sim_p
差异的损失:‘mse’ – 返回sim与sim_p的均方误差(也就是l2误差)。
‘l1’ – 返回sim与sim_p的l1误差。
‘bce’ – 返回sim与sim_p的二值交叉熵误差。
备注
脉冲向量稀疏、离散,最好先使用高斯核进行平滑,然后再计算相似度。
- 参数
spikes – shape=[N, M, T], output spikes corresponding to a batch of N inputs
labels – shape=[N, C], labels of inputs,
labels[i][k] == 1
means the i-th input belongs to the k-th category and vice versa. Multi-label input is allowed.kernel_type (str) – Type of kernel function used when calculating inner products. The inner product is the similarity measure of two spikes.
loss_type (str) – Type of loss returned. Can be: ‘mse’, ‘l1’, ‘bce’
args – Extra parameters for inner product
- 返回
shape=[1], similarity loss
A SNN consisting M neurons will receive a batch of N input data in each timestep (from 0 to T-1) and output a spike tensor of shape=[N, M, T]. The label is a tensor of shape=[N, C].
The groundtruth similarity matrix
sim
has a shape of [N, N].sim[i][j] == 1
indicates that input i is similar to input j and vice versa. If and only iflabels[i]
andlabels[j]
have at least one common label, they are viewed as similar.The output similarity matrix
sim_p
has a shape of [N, N]. The value ofsim_p[i][j]
ranges from 0 to 1, represents the similarity between output spike from both input i and input j.The similarity is measured by inner product of two spikes.
kernel_type
is the type of kernel function when calculating inner product:‘linear’, Linear kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}}\).
‘sigmoid’, Sigmoid kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{sigmoid}(\alpha \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})\), where \(\alpha = args[0]\).
‘gaussian’, Gaussian kernel,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{exp}(- \frac{||\boldsymbol{x_{i}} - \boldsymbol{y_{j}}||^{2}}{2\sigma^{2}})\), where \(\sigma = args[0]\).
When Sigmoid or Gaussian kernel is applied, the inner product naturally lies in \([0, 1]\). To make the value consistent when using linear kernel, the result will be normalized as: \(\text{sim_p}[i][j]=\frac{\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}})}{||\boldsymbol{x_{i}}|| · ||\boldsymbol{y_{j}}||}\).
For similar data, return the specified discrepancy loss between
sim
andsim_p
according toloss_type
.‘mse’ – Return the Mean-Square Error (squared L2 norm) between sim and sim_p.
‘l1’ – Return the L1 error between sim and sim_p.
‘bce’ – Return the Binary Cross Entropy between sim and sim_p.
Note
Since spike vectors are usually discrete and sparse, it would be better to apply Gaussian filter first to smooth the vectors before calculating similarities.
- spikingjelly.activation_based.functional.kernel_dot_product(x: Tensor, y: Tensor, kernel='linear', *args)[源代码]
-
- 参数
x – shape=[N, M]的tensor,看作是N个M维向量
y – shape=[N, M]的tensor,看作是N个M维向量
kernel (str) – 计算内积时所使用的核函数
args – 用于计算内积的额外的参数
- 返回
ret, shape=[N, N]的tensor,
ret[i][j]
表示x[i]
和y[j]
的内积
计算批量数据
x
和y
在核空间的内积。记2个M维tensor分别为 \(\boldsymbol{x_{i}}\) 和 \(\boldsymbol{y_{j}}\),kernel
定义了不同形式的内积:‘linear’,线性内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}}\)。
‘polynomial’,多项式内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = (\boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})^{d}\),其中 \(d = args[0]\)。
‘sigmoid’,Sigmoid内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{sigmoid}(\alpha \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})\),其中 \(\alpha = args[0]\)。
‘gaussian’,高斯内积,\(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{exp}(- \frac{||\boldsymbol{x_{i}} - \boldsymbol{y_{j}}||^{2}}{2\sigma^{2}})\),其中 \(\sigma = args[0]\)。
- 参数
x – Tensor of shape=[N, M]
y – Tensor of shape=[N, M]
kernel (str) – Type of kernel function used when calculating inner products.
args – Extra parameters for inner product
- 返回
ret, Tensor of shape=[N, N],
ret[i][j]
is inner product ofx[i]
andy[j]
.
Calculate inner product of
x
andy
in kernel space. These 2 M-dim tensors are denoted by \(\boldsymbol{x_{i}}\) and \(\boldsymbol{y_{j}}\).kernel
determine the kind of inner product:‘linear’ – Linear kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}}\).
‘polynomial’ – Polynomial kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = (\boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})^{d}\), where \(d = args[0]\).
‘sigmoid’ – Sigmoid kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{sigmoid}(\alpha \boldsymbol{x_{i}}^{T}\boldsymbol{y_{j}})\), where \(\alpha = args[0]\).
‘gaussian’ – Gaussian kernel, \(\kappa(\boldsymbol{x_{i}}, \boldsymbol{y_{j}}) = \mathrm{exp}(- \frac{||\boldsymbol{x_{i}} - \boldsymbol{y_{j}}||^{2}}{2\sigma^{2}})\), where \(\sigma = args[0]\).
- spikingjelly.activation_based.functional.set_threshold_margin(output_layer: BaseNode, label_one_hot: Tensor, eval_threshold=1.0, threshold0=0.9, threshold1=1.1)[源代码]
-
- 参数
- 返回
None
对于用来分类的网络,为输出层神经元的电压阈值设置一定的裕量,以获得更好的分类性能。
类别总数为C,网络的输出层共有C个神经元。网络在训练时,当输入真实类别为i的数据,输出层中第i个神经元的电压阈值会被设置成
threshold1
,而其他神经元的电压阈值会被设置成threshold0
。而在测试(推理)时,输出层中神经元的电压阈值被统一设置成eval_threshold
。- 参数
output_layer – The output layer of classification network, where the shape of output should be [batch_size, C]
label_one_hot – Labels in one-hot format, shape=[batch_size, C]
eval_threshold (float) – Voltage threshold of neurons in output layer when evaluating (inference)
threshold0 (float) – Voltage threshold of the corresponding neurons of negative samples in output layer when training
threshold1 (float) – Voltage threshold of the corresponding neurons of positive samples in output layer when training
- 返回
None
Set voltage threshold margin for neurons in the output layer to reach better performance in classification task.
When there are C different classes, the output layer contains C neurons. During training, when the input with groundtruth label i are sent into the network, the voltage threshold of the i-th neurons in the output layer will be set to
threshold1
and the remaining will be set tothreshold0
.During inference, the voltage thresholds of ALL neurons in the output layer will be set to
eval_threshold
.
- spikingjelly.activation_based.functional.redundant_one_hot(labels: Tensor, num_classes: int, n: int)[源代码]
-
- 参数
- 返回
shape=[batch_size, num_classes * n]的tensor
对数据进行冗余的one-hot编码,每一类用
n
个1和(num_classes - 1) * n
个0来编码。示例:
>>> num_classes = 3 >>> n = 2 >>> labels = torch.randint(0, num_classes, [4]) >>> labels tensor([0, 1, 1, 0]) >>> codes = functional.redundant_one_hot(labels, num_classes, n) >>> codes tensor([[1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 0., 0.], [0., 0., 1., 1., 0., 0.], [1., 1., 0., 0., 0., 0.]])
- 参数
- 返回
Tensor of shape=[batch_size, num_classes * n]
Redundant one-hot encoding for data. Each class is encoded to
n
1’s and(num_classes - 1) * n
0’se.g.:
>>> num_classes = 3 >>> n = 2 >>> labels = torch.randint(0, num_classes, [4]) >>> labels tensor([0, 1, 1, 0]) >>> codes = functional.redundant_one_hot(labels, num_classes, n) >>> codes tensor([[1., 1., 0., 0., 0., 0.], [0., 0., 1., 1., 0., 0.], [0., 0., 1., 1., 0., 0.], [1., 1., 0., 0., 0., 0.]])
- spikingjelly.activation_based.functional.first_spike_index(spikes: Tensor)[源代码]
-
- 参数
spikes – shape=[*, T],表示任意个神经元在t=0, 1, …, T-1,共T个时刻的输出脉冲
- 返回
index, shape=[*, T],为
True
的位置表示该神经元首次释放脉冲的时刻
输入若干个神经元的输出脉冲,返回一个与输入相同shape的
bool
类型的index。index为True
的位置,表示该神经元首次释放脉冲的时刻。示例:
>>> spikes = (torch.rand(size=[2, 3, 8]) >= 0.8).float() >>> spikes tensor([[[0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 1., 0.], [0., 1., 0., 0., 0., 1., 0., 1.]], [[0., 0., 1., 1., 0., 0., 0., 1.], [1., 1., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0.]]]) >>> first_spike_index(spikes) tensor([[[False, False, False, False, False, False, False, False], [ True, False, False, False, False, False, False, False], [False, True, False, False, False, False, False, False]], [[False, False, True, False, False, False, False, False], [ True, False, False, False, False, False, False, False], [False, False, False, True, False, False, False, False]]])
- 参数
spikes – shape=[*, T], indicates the output spikes of some neurons when t=0, 1, …, T-1.
- 返回
index, shape=[*, T], the index of
True
represents the moment of first spike.
Return an
index
tensor of the same shape of input tensor, which is the output spike of some neurons. The index ofTrue
represents the moment of first spike.e.g.:
>>> spikes = (torch.rand(size=[2, 3, 8]) >= 0.8).float() >>> spikes tensor([[[0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 1., 0.], [0., 1., 0., 0., 0., 1., 0., 1.]], [[0., 0., 1., 1., 0., 0., 0., 1.], [1., 1., 0., 0., 1., 0., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0.]]]) >>> first_spike_index(spikes) tensor([[[False, False, False, False, False, False, False, False], [ True, False, False, False, False, False, False, False], [False, True, False, False, False, False, False, False]], [[False, False, True, False, False, False, False, False], [ True, False, False, False, False, False, False, False], [False, False, False, True, False, False, False, False]]])
- spikingjelly.activation_based.functional.multi_step_forward(x_seq: Tensor, single_step_module: Module)[源代码]
-
- 参数
x_seq (Tensor) –
shape=[T, batch_size, ...]
的输入tensorsingle_step_module (torch.nn.Module or list[nn.Module] or tuple[nn.Module] or torch.nn.Sequential or Callable) – 一个或多个单步模块
- 返回
shape=[T, batch_size, ...]
的输出tensor- 返回类型
在单步模块
single_step_module
上使用多步前向传播。- 参数
x_seq (torch.Tensor) – the input tensor with
shape=[T, batch_size, ...]
single_step_module (torch.nn.Module or list[nn.Module] or tuple[nn.Module] or torch.nn.Sequential or Callable) – one or many single-step modules
- 返回
the output tensor with
shape=[T, batch_size, ...]
- 返回类型
torch.torch.Tensor
Applies multi-step forward on
single_step_module
.
- spikingjelly.activation_based.functional.chunk_multi_step_forward(split_size: int, x_seq: Tensor, multi_step_module: Module)[源代码]
-
- 参数
split_size (int) – 分割的尺寸
x_seq (Tensor) – 输入
multi_step_module (nn.Module) – 一个使用多步传播模式的网络
- 返回
输出
- 返回类型
Tensor
将
shape = [T, *]
的输入x_seq
拆分成多个shape = [split_size, *]
的小tensor(若T % split_size != 0
,最后 一个tensor的shape[0]
会小于split_size
),然后逐个输入到multi_step_module
中,再将输出重新拼接为shape = [split_size, *]
。chunk_multi_step_forward
可以在使用很大的T
进行不带梯度的推理(例如ANN2SNN)时使用,能够减少内存消耗量。示例代码:
import torch import torch.nn as nn from spikingjelly.activation_based import neuron, layer, functional net = nn.Sequential( layer.Linear(8, 4), neuron.IFNode(step_mode='m'), layer.Linear(4, 2), neuron.IFNode(step_mode='m'), ) x_seq = torch.rand([1024, 8]) with torch.no_grad(): y_seq = functional.chunk_multi_step_forward(16, x_seq, net) print(y_seq.shape) # torch.Size([1024, 2])
- 参数
split_size (int) – the split size
x_seq (Tensor) – the input tensor
multi_step_module (nn.Module) –
- 返回
the output tensor
- 返回类型
Tensor
Splits the input
x_seq
withshape = [T, *]
to many tensor chunks withshape = [split_size, *]
(ifT % split_size != 0
,shape[0]
of the last tensor chunk will be smaller thansplit_size
), and sends chunks tomulti_step_module
, then concatenates the outputs toshape = [split_size, *]
.chunk_multi_step_forward
can be used for inference with a largeT
(e.g., ANN2SNN) to reduce the memory consumption.Codes example:
import torch import torch.nn as nn from spikingjelly.activation_based import neuron, layer, functional net = nn.Sequential( layer.Linear(8, 4), neuron.IFNode(step_mode='m'), layer.Linear(4, 2), neuron.IFNode(step_mode='m'), ) x_seq = torch.rand([1024, 8]) with torch.no_grad(): y_seq = functional.chunk_multi_step_forward(16, x_seq, net) print(y_seq.shape) # torch.Size([1024, 2])
- spikingjelly.activation_based.functional.seq_to_ann_forward(x_seq: Tensor, stateless_module: Module)[源代码]
-
- 参数
x_seq (Tensor) –
shape=[T, batch_size, ...]
的输入tensorstateless_module (torch.nn.Module or list or tuple or torch.nn.Sequential or Callable) – 单个或多个无状态网络层
- 返回
the output tensor with
shape=[T, batch_size, ...]
- 返回类型
Tensor
- 参数
x_seq (Tensor) – the input tensor with
shape=[T, batch_size, ...]
stateless_module (torch.nn.Module or list or tuple or torch.nn.Sequential or Callable) – one or many stateless modules
- 返回
the output tensor with
shape=[T, batch_size, ...]
- 返回类型
Tensor
Applied forward on stateless modules
- spikingjelly.activation_based.functional.fused_conv2d_weight_of_convbn2d(conv2d: Conv2d, bn2d: BatchNorm2d)[源代码]
-
- 参数
conv2d (torch.nn.Conv2d) – 一个2D卷积层
bn2d (torch.nn.BatchNorm2d) – 一个2D的BN层
- 返回
the weight of this fused module
- 返回类型
Tensor
{Conv2d-BatchNorm2d}
模块可以合并为一个单个的{Conv2d}
,其中``BatchNorm2d`` 的参数会被吸收进Conv2d
。 本函数返回合并后的卷积的权重。备注
这里按照
conv2d.bias
为None
进行处理。原因参见 Disable bias for convolutions directly followed by a batch norm 。- 参数
conv2d (torch.nn.Conv2d) – a Conv2d layer
bn2d (torch.nn.BatchNorm2d) – a BatchNorm2d layer
- 返回
the weight of this fused module
- 返回类型
Tensor
A
{Conv2d-BatchNorm2d}
can be fused to a{Conv2d}
module withBatchNorm2d
‘s parameters being absorbed intoConv2d
. This function returns the weight of this fused module.Note
We assert
conv2d.bias
isNone
. See Disable bias for convolutions directly followed by a batch norm for more details.
- spikingjelly.activation_based.functional.fused_conv2d_bias_of_convbn2d(conv2d: Conv2d, bn2d: BatchNorm2d)[源代码]
-
- 参数
conv2d (torch.nn.Conv2d) – 一个2D卷积层
bn2d (torch.nn.BatchNorm2d) – 一个2D的BN层
- 返回
the weight of this fused module
- 返回类型
Tensor
{Conv2d-BatchNorm2d}
模块可以合并为一个单个的{Conv2d}
,其中``BatchNorm2d`` 的参数会被吸收进Conv2d
。 本函数返回合并后的卷积的偏置项。备注
这里按照
conv2d.bias
为None
进行处理。原因参见 Disable bias for convolutions directly followed by a batch norm 。- 参数
conv2d (torch.nn.Conv2d) – a Conv2d layer
bn2d (torch.nn.BatchNorm2d) – a BatchNorm2d layer
- 返回
the weight of this fused module
- 返回类型
Tensor
A
{Conv2d-BatchNorm2d}
can be fused to a{Conv2d}
module withBatchNorm2d
‘s parameters being absorbed intoConv2d
. This function returns the bias of this fused module.Note
We assert
conv2d.bias
isNone
. See Disable bias for convolutions directly followed by a batch norm for more details.
- spikingjelly.activation_based.functional.scale_fused_conv2d_weight_of_convbn2d(conv2d: Conv2d, bn2d: BatchNorm2d, k=None, b=None)[源代码]
-
- 参数
conv2d (torch.nn.Conv2d) – 一个2D卷积层
bn2d (torch.nn.BatchNorm2d) – 一个2D的BN层
- 返回
the weight of this fused module
- 返回类型
Tensor
{Conv2d-BatchNorm2d}
模块可以合并为一个单个的{Conv2d}
,其中``BatchNorm2d`` 的参数会被吸收进Conv2d
。 本函数对{Conv2d-BatchNorm2d}
模块整体的等效权重进行weight = k * weight + b
的线性变换。备注
这里按照
conv2d.bias
为None
进行处理。原因参见 Disable bias for convolutions directly followed by a batch norm 。- 参数
conv2d (torch.nn.Conv2d) – a Conv2d layer
bn2d (torch.nn.BatchNorm2d) – a BatchNorm2d layer
- 返回
the weight of this fused module
- 返回类型
Tensor
A
{Conv2d-BatchNorm2d}
can be fused to a{Conv2d}
module withBatchNorm2d
‘s parameters being absorbed intoConv2d
. This function applies a linear transformweight = k * weight + b
on the equivalent weight of the whole{Conv2d-BatchNorm2d}
.Note
We assert
conv2d.bias
isNone
. See Disable bias for convolutions directly followed by a batch norm for more details.
- spikingjelly.activation_based.functional.scale_fused_conv2d_bias_of_convbn2d(conv2d: Conv2d, bn2d: BatchNorm2d, k=None, b=None)[源代码]
-
- 参数
conv2d (torch.nn.Conv2d) – 一个2D卷积层
bn2d (torch.nn.BatchNorm2d) – 一个2D的BN层
- 返回
the weight of this fused module
- 返回类型
Tensor
{Conv2d-BatchNorm2d}
模块可以合并为一个单个的{Conv2d}
,其中``BatchNorm2d`` 的参数会被吸收进Conv2d
。 本函数对{Conv2d-BatchNorm2d}
模块整体的等效偏置项进行bias = k * bias + b
的线性变换。备注
这里按照
conv2d.bias
为None
进行处理。原因参见 Disable bias for convolutions directly followed by a batch norm 。- 参数
conv2d (torch.nn.Conv2d) – a Conv2d layer
bn2d (torch.nn.BatchNorm2d) – a BatchNorm2d layer
- 返回
the weight of this fused module
- 返回类型
Tensor
A
{Conv2d-BatchNorm2d}
can be fused to a{Conv2d}
module withBatchNorm2d
‘s parameters being absorbed intoConv2d
. This function applies a linear transformbias = k * bias + b
on the equivalent bias of the whole{Conv2d-BatchNorm2d}
.Note
We assert
conv2d.bias
isNone
. See Disable bias for convolutions directly followed by a batch norm for more details.
- spikingjelly.activation_based.functional.fuse_convbn2d(conv2d: Conv2d, bn2d: BatchNorm2d)[源代码]
-
- 参数
conv2d (torch.nn.Conv2d) – 一个2D卷积层
bn2d (torch.nn.BatchNorm2d) – 一个2D的BN层
- 返回
the weight of this fused module
- 返回类型
Tensor
{Conv2d-BatchNorm2d}
模块可以合并为一个单个的{Conv2d}
,其中``BatchNorm2d`` 的参数会被吸收进Conv2d
。 本函数对返回这个等效的合并后的{Conv2d}
。备注
这里按照
conv2d.bias
为None
进行处理。原因参见 Disable bias for convolutions directly followed by a batch norm 。- 参数
conv2d (torch.nn.Conv2d) – a Conv2d layer
bn2d (torch.nn.BatchNorm2d) – a BatchNorm2d layer
- 返回
the weight of this fused module
- 返回类型
Tensor
A
{Conv2d-BatchNorm2d}
can be fused to a{Conv2d}
module withBatchNorm2d
‘s parameters being absorbed intoConv2d
. This function returns the fused{Conv2d}
merged by{Conv2d-BatchNorm2d}
.Note
We assert
conv2d.bias
isNone
. See Disable bias for convolutions directly followed by a batch norm for more details.
- spikingjelly.activation_based.functional.temporal_efficient_training_cross_entropy(x_seq: Tensor, target: Tensor)[源代码]
-
- 参数
x_seq (torch.Tensor) –
shape=[T, N, C, *]
的预测值,其中C
是类别总数target (torch.Tensor) –
shape=[N]
的真实值,其中target[i]
是真实类别
- 返回
the temporal efficient training cross entropy
- 返回类型
Temporal efficient training (TET) 交叉熵损失, 是每个时间步的交叉熵损失的平均。
示例代码:
def tet_ce_for_loop_version(x_seq: torch.Tensor, target: torch.LongTensor): loss = 0. for t in range(x_seq.shape[0]): loss += F.cross_entropy(x_seq[t], target) return loss / x_seq.shape[0] T = 8 N = 4 C = 10 x_seq = torch.rand([T, N, C]) target = torch.randint(low=0, high=C - 1, size=[N]) print(f'max error = {(tet_ce_for_loop_version(x_seq, target) - temporal_efficient_training_cross_entropy(x_seq, target)).abs().max()}') # max error < 1e-6
- 参数
x_seq (torch.Tensor) – the predicted value with
shape=[T, N, C, *]
, whereC
is the number of classestarget (torch.Tensor) – the ground truth tensor with
shape=[N]
, wheretarget[i]
is the label
- 返回
the temporal efficient training cross entropy
- 返回类型
The temporal efficient training (TET) cross entropy, which is the mean of cross entropy of each time-step.
Codes example:
def tet_ce_for_loop_version(x_seq: torch.Tensor, target: torch.LongTensor): loss = 0. for t in range(x_seq.shape[0]): loss += F.cross_entropy(x_seq[t], target) return loss / x_seq.shape[0] T = 8 N = 4 C = 10 x_seq = torch.rand([T, N, C]) target = torch.randint(low=0, high=C - 1, size=[N]) print(f'max error = {(tet_ce_for_loop_version(x_seq, target) - temporal_efficient_training_cross_entropy(x_seq, target)).abs().max()}') # max error < 1e-6
Note
The TET cross entropy is proposed by Temporal Efficient Training of Spiking Neural Network via Gradient Re-weighting.
- spikingjelly.activation_based.functional.kaiming_normal_conv_linear_weight(net: Module)[源代码]
-
- 参数
net – 任何属于
nn.Module
子类的网络- 返回
None
使用kaiming normal初始化 ``net` 中的所有 :class:`torch.nn._ConvNd 和
torch.nn.Linear
的权重(不包括偏置项)。参见torch.nn.init.kaiming_normal_
。- 参数
net – Any network inherits from
nn.Module
- 返回
None
initialize all weights (not including bias) of
torch.nn._ConvNd
andtorch.nn.Linear
innet
by the kaiming normal. Seetorch.nn.init.kaiming_normal_
for more details.
- spikingjelly.activation_based.functional.delay(x_seq: Tensor, delay_steps: int)[源代码]
-
- 参数
x_seq (torch.Tensor) – 输入的序列,
shape = [T, *]
delay_steps (int) – 延迟的时间步数
- 返回
延迟后的序列
- 返回类型
延迟函数,可以用来延迟输入,使得
y[t] = x[t - delay_steps]
。缺失的数据用0填充。代码示例:
x = torch.rand([5, 2]) x[3:].zero_() x.requires_grad = True y = delay(x, 1) print('x=') print(x) print('y=') print(y) y.sum().backward() print('x.grad=') print(x.grad)
输出为:
x= tensor([[0.1084, 0.5698], [0.4563, 0.3623], [0.0556, 0.4704], [0.0000, 0.0000], [0.0000, 0.0000]], requires_grad=True) y= tensor([[0.0000, 0.0000], [0.1084, 0.5698], [0.4563, 0.3623], [0.0556, 0.4704], [0.0000, 0.0000]], grad_fn=<CatBackward0>) x.grad= tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.]])
- 参数
x_seq (torch.Tensor) – the input sequence with
shape = [T, *]
delay_steps (int) – the number of delayed time-steps
- 返回
the delayed sequence
- 返回类型
A delay function that can delay inputs and makes
y[t] = x[t - delay_steps]
. The nonexistent data will be regarded as 0.Codes example:
x = torch.rand([5, 2]) x[3:].zero_() x.requires_grad = True y = delay(x, 1) print('x=') print(x) print('y=') print(y) y.sum().backward() print('x.grad=') print(x.grad)
The outputs are:
x= tensor([[0.1084, 0.5698], [0.4563, 0.3623], [0.0556, 0.4704], [0.0000, 0.0000], [0.0000, 0.0000]], requires_grad=True) y= tensor([[0.0000, 0.0000], [0.1084, 0.5698], [0.4563, 0.3623], [0.0556, 0.4704], [0.0000, 0.0000]], grad_fn=<CatBackward0>) x.grad= tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.]])
- spikingjelly.activation_based.functional.fptt_online_training_init_w_ra(optimizer: Optimizer) list [源代码]
- spikingjelly.activation_based.functional.fptt_online_training(model: Module, optimizer: Optimizer, x_seq: Tensor, target_seq: Tensor, f_loss_t: Callable, alpha: float, w_ra: list) None [源代码]
- 参数
model (nn.Module) – the neural network
optimizer (torch.optim.Optimizer) – the optimizer for the network
x_seq (torch.Tensor) – the input sequence
target_seq (torch.Tensor) – the output sequence
f_loss_t (Callable) – the loss function, which should has the formulation of
def f_loss_t(x_t, y_t) -> torch.Tensor
alpha (float) – the hyper-parameter
w_ra (list) – the running average of params, which can be initialized by
spikingjelly.activation_based.functional.fptt_online_training_init_w_ra
The FPTT online learning method proposed by Training Recurrent Neural Networks via Forward Propagation Through Time and used for SNN in Accurate online training of dynamical spiking neural networks through Forward Propagation Through Time .
Example:
from spikingjelly.activation_based import neuron net = nn.Sequential( nn.Linear(8, 4), neuron.IFNode(), nn.Linear(4, 2), neuron.IFNode() ) optimizer = torch.optim.SGD(net.parameters(), lr=0.1) T = 4 N = 2 w_ra = fptt_online_training_init_w_ra(optimizer) for epoch in range(2): x_seq = torch.rand([T, N, 8]) target_seq = torch.rand([T, N, 2]) fptt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, alpha=0.1, w_ra=w_ra) functional.reset_net(net)
spikingjelly.activation_based.lava_exchange package
Module contents
- class spikingjelly.activation_based.lava_exchange.step_quantize_atgf(*args, **kwargs)[源代码]
基类:
Function
- spikingjelly.activation_based.lava_exchange.quantize_8b(x, scale, descale=False)[源代码]
Denote
k
as anint
,x[i]
will be quantized to the nearest2 * k / scale
, andk = {-128, -127, ..., 126, 127}
.
- class spikingjelly.activation_based.lava_exchange.BatchNorm2d(num_features: int, eps: float = 1e-05, momentum: float = 0.1, track_running_stats: bool = True, weight_exp_bits: int = 3, pre_hook_fx: ~typing.Callable = <function BatchNorm2d.<lambda>>)[源代码]
基类:
Module
- class spikingjelly.activation_based.lava_exchange.LeakyIntegratorStep(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.lava_exchange.CubaLIFNode(current_decay: Union[float, Tensor], voltage_decay: Union[float, Tensor], v_threshold: float = 1.0, v_reset: float = 0.0, scale=64, requires_grad=False, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), norm: Optional[BatchNorm2d] = None, detach_reset=False, step_mode='s', backend='torch', store_v_seq: bool = False, store_i_seq: bool = False)[源代码]
基类:
BaseNode
API in English
- 参数
current_decay (float | torch.Tensor) – 电流衰减常数
voltage_decay (float | torch.Tensor) – 电压衰减常数
v_threshold (float) – 神经元阈值电压。默认为1。
v_reset (float, None) – 重置电压,默认为0
scale (float) – 量化参数,控制神经元的量化精度(参考了lava-dl的cuba.Neuron)。默认为
1<<6
。 等效于``w_scale=int(scale)``,s_scale=int(scale * (1<<6))
,p_scale=1<<12
。requires_grad (bool) – 指明
current_decay
和voltage_decay
两个神经元参数是否可学习(是否需要梯度),默认为False
。detach_reset (bool) – 是否将reset的计算图分离,默认为
False
。step_mode (str) – 步进模式,可以为 ‘s’ (单步)或 ‘m’ (多步),默认为 ‘s’ 。
backend (str) – 使用哪种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。目前只支持torchstore_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.voltage_state
。 通常设置成False
,可以节省内存。store_i_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电流值self.i_seq
。设置为False
时计算完成后只保留最后一个时刻的电流,即shape = [N, *]
的self.current_state
。 通常设置成False
,可以节省内存。
\[I[t] = (1 - \alpha_{I})I[t-1] + X[t] V[t] = (1 - \alpha_{V})V[t-1] + I[t]\]- 参数
current_decay (float | torch.Tensor) – current decay constant
voltage_decay (float | torch.Tensor) – voltage decay constant
v_threshold (float) – threshold of the the neurons in this layer. Default to 1.
v_reset (float) – reset potential of the neurons in this layer, 0 by default
scale (float) – quantization precision (ref: lava-dl cuba.Neuron). Default to
1<<6
. Equivalent tow_scale=int(scale)
,s_scale=int(scale * (1<<6))
,p_scale=1<<12
.requires_grad (bool) – whether
current_decay
andvoltage_decay
are learnable. Default toFalse
.detach_reset (bool) – whether to detach the computational graph of reset in backward pass. Default to
False
.step_mode (str) – the step mode, which can be s (single-step) or m (multi-step). Default to ‘s’ .
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. Only torch is supported. :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.voltage_state
withshape = [N, *]
, which can reduce the memory consumption. Default toFalse
.store_i_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the current at each time-step toself.i_seq
withshape = [T, N, *]
. If set toFalse
, only the current at last time-step will be stored toself.current_state
withshape = [N, *]
, which can reduce the memory consumption. Default toFalse
.
\[I[t] = (1 - \alpha_{I})I[t-1] + X[t] V[t] = (1 - \alpha_{V})V[t-1] + I[t]\]- property scale
scale
- Type
Read-only attribute
- property s_scale
s_scale
- Type
Read-only attribute
- property p_scale
s_scale
- Type
Read-only attribute
- property store_i_seq
- property supported_backends
- spikingjelly.activation_based.lava_exchange.lava_neuron_forward(lava_neuron: Module, x_seq: Tensor, v: Tensor)[源代码]
- spikingjelly.activation_based.lava_exchange.step_quantize(x: Tensor, step: float = 1.0)[源代码]
- 参数
x (torch.Tensor) – the input tensor
step (float) – the quantize step
- 返回
quantized tensor
- 返回类型
The step quantize function. Here is an example:
# plt.style.use(['science', 'muted', 'grid']) fig = plt.figure(dpi=200, figsize=(6, 4)) x = torch.arange(-4, 4, 0.001) plt.plot(x, lava_exchange.step_quantize(x, 2.), label='quantize(x, step=2)') plt.plot(x, x, label='y=x', ls='-.') plt.legend() plt.grid(ls='--') plt.title('step quantize') plt.xlabel('Input') plt.ylabel('Output') plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.svg') plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.pdf')
- spikingjelly.activation_based.lava_exchange.linear_to_lava_synapse_dense(fc: Linear)[源代码]
- 参数
fc (nn.Linear) – a pytorch linear layer without bias
- 返回
a lava slayer dense synapse
- 返回类型
slayer.synapse.Dense
Codes example:
T = 4 N = 2 layer_nn = nn.Linear(8, 4, bias=False) layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn) x_seq = torch.rand([T, N, 8]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.conv2d_to_lava_synapse_conv(conv2d_nn: Conv2d)[源代码]
- 参数
conv2d_nn (nn.Conv2d) – a pytorch conv2d layer without bias
- 返回
a lava slayer conv synapse
- 返回类型
slayer.synapse.Conv
Codes example:
T = 4 N = 2 layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.avgpool2d_to_lava_synapse_pool(pool2d_nn: AvgPool2d)[源代码]
- 参数
pool2d_nn (nn.AvgPool2d) – a pytorch AvgPool2d layer
- 返回
a lava slayer pool layer
- 返回类型
slayer.synapse.Pool
Warning
The lava slayer pool layer applies sum pooling, rather than average pooling.
T = 4 N = 2 layer_nn = nn.AvgPool2d(kernel_size=2, stride=2) layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) / 4. print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.to_lava_block_dense(fc: Linear, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_block_conv(conv2d_nn: Conv2d, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_block_pool(pool2d_nn: AvgPool2d, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_blocks(net: list)[源代码]
Supported layer types input : {shape, type} flatten: {shape, type} average: {shape, type} concat : {shape, type, layers} dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)} pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight} conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride,
- class spikingjelly.activation_based.lava_exchange.SumPool2d(kernel_size, stride=None, padding=0, dilation=1)[源代码]
基类:
Module
x = torch.rand([4, 2, 4, 16, 16]) with torch.no_grad(): sp_sj = SumPool2d(kernel_size=2, stride=2) y_sj = functional.seq_to_ann_forward(x, sp_sj) sp_la = slayer.synapse.Pool(kernel_size=2, stride=2) y_la = lava_exchange.NXT_to_TNX(sp_la(lava_exchange.TNX_to_NXT(x))) print((y_sj - y_la).abs().sum())
spikingjelly.activation_based.layer package
Module contents
- class spikingjelly.activation_based.layer.StepModeContainer(stateful: bool, *args)[源代码]
基类:
Sequential
,StepModule
- class spikingjelly.activation_based.layer.Conv1d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]], stride: Union[int, Tuple[int]] = 1, padding: Union[str, int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
Conv1d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.Conv1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.Conv1d
for other parameters’ API
- class spikingjelly.activation_based.layer.Conv2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[str, int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
Conv2d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.Conv2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.Conv2d
for other parameters’ API
- class spikingjelly.activation_based.layer.Conv3d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[str, int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
Conv3d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.Conv3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.Conv3d
for other parameters’ API
- class spikingjelly.activation_based.layer.ConvTranspose1d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]], stride: Union[int, Tuple[int]] = 1, padding: Union[int, Tuple[int]] = 0, output_padding: Union[int, Tuple[int]] = 0, groups: int = 1, bias: bool = True, dilation: Union[int, Tuple[int]] = 1, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
ConvTranspose1d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.ConvTranspose1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.ConvTranspose1d
for other parameters’ API
- class spikingjelly.activation_based.layer.ConvTranspose2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, output_padding: Union[int, Tuple[int, int]] = 0, groups: int = 1, bias: bool = True, dilation: int = 1, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
ConvTranspose2d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.ConvTranspose2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.ConvTranspose2d
for other parameters’ API
- class spikingjelly.activation_based.layer.ConvTranspose3d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, output_padding: Union[int, Tuple[int, int, int]] = 0, groups: int = 1, bias: bool = True, dilation: Union[int, Tuple[int, int, int]] = 1, padding_mode: str = 'zeros', step_mode: str = 's')[源代码]
基类:
ConvTranspose3d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.ConvTranspose3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.ConvTranspose3d
for other parameters’ API
- class spikingjelly.activation_based.layer.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode='s')[源代码]
-
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.BatchNorm1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.BatchNorm1d
for other parameters’ API
- class spikingjelly.activation_based.layer.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode='s')[源代码]
-
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.BatchNorm2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.BatchNorm2d
for other parameters’ API
- class spikingjelly.activation_based.layer.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, step_mode='s')[源代码]
-
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.BatchNorm3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.BatchNorm3d
for other parameters’ API
- class spikingjelly.activation_based.layer.GroupNorm(num_groups: int, num_channels: int, eps: float = 1e-05, affine: bool = True, step_mode='s')[源代码]
基类:
GroupNorm
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.GroupNorm
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.GroupNorm
for other parameters’ API
- class spikingjelly.activation_based.layer.MaxPool1d(kernel_size: Union[int, Tuple[int]], stride: Optional[Union[int, Tuple[int]]] = None, padding: Union[int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode='s')[源代码]
基类:
MaxPool1d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.MaxPool1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.MaxPool1d
for other parameters’ API
- class spikingjelly.activation_based.layer.MaxPool2d(kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode='s')[源代码]
基类:
MaxPool2d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.MaxPool2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.MaxPool2d
for other parameters’ API
- class spikingjelly.activation_based.layer.MaxPool3d(kernel_size: Union[int, Tuple[int, int, int]], stride: Optional[Union[int, Tuple[int, int, int]]] = None, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, return_indices: bool = False, ceil_mode: bool = False, step_mode='s')[源代码]
基类:
MaxPool3d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.MaxPool3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.MaxPool3d
for other parameters’ API
- class spikingjelly.activation_based.layer.AvgPool1d(kernel_size: Union[int, Tuple[int]], stride: Optional[Union[int, Tuple[int]]] = None, padding: Union[int, Tuple[int]] = 0, ceil_mode: bool = False, count_include_pad: bool = True, step_mode='s')[源代码]
基类:
AvgPool1d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AvgPool1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AvgPool1d
for other parameters’ API
- class spikingjelly.activation_based.layer.AvgPool2d(kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, step_mode='s')[源代码]
基类:
AvgPool2d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AvgPool2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AvgPool2d
for other parameters’ API
- class spikingjelly.activation_based.layer.AvgPool3d(kernel_size: Union[int, Tuple[int, int, int]], stride: Optional[Union[int, Tuple[int, int, int]]] = None, padding: Union[int, Tuple[int, int, int]] = 0, ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None, step_mode='s')[源代码]
基类:
AvgPool3d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AvgPool3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AvgPool3d
for other parameters’ API
- class spikingjelly.activation_based.layer.AdaptiveAvgPool1d(output_size, step_mode='s')[源代码]
基类:
AdaptiveAvgPool1d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AdaptiveAvgPool1d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AdaptiveAvgPool1d
for other parameters’ API
- class spikingjelly.activation_based.layer.AdaptiveAvgPool2d(output_size, step_mode='s')[源代码]
基类:
AdaptiveAvgPool2d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AdaptiveAvgPool2d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AdaptiveAvgPool2d
for other parameters’ API
- class spikingjelly.activation_based.layer.AdaptiveAvgPool3d(output_size, step_mode='s')[源代码]
基类:
AdaptiveAvgPool3d
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.AdaptiveAvgPool3d
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.AdaptiveAvgPool3d
for other parameters’ API
- class spikingjelly.activation_based.layer.Linear(in_features: int, out_features: int, bias: bool = True, step_mode='s')[源代码]
基类:
Linear
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.Linear
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.Linear
for other parameters’ API
- class spikingjelly.activation_based.layer.Flatten(start_dim: int = 1, end_dim: int = -1, step_mode='s')[源代码]
基类:
Flatten
,StepModule
- 参数
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
其他的参数API参见
torch.nn.Flatten
- 参数
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
Refer to
torch.nn.Flatten
for other parameters’ API
- class spikingjelly.activation_based.layer.NeuNorm(in_channels, height, width, k=0.9, shared_across_channels=False, step_mode='s')[源代码]
基类:
MemoryModule
- 参数
in_channels – 输入数据的通道数
height – 输入数据的宽
width – 输入数据的高
k – 动量项系数
shared_across_channels – 可学习的权重
w
是否在通道这一维度上共享。设置为True
可以大幅度节省内存step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
Direct Training for Spiking Neural Networks: Faster, Larger, Better 中提出的NeuNorm层。NeuNorm层必须放在二维卷积层后的脉冲神经元后,例如:
Conv2d -> LIF -> NeuNorm
要求输入的尺寸是
[batch_size, in_channels, height, width]
。in_channels
是输入到NeuNorm层的通道数,也就是论文中的 \(F\)。k
是动量项系数,相当于论文中的 \(k_{\tau 2}\)。论文中的 \(\frac{v}{F}\) 会根据 \(k_{\tau 2} + vF = 1\) 自动算出。
- 参数
in_channels – channels of input
height – height of input
width – height of width
k – momentum factor
shared_across_channels – whether the learnable parameter
w
is shared over channel dim. If setTrue
, the consumption of memory can decrease largelystep_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
The NeuNorm layer is proposed in Direct Training for Spiking Neural Networks: Faster, Larger, Better.
It should be placed after spiking neurons behind convolution layer, e.g.,
Conv2d -> LIF -> NeuNorm
The input should be a 4-D tensor with
shape = [batch_size, in_channels, height, width]
.in_channels
is the channels of input,which is \(F\) in the paper.k
is the momentum factor,which is \(k_{\tau 2}\) in the paper.\(\frac{v}{F}\) will be calculated by \(k_{\tau 2} + vF = 1\) autonomously.
- class spikingjelly.activation_based.layer.Dropout(p=0.5, step_mode='s')[源代码]
基类:
MemoryModule
与
torch.nn.Dropout
的几乎相同。区别在于,在每一轮的仿真中,被设置成0的位置不会发生改变;直到下一轮运行,即网络调用reset()函数后,才会按照概率去重新决定,哪些位置被置0。小技巧
这种Dropout最早由 Enabling Spike-based Backpropagation for Training Deep Neural Network Architectures 一文进行详细论述:
There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of \(p\)) are disconnected from the network while weighting by its posterior probability (\(1-p\)). However, in SNNs, each iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate the output error and modify the network parameters only at the last time step. For dropout to be effective in our training method, it has to be ensured that the set of connected units within an iteration of mini-batch data is not changed, such that the neural network is constituted by the same random subset of units during each forward propagation within a single iteration. On the other hand, if the units are randomly connected at each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire time window within an iteration.
- 参数
This layer is almost same with
torch.nn.Dropout
. The difference is that elements have been zeroed at first step during a simulation will always be zero. The indexes of zeroed elements will be update only afterreset()
has been called and a new simulation is started.Tip
This kind of Dropout is firstly described in Enabling Spike-based Backpropagation for Training Deep Neural Network Architectures:
There is a subtle difference in the way dropout is applied in SNNs compared to ANNs. In ANNs, each epoch of training has several iterations of mini-batches. In each iteration, randomly selected units (with dropout ratio of \(p\)) are disconnected from the network while weighting by its posterior probability (\(1-p\)). However, in SNNs, each iteration has more than one forward propagation depending on the time length of the spike train. We back-propagate the output error and modify the network parameters only at the last time step. For dropout to be effective in our training method, it has to be ensured that the set of connected units within an iteration of mini-batch data is not changed, such that the neural network is constituted by the same random subset of units during each forward propagation within a single iteration. On the other hand, if the units are randomly connected at each time-step, the effect of dropout will be averaged out over the entire forward propagation time within an iteration. Then, the dropout effect would fade-out once the output error is propagated backward and the parameters are updated at the last time step. Therefore, we need to keep the set of randomly connected units for the entire time window within an iteration.
- class spikingjelly.activation_based.layer.Dropout2d(p=0.2, step_mode='s')[源代码]
基类:
Dropout
与
torch.nn.Dropout2d
的几乎相同。区别在于,在每一轮的仿真中,被设置成0的位置不会发生改变;直到下一轮运行,即网络调用reset()函数后,才会按照概率去重新决定,哪些位置被置0。关于SNN中Dropout的更多信息,参见 layer.Dropout。
- 参数
This layer is almost same with
torch.nn.Dropout2d
. The difference is that elements have been zeroed at first step during a simulation will always be zero. The indexes of zeroed elements will be update only afterreset()
has been called and a new simulation is started.For more information about Dropout in SNN, refer to layer.Dropout.
- class spikingjelly.activation_based.layer.SynapseFilter(tau=100.0, learnable=False, step_mode='s')[源代码]
基类:
MemoryModule
- 参数
tau – time 突触上电流衰减的时间常数
learnable – 时间常数在训练过程中是否是可学习的。若为
True
,则tau
会被设定成时间常数的初始值step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
具有滤波性质的突触。突触的输出电流满足,当没有脉冲输入时,输出电流指数衰减:
\[\tau \frac{\mathrm{d} I(t)}{\mathrm{d} t} = - I(t)\]当有新脉冲输入时,输出电流自增1:
\[I(t) = I(t) + 1\]记输入脉冲为 \(S(t)\),则离散化后,统一的电流更新方程为:
\[I(t) = I(t-1) - (1 - S(t)) \frac{1}{\tau} I(t-1) + S(t)\]这种突触能将输入脉冲进行平滑,简单的示例代码和输出结果:
T = 50 in_spikes = (torch.rand(size=[T]) >= 0.95).float() lp_syn = LowPassSynapse(tau=10.0) pyplot.subplot(2, 1, 1) pyplot.bar(torch.arange(0, T).tolist(), in_spikes, label='in spike') pyplot.xlabel('t') pyplot.ylabel('spike') pyplot.legend() out_i = [] for i in range(T): out_i.append(lp_syn(in_spikes[i])) pyplot.subplot(2, 1, 2) pyplot.plot(out_i, label='out i') pyplot.xlabel('t') pyplot.ylabel('i') pyplot.legend() pyplot.show()
输出电流不仅取决于当前时刻的输入,还取决于之前的输入,使得该突触具有了一定的记忆能力。
这种突触偶有使用,例如:
Unsupervised learning of digit recognition using spike-timing-dependent plasticity
另一种视角是将其视为一种输入为脉冲,并输出其电压的LIF神经元。并且该神经元的发放阈值为 \(+\infty\) 。
神经元最后累计的电压值一定程度上反映了该神经元在整个仿真过程中接收脉冲的数量,从而替代了传统的直接对输出脉冲计数(即发放频率)来表示神经元活跃程度的方法。因此通常用于最后一层,在以下文章中使用:
Enabling spike-based backpropagation for training deep neural network architectures
- 参数
tau – time constant that determines the decay rate of current in the synapse
learnable – whether time constant is learnable during training. If
True
, thentau
will be the initial value of time constantstep_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
The synapse filter that can filter input current. The output current will decay when there is no input spike:
\[\tau \frac{\mathrm{d} I(t)}{\mathrm{d} t} = - I(t)\]The output current will increase 1 when there is a new input spike:
\[I(t) = I(t) + 1\]Denote the input spike as \(S(t)\), then the discrete current update equation is as followed:
\[I(t) = I(t-1) - (1 - S(t)) \frac{1}{\tau} I(t-1) + S(t)\]This synapse can smooth input. Here is the example and output:
T = 50 in_spikes = (torch.rand(size=[T]) >= 0.95).float() lp_syn = LowPassSynapse(tau=10.0) pyplot.subplot(2, 1, 1) pyplot.bar(torch.arange(0, T).tolist(), in_spikes, label='in spike') pyplot.xlabel('t') pyplot.ylabel('spike') pyplot.legend() out_i = [] for i in range(T): out_i.append(lp_syn(in_spikes[i])) pyplot.subplot(2, 1, 2) pyplot.plot(out_i, label='out i') pyplot.xlabel('t') pyplot.ylabel('i') pyplot.legend() pyplot.show()
The output current is not only determined by the present input but also by the previous input, which makes this synapse have memory.
This synapse is sometimes used, e.g.:
Unsupervised learning of digit recognition using spike-timing-dependent plasticity
Another view is regarding this synapse as a LIF neuron with a \(+\infty\) threshold voltage.
The final output of this synapse (or the final voltage of this LIF neuron) represents the accumulation of input spikes, which substitute for traditional firing rate that indicates the excitatory level. So, it can be used in the last layer of the network, e.g.:
Enabling spike-based backpropagation for training deep neural network architectures
- class spikingjelly.activation_based.layer.DropConnectLinear(in_features: int, out_features: int, bias: bool = True, p: float = 0.5, samples_num: int = 1024, invariant: bool = False, activation: Module = ReLU(), state_mode='s')[源代码]
基类:
MemoryModule
- 参数
in_features (int) – 每个输入样本的特征数
out_features (int) – 每个输出样本的特征数
bias (bool) – 若为
False
,则本层不会有可学习的偏置项。 默认为True
p (float) – 每个连接被断开的概率。默认为0.5
samples_num (int) – 在推理时,从高斯分布中采样的数据数量。默认为1024
invariant (bool) – 若为
True
,线性层会在第一次执行前向传播时被按概率断开,断开后的线性层会保持不变,直到reset()
函数 被调用,线性层恢复为完全连接的状态。完全连接的线性层,调用reset()
函数后的第一次前向传播时被重新按概率断开。 若为False
,在每一次前向传播时线性层都会被重新完全连接再按概率断开。 阅读 layer.Dropout 以 获得更多关于此参数的信息。 默认为False
activation (None or nn.Module) – 在线性层后的激活层
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
DropConnect,由 Regularization of Neural Networks using DropConnect 一文提出。DropConnect与Dropout非常类似,区别在于DropConnect是以概率
p
断开连接,而Dropout是将输入以概率置0。备注
在使用DropConnect进行推理时,输出的tensor中的每个元素,都是先从高斯分布中采样,通过激活层激活,再在采样数量上进行平均得到的。 详细的流程可以在 Regularization of Neural Networks using DropConnect 一文中的 Algorithm 2 找到。激活层
activation
在中间的步骤起作用,因此我们将其作为模块的成员。- 参数
in_features (int) – size of each input sample
out_features (int) – size of each output sample
bias (bool) – If set to
False
, the layer will not learn an additive bias. Default:True
p (float) – probability of an connection to be zeroed. Default: 0.5
samples_num (int) – number of samples drawn from the Gaussian during inference. Default: 1024
invariant (bool) – If set to
True
, the connections will be dropped at the first time of forward and the dropped connections will remain unchanged untilreset()
is called and the connections recovery to fully-connected status. Then the connections will be re-dropped at the first time of forward afterreset()
. If set toFalse
, the connections will be re-dropped at every forward. See layer.Dropout for more information to understand this parameter. Default:False
activation (None or nn.Module) – the activation layer after the linear layer
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
DropConnect, which is proposed by Regularization of Neural Networks using DropConnect, is similar with Dropout but drop connections of a linear layer rather than the elements of the input tensor with probability
p
.Note
When inference with DropConnect, every elements of the output tensor are sampled from a Gaussian distribution, activated by the activation layer and averaged over the sample number
samples_num
. See Algorithm 2 in Regularization of Neural Networks using DropConnect for more details. Note that activation is an intermediate process. This is the reason why we includeactivation
as a member variable of this module.- reset_parameters() None [源代码]
-
- 返回
None
- 返回类型
None
初始化模型中的可学习参数。
- 返回
None
- 返回类型
None
Initialize the learnable parameters of this module.
- class spikingjelly.activation_based.layer.PrintShapeModule(ext_str='PrintShapeModule')[源代码]
基类:
Module
- 参数
ext_str (str) – 额外打印的字符串
只打印
ext_str
和输入的shape
,不进行任何操作的网络层,可以用于debug。- 参数
ext_str (str) – extra strings for printing
This layer will not do any operation but print
ext_str
and the shape of input, which can be used for debugging.
- class spikingjelly.activation_based.layer.ElementWiseRecurrentContainer(sub_module: Module, element_wise_function: Callable, step_mode='s')[源代码]
基类:
MemoryModule
- 参数
sub_module (torch.nn.Module) – 被包含的模块
element_wise_function (Callable) – 用户自定义的逐元素函数,应该形如
z=f(x, y)
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
使用逐元素运算的自连接包装器。记
sub_module
的输入输出为 \(i[t]\) 和 \(y[t]\) (注意 \(y[t]\) 也是整个模块的输出), 整个模块的输入为 \(x[t]\),则\[i[t] = f(x[t], y[t-1])\]其中 \(f\) 是用户自定义的逐元素函数。我们默认 \(y[-1] = 0\)。
备注
sub_module
输入和输出的尺寸需要相同。示例代码:
T = 8 net = ElementWiseRecurrentContainer(neuron.IFNode(v_reset=None), element_wise_function=lambda x, y: x + y) print(net) x = torch.zeros([T]) x[0] = 1.5 for t in range(T): print(t, f'x[t]={x[t]}, s[t]={net(x[t])}') functional.reset_net(net)
- 参数
sub_module (torch.nn.Module) – the contained module
element_wise_function (Callable) – the user-defined element-wise function, which should have the format
z=f(x, y)
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
A container that use a element-wise recurrent connection. Denote the inputs and outputs of
sub_module
as \(i[t]\) and \(y[t]\) (Note that \(y[t]\) is also the outputs of this module), and the inputs of this module as \(x[t]\), then\[i[t] = f(x[t], y[t-1])\]where \(f\) is the user-defined element-wise function. We set \(y[-1] = 0\).
Note
The shape of inputs and outputs of
sub_module
must be the same.Codes example:
T = 8 net = ElementWiseRecurrentContainer(neuron.IFNode(v_reset=None), element_wise_function=lambda x, y: x + y) print(net) x = torch.zeros([T]) x[0] = 1.5 for t in range(T): print(t, f'x[t]={x[t]}, s[t]={net(x[t])}') functional.reset_net(net)
- class spikingjelly.activation_based.layer.LinearRecurrentContainer(sub_module: Module, in_features: int, out_features: int, bias: bool = True, step_mode='s')[源代码]
基类:
MemoryModule
- 参数
sub_module (torch.nn.Module) – 被包含的模块
in_features (int) – 输入的特征数量
out_features (int) – 输出的特征数量
bias (bool) – 若为
False
,则线性自连接不会带有可学习的偏执项step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
使用线性层的自连接包装器。记
sub_module
的输入和输出为 \(i[t]\) 和 \(y[t]\) (注意 \(y[t]\) 也是整个模块的输出), 整个模块的输入记作 \(x[t]\) ,则\[\begin{split}i[t] = \begin{pmatrix} x[t] \\ y[t-1]\end{pmatrix} W^{T} + b\end{split}\]其中 \(W, b\) 是线性层的权重和偏置项。默认 \(y[-1] = 0\)。
\(x[t]\) 应该
shape = [N, *, in_features]
,\(y[t]\) 则应该shape = [N, *, out_features]
。备注
自连接是由
torch.nn.Linear(in_features + out_features, in_features, bias)
实现的。in_features = 4 out_features = 2 T = 8 N = 2 net = LinearRecurrentContainer( nn.Sequential( nn.Linear(in_features, out_features), neuron.LIFNode(), ), in_features, out_features) print(net) x = torch.rand([T, N, in_features]) for t in range(T): print(t, net(x[t])) functional.reset_net(net)
- 参数
sub_module (torch.nn.Module) – the contained module
in_features (int) – size of each input sample
out_features (int) – size of each output sample
bias (bool) – If set to
False
, the linear recurrent layer will not learn an additive biasstep_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
A container that use a linear recurrent connection. Denote the inputs and outputs of
sub_module
as \(i[t]\) and \(y[t]\) (Note that \(y[t]\) is also the outputs of this module), and the inputs of this module as \(x[t]\), then\[\begin{split}i[t] = \begin{pmatrix} x[t] \\ y[t-1]\end{pmatrix} W^{T} + b\end{split}\]where \(W, b\) are the weight and bias of the linear connection. We set \(y[-1] = 0\).
\(x[t]\) should have the shape
[N, *, in_features]
, and \(y[t]\) has the shape[N, *, out_features]
.Note
The recurrent connection is implement by
torch.nn.Linear(in_features + out_features, in_features, bias)
.in_features = 4 out_features = 2 T = 8 N = 2 net = LinearRecurrentContainer( nn.Sequential( nn.Linear(in_features, out_features), neuron.LIFNode(), ), in_features, out_features) print(net) x = torch.rand([T, N, in_features]) for t in range(T): print(t, net(x[t])) functional.reset_net(net)
- class spikingjelly.activation_based.layer.ThresholdDependentBatchNorm1d(alpha: float, v_th: float, *args, **kwargs)[源代码]
基类:
_ThresholdDependentBatchNormBase
*args, **kwargs
中的参数与torch.nn.BatchNorm1d
的参数相同。Going Deeper With Directly-Trained Larger Spiking Neural Networks 一文提出 的Threshold-Dependent Batch Normalization (tdBN)。
- 参数
Other parameters in
*args, **kwargs
are same with those oftorch.nn.BatchNorm1d
.The Threshold-Dependent Batch Normalization (tdBN) proposed in Going Deeper With Directly-Trained Larger Spiking Neural Networks.
- class spikingjelly.activation_based.layer.ThresholdDependentBatchNorm2d(alpha: float, v_th: float, *args, **kwargs)[源代码]
基类:
_ThresholdDependentBatchNormBase
*args, **kwargs
中的参数与torch.nn.BatchNorm2d
的参数相同。Going Deeper With Directly-Trained Larger Spiking Neural Networks 一文提出 的Threshold-Dependent Batch Normalization (tdBN)。
- 参数
Other parameters in
*args, **kwargs
are same with those oftorch.nn.BatchNorm2d
.The Threshold-Dependent Batch Normalization (tdBN) proposed in Going Deeper With Directly-Trained Larger Spiking Neural Networks.
- class spikingjelly.activation_based.layer.ThresholdDependentBatchNorm3d(alpha: float, v_th: float, *args, **kwargs)[源代码]
基类:
_ThresholdDependentBatchNormBase
*args, **kwargs
中的参数与torch.nn.BatchNorm3d
的参数相同。Going Deeper With Directly-Trained Larger Spiking Neural Networks 一文提出 的Threshold-Dependent Batch Normalization (tdBN)。
- 参数
Other parameters in
*args, **kwargs
are same with those oftorch.nn.BatchNorm3d
.The Threshold-Dependent Batch Normalization (tdBN) proposed in Going Deeper With Directly-Trained Larger Spiking Neural Networks.
- class spikingjelly.activation_based.layer.TemporalWiseAttention(T: int, reduction: int = 16, dimension: int = 4)[源代码]
-
- 参数
T – 输入数据的时间步长
reduction – 压缩比
dimension – 输入数据的维度。当输入数据为[T, N, C, H, W]时, dimension = 4;输入数据维度为[T, N, L]时,dimension = 2。
Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification 中提出 的MultiStepTemporalWiseAttention层。MultiStepTemporalWiseAttention层必须放在二维卷积层之后脉冲神经元之前,例如:
Conv2d -> MultiStepTemporalWiseAttention -> LIF
输入的尺寸是
[T, N, C, H, W]
或者[T, N, L]
,经过MultiStepTemporalWiseAttention层,输出为[T, N, C, H, W]
或者[T, N, L]
。reduction
是压缩比,相当于论文中的 \(r\)。- 参数
T – timewindows of input
reduction – reduction ratio
dimension – Dimensions of input. If the input dimension is [T, N, C, H, W], dimension = 4; when the input dimension is [T, N, L], dimension = 2.
The MultiStepTemporalWiseAttention layer is proposed in Temporal-Wise Attention Spiking Neural Networks for Event Streams Classification.
It should be placed after the convolution layer and before the spiking neurons, e.g.,
Conv2d -> MultiStepTemporalWiseAttention -> LIF
The dimension of the input is
[T, N, C, H, W]
or[T, N, L]
, after the MultiStepTemporalWiseAttention layer, the output dimension is[T, N, C, H, W]
or[T, N, L]
.reduction
is the reduction ratio,which is \(r\) in the paper.
- class spikingjelly.activation_based.layer.MultiDimensionalAttention(T: int, C: int, reduction_t: int = 16, reduction_c: int = 16, kernel_size=3)[源代码]
-
- 参数
T – 输入数据的时间步长
C – 输入数据的通道数
reduction_t – 时间压缩比
reduction_c – 通道压缩比
kernel_size – 空间注意力机制的卷积核大小
Attention Spiking Neural Networks 中提出 的MA-SNN模型以及MultiStepMultiDimensionalAttention层。
您可以从以下链接中找到MA-SNN的示例项目: - https://github.com/MA-SNN/MA-SNN - https://github.com/ridgerchu/SNN_Attention_VGG
输入的尺寸是
[T, N, C, H, W]
,经过MultiStepMultiDimensionalAttention层,输出为[T, N, C, H, W]
。- 参数
T – timewindows of input
C – channel number of input
reduction_t – temporal reduction ratio
reduction_c – channel reduction ratio
kernel_size – convolution kernel size of SpatialAttention
The MA-SNN model and MultiStepMultiDimensionalAttention layer are proposed in ``Attention Spiking Neural Networks <https://ieeexplore.ieee.org/document/10032591>`_.
You can find the example projects of MA-SNN in the following links: - https://github.com/MA-SNN/MA-SNN - https://github.com/ridgerchu/SNN_Attention_VGG
The dimension of the input is
[T, N, C, H, W]
, after the MultiStepMultiDimensionalAttention layer, the output dimension is[T, N, C, H, W]
.
- class spikingjelly.activation_based.layer.VotingLayer(voting_size: int = 10, step_mode='s')[源代码]
基类:
Module
,StepModule
投票层,对
shape = [..., C * voting_size]
的输入在最后一维上做kernel_size = voting_size, stride = voting_size
的平均池化- 参数
Applies average pooling with
kernel_size = voting_size, stride = voting_size
on the last dimension of the input withshape = [..., C * voting_size]
- class spikingjelly.activation_based.layer.Delay(delay_steps: int, step_mode='s')[源代码]
基类:
MemoryModule
延迟层,可以用来延迟输入,使得
y[t] = x[t - delay_steps]
。缺失的数据用0填充。代码示例:
delay_layer = Delay(delay=1, step_mode='m') x = torch.rand([5, 2]) x[3:].zero_() x.requires_grad = True y = delay_layer(x) print('x=') print(x) print('y=') print(y) y.sum().backward() print('x.grad=') print(x.grad)
输出为:
x= tensor([[0.2510, 0.7246], [0.5303, 0.3160], [0.2531, 0.5961], [0.0000, 0.0000], [0.0000, 0.0000]], requires_grad=True) y= tensor([[0.0000, 0.0000], [0.2510, 0.7246], [0.5303, 0.3160], [0.2531, 0.5961], [0.0000, 0.0000]], grad_fn=<CatBackward0>) x.grad= tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.]])
- 参数
A delay layer that can delay inputs and makes
y[t] = x[t - delay_steps]
. The nonexistent data will be regarded as 0.Codes example:
delay_layer = Delay(delay=1, step_mode='m') x = torch.rand([5, 2]) x[3:].zero_() x.requires_grad = True y = delay_layer(x) print('x=') print(x) print('y=') print(y) y.sum().backward() print('x.grad=') print(x.grad)
The outputs are:
x= tensor([[0.2510, 0.7246], [0.5303, 0.3160], [0.2531, 0.5961], [0.0000, 0.0000], [0.0000, 0.0000]], requires_grad=True) y= tensor([[0.0000, 0.0000], [0.2510, 0.7246], [0.5303, 0.3160], [0.2531, 0.5961], [0.0000, 0.0000]], grad_fn=<CatBackward0>) x.grad= tensor([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [0., 0.]])
- property delay_steps
spikingjelly.activation_based.learning package
Module contents
- spikingjelly.activation_based.learning.stdp_linear_single_step(fc: ~torch.nn.modules.linear.Linear, in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], trace_post: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- spikingjelly.activation_based.learning.mstdp_linear_single_step(fc: ~torch.nn.modules.linear.Linear, in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], trace_post: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- spikingjelly.activation_based.learning.mstdpet_linear_single_step(fc: ~torch.nn.modules.linear.Linear, in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], trace_post: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], tau_pre: float, tau_post: float, tau_trace: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- spikingjelly.activation_based.learning.stdp_conv2d_single_step(conv: ~torch.nn.modules.conv.Conv2d, in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~torch.Tensor], trace_post: ~typing.Optional[~torch.Tensor], tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- spikingjelly.activation_based.learning.stdp_conv1d_single_step(conv: ~torch.nn.modules.conv.Conv1d, in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~torch.Tensor], trace_post: ~typing.Optional[~torch.Tensor], tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- spikingjelly.activation_based.learning.stdp_multi_step(layer: ~typing.Union[~torch.nn.modules.linear.Linear, ~torch.nn.modules.conv.Conv1d, ~torch.nn.modules.conv.Conv2d], in_spike: ~torch.Tensor, out_spike: ~torch.Tensor, trace_pre: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], trace_post: ~typing.Optional[~typing.Union[float, ~torch.Tensor]], tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function <lambda>>, f_post: ~typing.Callable = <function <lambda>>)[源代码]
- class spikingjelly.activation_based.learning.STDPLearner(step_mode: str, synapse: ~typing.Union[~torch.nn.modules.conv.Conv2d, ~torch.nn.modules.linear.Linear], sn: ~spikingjelly.activation_based.neuron.BaseNode, tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function STDPLearner.<lambda>>, f_post: ~typing.Callable = <function STDPLearner.<lambda>>)[源代码]
基类:
MemoryModule
- class spikingjelly.activation_based.learning.MSTDPLearner(step_mode: str, batch_size: float, synapse: ~typing.Union[~torch.nn.modules.conv.Conv2d, ~torch.nn.modules.linear.Linear], sn: ~spikingjelly.activation_based.neuron.BaseNode, tau_pre: float, tau_post: float, f_pre: ~typing.Callable = <function MSTDPLearner.<lambda>>, f_post: ~typing.Callable = <function MSTDPLearner.<lambda>>)[源代码]
基类:
MemoryModule
- class spikingjelly.activation_based.learning.MSTDPETLearner(step_mode: str, synapse: ~typing.Union[~torch.nn.modules.conv.Conv2d, ~torch.nn.modules.linear.Linear], sn: ~spikingjelly.activation_based.neuron.BaseNode, tau_pre: float, tau_post: float, tau_trace: float, f_pre: ~typing.Callable = <function MSTDPETLearner.<lambda>>, f_post: ~typing.Callable = <function MSTDPETLearner.<lambda>>)[源代码]
基类:
MemoryModule
spikingjelly.activation_based.monitor package
Module contents
- class spikingjelly.activation_based.monitor.OutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Optional[~typing.Any] = None, function_on_output: ~typing.Callable = <function OutputMonitor.<lambda>>)[源代码]
基类:
BaseMonitor
- 参数
net (nn.Module) – 一个神经网络
instance (Any or tuple) – 被监视的模块的数据类型。若为
None
则表示类型为type(net)
function_on_output (Callable) – 作用于被监控的模块输出的自定义的函数
对
net
中所有类型为instance
的模块的输出使用function_on_output
作用后,记录到类型为 list` 的self.records
中。 可以通过self.enable()
和self.disable()
来启用或停用这个监视器。 可以通过self.clear_recorded_data()
来清除已经记录的数据。阅读监视器的教程以获得更多信息。
示例代码:
class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.OutputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[0., 0., 0., 1.]]), tensor([[0., 0.]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[0., 0., 0., 1.]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[0., 0., 0., 1.]])]
- 参数
net (nn.Module) – a network
instance (Any or tuple) – the instance of modules to be monitored. If
None
, it will be regarded astype(net)
function_on_output (Callable) – the function that applies on the monitored modules’ outputs
Applies
function_on_output
on outputs of all modules whose instances areinstance
innet
, and records the data intoself.records
, which is alist
. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear the recorded data.Refer to the tutorial about the monitor for more details.
Codes example:
class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.OutputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[0., 0., 0., 1.]]), tensor([[0., 0.]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[0., 0., 0., 1.]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[0., 0., 0., 1.]])]
- class spikingjelly.activation_based.monitor.InputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Optional[~typing.Any] = None, function_on_input: ~typing.Callable = <function InputMonitor.<lambda>>)[源代码]
基类:
BaseMonitor
- 参数
net (nn.Module) – 一个神经网络
instance (Any or tuple) – 被监视的模块的数据类型。若为
None
则表示类型为type(net)
function_on_input (Callable) – 作用于被监控的模块输入的自定义的函数
对
net
中所有类型为instance
的模块的输入使用function_on_input
作用后,记录到类型为 list` 的self.records
中。 可以通过self.enable()
和self.disable()
来启用或停用这个监视器。 可以通过self.clear_recorded_data()
来清除已经记录的数据。阅读监视器的教程以获得更多信息。
示例代码:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.InputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[1.0165, 1.1934, 0.9347, 0.9539]]), tensor([[0.9115, 0.9508]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[1.0165, 1.1934, 0.9347, 0.9539]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[1.0165, 1.1934, 0.9347, 0.9539]])]
- 参数
net (nn.Module) – a network
instance (Any or tuple) – the instance of modules to be monitored. If
None
, it will be regarded astype(net)
function_on_input (Callable) – the function that applies on the monitored modules’ inputs
Applies
function_on_input
on inputs of all modules whose instances areinstance
innet
, and records the data intoself.records
, which is alist
. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear the recorded data.Refer to the tutorial about the monitor for more details.
Codes example:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.InputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[1.0165, 1.1934, 0.9347, 0.9539]]), tensor([[0.9115, 0.9508]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[1.0165, 1.1934, 0.9347, 0.9539]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[1.0165, 1.1934, 0.9347, 0.9539]])]
- class spikingjelly.activation_based.monitor.AttributeMonitor(attribute_name: str, pre_forward: bool, net: ~torch.nn.modules.module.Module, instance: ~typing.Optional[~typing.Any] = None, function_on_attribute: ~typing.Callable = <function AttributeMonitor.<lambda>>)[源代码]
基类:
BaseMonitor
- 参数
对
net
中所有类型为instance
的模块m
的成员m.attribute_name
使用function_on_attribute
作用后,记录到类型为 list` 的self.records
。 可以通过self.enable()
和self.disable()
来启用或停用这个监视器。 可以通过self.clear_recorded_data()
来清除已经记录的数据。阅读监视器的教程以获得更多信息。
示例代码:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.AttributeMonitor('v', False, net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([0.0000, 0.6854, 0.0000, 0.7968]), tensor([0.4472, 0.0000])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([0.0000, 0.6854, 0.0000, 0.7968]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([0.0000, 0.6854, 0.0000, 0.7968])]
- 参数
attribute_name (str) – the monitored attribute’s name
pre_forward (bool) – If
True
, recording the attribute before forward, otherwise recording the attribute after forwardnet (nn.Module) – a network
instance (Any or tuple) – the instance of modules to be monitored. If
None
, it will be regarded astype(net)
function_on_attribute (Callable) – the function that applies on each monitored module’s attribute
Applies
function_on_attribute
onm.attribute_name
of each monitored modulem
whose instance isinstance
innet
, and records the data intoself.records
, which is alist
. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear the recorded data.Refer to the tutorial about the monitor for more details.
Codes example:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.AttributeMonitor('v', False, net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([0.0000, 0.6854, 0.0000, 0.7968]), tensor([0.4472, 0.0000])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([0.0000, 0.6854, 0.0000, 0.7968]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([0.0000, 0.6854, 0.0000, 0.7968])]
- class spikingjelly.activation_based.monitor.GradInputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Optional[~typing.Any] = None, function_on_grad_input: ~typing.Callable = <function GradInputMonitor.<lambda>>)[源代码]
基类:
BaseMonitor
- 参数
net (nn.Module) – 一个神经网络
instance (Any or tuple) – 被监视的模块的数据类型。若为
None
则表示类型为type(net)
function_on_grad_input (Callable) – 作用于被监控的模块输出的输入的梯度的函数
对
net
中所有类型为instance
的模块的输入的梯度使用function_on_grad_input
作用后,记录到类型为 list` 的self.records
中。 可以通过self.enable()
和self.disable()
来启用或停用这个监视器。 可以通过self.clear_recorded_data()
来清除已经记录的数据。阅读监视器的教程以获得更多信息。
备注
对于一个模块,输入为 \(X\),输出为 \(Y\),损失为 \(L\),则
GradInputMonitor
记录的是对输入的梯度 \(\frac{\partial L}{\partial X}\)。示例代码:
class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.GradInputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([0.0000, 0.6854, 0.0000, 0.7968]), tensor([0.4472, 0.0000])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([0.0000, 0.6854, 0.0000, 0.7968]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([0.0000, 0.6854, 0.0000, 0.7968])]
- 参数
net (nn.Module) – a network
instance (Any or tuple) – the instance of modules to be monitored. If
None
, it will be regarded astype(net)
function_on_grad_input (Callable) – the function that applies on the grad of monitored modules’ inputs
Applies
function_on_grad_input
on grad of inputs of all modules whose instances areinstance
innet
, and records the data intoself.records
, which is alist
. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear the recorded data.Refer to the tutorial about the monitor for more details.
Note
Denote the input and output of the monitored module as \(X\) and \(Y\), and the loss is \(L\), then
GradInputMonitor
will record the gradient of input, which is \(\frac{\partial L}{\partial X}\).Codes example:
class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.GradInputMonitor(net, instance=neuron.IFNode) with torch.no_grad(): y = net(torch.rand([1, 8])) print(f'mtor.records={mtor.records}') # mtor.records=[tensor([0.0000, 0.6854, 0.0000, 0.7968]), tensor([0.4472, 0.0000])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([0.0000, 0.6854, 0.0000, 0.7968]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([0.0000, 0.6854, 0.0000, 0.7968])]
- class spikingjelly.activation_based.monitor.GradOutputMonitor(net: ~torch.nn.modules.module.Module, instance: ~typing.Optional[~typing.Any] = None, function_on_grad_output: ~typing.Callable = <function GradOutputMonitor.<lambda>>)[源代码]
基类:
BaseMonitor
- 参数
net (nn.Module) – 一个神经网络
instance (Any or tuple) – 被监视的模块的数据类型。若为
None
则表示类型为type(net)
function_on_grad_output (Callable) – 作用于被监控的模块输出的输出的的梯度的函数
对
net
中所有类型为instance
的模块的输出的梯度使用function_on_grad_output
作用后,记录到类型为 list` 的self.records
中。 可以通过self.enable()
和self.disable()
来启用或停用这个监视器。 可以通过self.clear_recorded_data()
来清除已经记录的数据。阅读监视器的教程以获得更多信息。
备注
对于一个模块,输入为 \(X\),输出为 \(Y\),损失为 \(L\),则
GradOutputMonitor
记录的是对输出的梯度 \(\frac{\partial L}{\partial Y}\)。示例代码:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) net(torch.rand([1, 8])).sum().backward() print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[1., 1.]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])]
- 参数
net (nn.Module) – a network
instance (Any or tuple) – the instance of modules to be monitored. If
None
, it will be regarded astype(net)
function_on_grad_output (Callable) – the function that applies on the grad of monitored modules’ inputs
Applies
function_on_grad_output
on grad of outputs of all modules whose instances areinstance
innet
, and records the data intoself.records
, which is alist
. Callself.enable()
orself.disable()
to enable or disable the monitor. Callself.clear_recorded_data()
to clear the recorded data.Refer to the tutorial about the monitor for more details.
Note
Denote the input and output of the monitored module as \(X\) and \(Y\), and the loss is \(L\), then
GradOutputMonitor
will record the gradient of output, which is \(\frac{\partial L}{\partial Y}\).Codes example:
import torch import torch.nn as nn from spikingjelly.activation_based import monitor, neuron, functional, layer class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = layer.Linear(8, 4) self.sn1 = neuron.IFNode() self.fc2 = layer.Linear(4, 2) self.sn2 = neuron.IFNode() functional.set_step_mode(self, 'm') def forward(self, x_seq: torch.Tensor): x_seq = self.fc1(x_seq) x_seq = self.sn1(x_seq) x_seq = self.fc2(x_seq) x_seq = self.sn2(x_seq) return x_seq net = Net() for param in net.parameters(): param.data.abs_() mtor = monitor.GradOutputMonitor(net, instance=neuron.IFNode) net(torch.rand([1, 8])).sum().backward() print(f'mtor.records={mtor.records}') # mtor.records=[tensor([[1., 1.]]), tensor([[0.1372, 0.1081, 0.0880, 0.1089]])] print(f'mtor[0]={mtor[0]}') # mtor[0]=tensor([[1., 1.]]) print(f'mtor.monitored_layers={mtor.monitored_layers}') # mtor.monitored_layers=['sn1', 'sn2'] print(f"mtor['sn1']={mtor['sn1']}") # mtor['sn1']=[tensor([[0.1372, 0.1081, 0.0880, 0.1089]])]
- class spikingjelly.activation_based.monitor.GPUMonitor(log_dir: Optional[str] = None, gpu_ids: tuple = (0,), interval: float = 600.0, start_now=True)[源代码]
基类:
Thread
- 参数
GPU监视器,可以开启一个新的线程来记录
gpu_ids
的使用率和显存使用情况,每interval
秒记录一次数据。警告
在主线程的工作完成后一定要调用GPU监视器的
stop()
函数,否则主线程不会退出。Codes example:
import time gm = GPUMonitor(interval=1) time.sleep(2) # make the main thread sleep gm.stop() # The outputs are: # 2022-04-28 10:52:25 # utilization.gpu [%], memory.used [MiB] # 0 %, 376 MiB
- 参数
log_dir (str) – the directory for saving logs with tensorboard. If it is None, this module will print logs
gpu_ids (tuple) – the id of GPUs to be monitored, e.g.,
(0, 1, 2, 3)
. The default value is(0, )
interval (float) – the recording interval (in seconds)
start_now – if true, the monitor will start to record now. Otherwise, it will start after the user call
start()
manually
The GPU monitor, which starts a new thread to record the utilization and memory used of
gpu_ids
everyinterval
seconds.Warning
Do not forget to call this module’s
stop()
after the main thread finishes its job, otherwise the main thread will never stop!Codes example:
import time gm = GPUMonitor(interval=1) time.sleep(2) # make the main thread sleep gm.stop() # The outputs are: # 2022-04-28 10:52:25 # utilization.gpu [%], memory.used [MiB] # 0 %, 376 MiB
spikingjelly.activation_based.neuron package
Module contents
- class spikingjelly.activation_based.neuron.BaseNode(v_threshold: float = 1.0, v_reset: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
MemoryModule
- 参数
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
可微分SNN神经元的基类神经元。
- 参数
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
This class is the base class of differentiable spiking neurons.
- property store_v_seq
- abstract neuronal_charge(x: Tensor)[源代码]
定义神经元的充电差分方程。子类必须实现这个函数。
Define the charge difference equation. The sub-class must implement this function.
- neuronal_fire()[源代码]
-
根据当前神经元的电压、阈值,计算输出脉冲。
Calculate out spikes of neurons by their current membrane potential and threshold voltage.
- neuronal_reset(spike)[源代码]
-
根据当前神经元释放的脉冲,对膜电位进行重置。
Reset the membrane potential according to neurons’ output spikes.
- single_step_forward(x: Tensor)[源代码]
-
- 参数
x (torch.Tensor) – 输入到神经元的电压增量
- 返回
神经元的输出脉冲
- 返回类型
按照充电、放电、重置的顺序进行前向传播。
- 参数
x (torch.Tensor) – increment of voltage inputted to neurons
- 返回
out spikes of neurons
- 返回类型
Forward by the order of neuronal_charge, neuronal_fire, and neuronal_reset.
- class spikingjelly.activation_based.neuron.AdaptBaseNode(v_threshold: float = 1.0, v_reset: float = 0.0, v_rest: float = 0.0, w_rest: float = 0.0, tau_w: float = 2.0, a: float = 0.0, b: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- static jit_hard_reset(v: Tensor, w: Tensor, spike_d: Tensor, v_reset: float, b: float, spike: Tensor)[源代码]
- static jit_soft_reset(v: Tensor, w: Tensor, spike_d: Tensor, v_threshold: float, b: float, spike: Tensor)[源代码]
- class spikingjelly.activation_based.neuron.IFNode(v_threshold: float = 1.0, v_reset: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:
\[H[t] = V[t-1] + X[t]\]- 参数
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed:
\[H[t] = V[t-1] + X[t]\]- property supported_backends
- static jit_eval_single_step_forward_hard_reset(x: Tensor, v: Tensor, v_threshold: float, v_reset: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float)[源代码]
- class spikingjelly.activation_based.neuron.LIFNode(tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.0, v_reset: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
tau (float) – 膜电位时间常数
decay_input (bool) – 输入是否也会参与衰减
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若
decay_input == True
:\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]若
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]- 参数
tau (float) – membrane time constant
decay_input (bool) – whether the input will decay
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed:
IF
decay_input == True
:\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]IF
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]- property supported_backends
- static jit_eval_single_step_forward_hard_reset_decay_input(x: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_single_step_forward_hard_reset_no_decay_input(x: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_single_step_forward_soft_reset_decay_input(x: Tensor, v: Tensor, v_threshold: float, tau: float)[源代码]
- static jit_eval_single_step_forward_soft_reset_no_decay_input(x: Tensor, v: Tensor, v_threshold: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset_decay_input(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset_no_decay_input(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(x_seq: Tensor, v: Tensor, v_threshold: float, v_reset: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_soft_reset_decay_input(x_seq: Tensor, v: Tensor, v_threshold: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(x_seq: Tensor, v: Tensor, v_threshold: float, tau: float)[源代码]
- static jit_eval_multi_step_forward_soft_reset_no_decay_input(x_seq: Tensor, v: Tensor, v_threshold: float, tau: float)[源代码]
- class spikingjelly.activation_based.neuron.ParametricLIFNode(init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.0, v_reset: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
init_tau (float) – 膜电位时间常数的初始值
decay_input (bool) – 输入是否也会参与衰减
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存cupy_fp32_inference (bool) – 若为 True,在 eval 模式下,使用float32,却在GPU上运行,并且 cupy 已经安装,则会自动使用 cupy 进行加速。 这个选项的优先权高于
backend
Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若
decay_input == True
:\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]若
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]其中 \(\frac{1}{\tau} = {\rm Sigmoid}(w)\),\(w\) 是可学习的参数。
- 参数
init_tau (float) – the initial value of membrane time constant
decay_input (bool) – whether the input will decay
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumptioncupy_fp32_inference (bool) – If True, if this module is in eval mode, using float32, running on GPU, and cupy is installed, then this module will use cupy to accelerate. This option has priority over
backend
The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks and can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed:
IF
decay_input == True
:\[H = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]IF
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]where \(\frac{1}{\tau} = {\rm Sigmoid}(w)\), \(w\) is a learnable parameter.
- property supported_backends
- class spikingjelly.activation_based.neuron.QIFNode(tau: float = 2.0, v_c: float = 0.8, a0: float = 1.0, v_threshold: float = 1.0, v_rest: float = 0.0, v_reset: float = -0.1, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
tau (float) – 膜电位时间常数
v_c (float) – 关键电压
a0 (float) –
v_threshold (float) – 神经元的阈值电压
v_rest (float) – 静息电位
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为:
\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))\]- 参数
tau (float) – membrane time constant
v_c (float) – critical voltage
a0 (float) –
v_threshold (float) – threshold voltage of neurons
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model. The subthreshold neural dynamics of it is as followed:
\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))\]- property supported_backends
- class spikingjelly.activation_based.neuron.EIFNode(tau: float = 2.0, delta_T: float = 1.0, theta_rh: float = 0.8, v_threshold: float = 1.0, v_rest: float = 0.0, v_reset: float = -0.1, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
tau (float) – 膜电位时间常数
delta_T (float) – 陡峭度参数
theta_rh (float) – 基强度电压阈值
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 \(\Delta_T\to 0\) 时退化为LIF模型。其阈下神经动力学方程为:
\[H[t] = V[t-1] + \frac{1}{\tau}\left(X[t] - (V[t-1] - V_{rest}) + \Delta_T\exp\left(\frac{V[t-1] - \theta_{rh}}{\Delta_T}\right)\right)\]- 参数
tau (float) – membrane time constant
delta_T (float) – sharpness parameter
theta_rh (float) – rheobase threshold
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when \(\Delta_T\to 0\). The subthreshold neural dynamics of it is as followed:
\[H[t] = V[t-1] + \frac{1}{\tau}\left(X[t] - (V[t-1] - V_{rest}) + \Delta_T\exp\left(\frac{V[t-1] - \theta_{rh}}{\Delta_T}\right)\right)\]- property supported_backends
- class spikingjelly.activation_based.neuron.IzhikevichNode(tau: float = 2.0, v_c: float = 0.8, a0: float = 1.0, v_threshold: float = 1.0, v_reset: float = 0.0, v_rest: float = -0.1, w_rest: float = 0.0, tau_w: float = 2.0, a: float = 0.0, b: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
-
- property supported_backends
- class spikingjelly.activation_based.neuron.LIAFNode(act: Callable, threshold_related: bool, *args, **kwargs)[源代码]
基类:
LIFNode
- 参数
act (Callable) – 激活函数
threshold_related (bool) – 是否使用阈值依赖模式 (TR mode). 若为
True
则y = act(h - v_th)
, 否则y = act(h)
LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing 提出的LIAF神经元。LIAFNode和LIFNode的行为相同,但输出是
self.act(...)
而非脉冲。警告
The outputs of this neurons layer are not binary spikes.
- 参数
act (Callable) – the activation function
threshold_related (bool) – whether the neuron uses threshold related (TR mode). If
True
,y = act(h - v_th)
, otherwisey = act(h)
Other parameters in *args, **kwargs are same with
LIFNode
.The LIAF neuron proposed in LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing. LIAFNode has the same behavior as LIFNode, but outputs
self.act(...)
rather than spikes.Warning
The outputs of this neurons layer are not binary spikes.
- property supported_backends
- class spikingjelly.activation_based.neuron.KLIFNode(scale_reset: bool = False, tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.0, v_reset: float = 0.0, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False)[源代码]
基类:
BaseNode
- 参数
scale_reset (bool) – 是否在
neuronal_reset
时将v
进行缩放tau (float) – 膜电位时间常数
decay_input (bool) – 输入是否也会参与衰减
v_threshold (float) – 神经元的阈值电压
v_reset (float) – 神经元的重置电压。如果不为
None
,当神经元释放脉冲后,电压会被重置为v_reset
; 如果设置为None
,当神经元释放脉冲后,电压会被减去v_threshold
surrogate_function (Callable) – 反向传播时用来计算脉冲函数梯度的替代函数
detach_reset (bool) – 是否将reset过程的计算图分离
step_mode (str) – 步进模式,可以为 ‘s’ (单步) 或 ‘m’ (多步)
backend (str) – 使用那种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。在支持的情况下,使用'cupy'
后端是速度最快的store_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.v
。 通常设置成False
,可以节省内存
KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential 提出的K-based Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若
decay_input == True
:\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]若
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]注意,KLIF神经元的放电和重置与普通的神经元不同,为:
\[ \begin{align}\begin{aligned}F[t] &= \mathrm{ReLU}(kH[t])\\S[t] &= \Theta(F[t] - V_{th})\end{aligned}\end{align} \]如果
scale_reset == False
,则\[\begin{split}V[t] = \begin{cases} F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\ F[t] - S[t]V_{th}, soft~~reset \end{cases}\end{split}\]如果
scale_reset == True
,则\[\begin{split}V[t] = \begin{cases} \frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\ \frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset \end{cases}\end{split}\]- 参数
scale_reset (bool) – whether scale
v
inneuronal_reset
tau (float) – membrane time constant
decay_input (bool) – whether the input will decay
v_threshold (float) – threshold of this neurons layer
v_reset (float) – reset voltage of this neurons layer. If not
None
, the neuron’s voltage will be set tov_reset
after firing a spike. IfNone
, the neuron’s voltage will subtractv_threshold
after firing a spikesurrogate_function (Callable) – the function for calculating surrogate gradients of the heaviside step function in backward
detach_reset (bool) – whether detach the computation graph of reset in backward
step_mode (str) – the step mode, which can be s (single-step) or m (multi-step)
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. If supported, using'cupy'
backend will have the fastest training speed :type backend: str- 参数
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.v
withshape = [N, *]
, which can reduce the memory consumption
The K-based Leaky Integrate-and-Fire neuron proposed by KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential, which can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed:
IF
decay_input == True
:\[H[t] = V[t-1] + \frac{1}{\tau}(X[t] - (V[t-1] - V_{reset}))\]IF
decay_input == False
:\[H[t] = V[t-1] - \frac{1}{\tau}(V[t-1] - V_{reset}) + X[t]\]Note that the neuronal fire and reset of the KLIF neuron is different from native neurons:
\[ \begin{align}\begin{aligned}F[t] &= \mathrm{ReLU}(kH[t])\\S[t] &= \Theta(F[t] - V_{th})\end{aligned}\end{align} \]If
scale_reset == False
, then\[\begin{split}V[t] = \begin{cases} F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\ F[t] - S[t]V_{th}, soft~~reset \end{cases}\end{split}\]Elif
scale_reset == True
, then\[\begin{split}V[t] = \begin{cases} \frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\ \frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset \end{cases}\end{split}\]- static neuronal_charge_decay_input(x: Tensor, v: Tensor, v_reset: float, tau: float, k: Tensor)[源代码]
spikingjelly.activation_based.neuron_kernel package
Module contents
- class spikingjelly.activation_based.neuron_kernel.MultiStepIFNodePTT(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.neuron_kernel.MultiStepLIFNodePTT(*args, **kwargs)[源代码]
基类:
Function
- static create_fptt_kernel(decay_input: bool, hard_reset: bool, dtype: str, kernel_name_prefix: str = 'LIFNode')[源代码]
- static create_bptt_kernel(sg_cuda_code_fun, decay_input: bool, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
- class spikingjelly.activation_based.neuron_kernel.MultiStepParametricLIFNodePTT(*args, **kwargs)[源代码]
基类:
Function
- static create_bptt_kernel(sg_cuda_code_fun, decay_input: bool, hard_reset: bool, detach_reset: bool, dtype: str)[源代码]
- spikingjelly.activation_based.neuron_kernel.check_multi_step_neuron_output_and_grad(device, multi_step_neuron, shape=[65, 15, 511], *neu_args, **neu_kwargs)[源代码]
- class spikingjelly.activation_based.neuron_kernel.MultiStepQIFNodePTT(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.neuron_kernel.MultiStepIzhikevichNodePTT(*args, **kwargs)[源代码]
基类:
Function
spikingjelly.activation_based.quantize package
Module contents
- spikingjelly.activation_based.quantize.round(x: Tensor)[源代码]
- 参数
x (torch.Tensor) – the input tensor
- 返回
the output tensor
- 返回类型
Apply
y = torch.round(x)
with re-defining gradient as \(\frac{\partial y}{\partial x} = 1\).
- spikingjelly.activation_based.quantize.ceil(x: Tensor)[源代码]
- 参数
x (torch.Tensor) – the input tensor
- 返回
the output tensor
- 返回类型
Apply
y = torch.ceil(x)
with re-defining gradient as \(\frac{\partial y}{\partial x} = 1\).
- spikingjelly.activation_based.quantize.floor(x: Tensor)[源代码]
- 参数
x (torch.Tensor) – the input tensor
- 返回
the output tensor
- 返回类型
Apply
y = torch.floor(x)
with re-defining gradient as \(\frac{\partial y}{\partial x} = 1\).
- spikingjelly.activation_based.quantize.clamp_backward(grad_output: Tensor, x: Tensor, min_value: float, max_value: float)[源代码]
- spikingjelly.activation_based.quantize.clamp(x: Tensor, min_value: float, max_value: float)[源代码]
- 参数
x (torch.Tensor) – the input tensor
min_value (float) – lower-bound of the range to be clamped to
max_value (torch.Tensor) – upper-bound of the range to be clamped to
- 返回
the output tensor
- 返回类型
Apply
y = torch.clamp(x, min_value, max_value)
with re-defining gradient as:\[\begin{split}\frac{\partial y}{\partial x} = \begin{cases} 1, \rm{min\_value} \leq x \leq \rm{max\_value} \\ 0, \rm{otherwise} \end{cases}\end{split}\]
- spikingjelly.activation_based.quantize.step_quantize(x: Tensor, step: float)[源代码]
- 参数
x (torch.Tensor) – the input tensor
step (float) – the quantize step
- 返回
the quantized tensor
- 返回类型
Quantize
x
to the nearesti * step
, wherei
is an integer.Note that the gradient is defined by \(\frac{\partial y}{\partial x} = 1\).
- spikingjelly.activation_based.quantize.k_bit_quantize(x: Tensor, k: int)[源代码]
- 参数
x (torch.Tensor) – a float tensor whose range is
[0, 1]
.k (int) – the bit number of output
- 返回
y = round((2 ** k - 1) * x) / (2 ** k - 1)
- 返回类型
The k-bit quantizer defined in DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients.
The input whose range is
[0, 1]
will be quantized to the nearesti / (2 ** k - 1)
, wherei = 0, 1, ..., (2 ** k - 1)
.Note that the gradient is defined by \(\frac{\partial y}{\partial x} = 1\).
To clamp the input whose range is
(-inf, inf)
to range(0, 1)
, usingtorch.sigmoid
,torch.nn.Hardtanh
orclamp_*
functions (e.g.,spikingjelly.activation_based.quantize.clamp_by_linear
) inspikingjelly.activation_based.quantize
.Codes example:
x = torch.rand(8) y = k_bit_quantize(x, 2) print(f'x={x}') print(f'y={y}') # x=tensor([0.6965, 0.5697, 0.9883, 0.0438, 0.1332, 0.7613, 0.9704, 0.2384]) # y=tensor([0.6667, 0.6667, 1.0000, 0.0000, 0.0000, 0.6667, 1.0000, 0.3333])
- spikingjelly.activation_based.quantize.clamp_by_linear(x: Tensor, eps: float = 1e-05)[源代码]
- 参数
x (torch.Tensor) – the input tensor to be normed, whose range is
(-inf, inf)
eps (float) – a value added to the denominator for numerical stability. The default value is
1e-5
- 返回
the normed tensor, whose range is
[min_value, max_value]
- 返回类型
Using the linear transform to clamp the input range from
(-inf, inf)
to[0., 1.]
:\[y = \frac{x - \rm{min}(x)}{\rm{max}(x) - \rm{min}(x) + eps}\]
spikingjelly.activation_based.rnn package
Module contents
- spikingjelly.activation_based.rnn.bidirectional_rnn_cell_forward(cell: Module, cell_reverse: Module, x: Tensor, states: Tensor, states_reverse: Tensor)[源代码]
- 参数
cell (nn.Module) – 正向RNN cell,输入是正向序列
cell_reverse (nn.Module) – 反向的RNN cell,输入是反向序列
x (torch.Tensor) –
shape = [T, batch_size, input_size]
的输入states (torch.Tensor) – 正向RNN cell的起始状态 若RNN cell只有单个隐藏状态,则
shape = [batch_size, hidden_size]
; 否则shape = [states_num, batch_size, hidden_size]
states_reverse – 反向RNN cell的起始状态 若RNN cell只有单个隐藏状态,则
shape = [batch_size, hidden_size]
; 否则shape = [states_num, batch_size, hidden_size]
- 返回
y, ss, ss_r
- y: torch.Tensor
shape = [T, batch_size, 2 * hidden_size]
的输出。y[t]
由正向cell在t
时刻和反向cell在T - t - 1
时刻的输出拼接而来- ss: torch.Tensor
shape
与states
相同,正向cell在T-1
时刻的状态- ss_r: torch.Tensor
shape
与states_reverse
相同,反向cell在0
时刻的状态
计算单个正向和反向RNN cell沿着时间维度的循环并输出结果和两个cell的最终状态。
- class spikingjelly.activation_based.rnn.SpikingRNNCellBase(input_size: int, hidden_size: int, bias=True)[源代码]
基类:
Module
Spiking RNN Cell 的基类。
- 参数
备注
所有权重和偏置项都会按照 \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) 进行初始化。 其中 \(k = \frac{1}{\text{hidden_size}}\).
The base class of Spiking RNN Cell.
- 参数
Note
All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden_size}}\).
- class spikingjelly.activation_based.rnn.SpikingRNNBase(input_size, hidden_size, num_layers, bias=True, dropout_p=0, invariant_dropout_mask=False, bidirectional=False, *args, **kwargs)[源代码]
基类:
Module
多层 脉冲 RNN的基类。
- 参数
input_size (int) – 输入
x
的特征数hidden_size (int) – 隐藏状态
h
的特征数num_layers (int) – 内部RNN的层数,例如
num_layers = 2
将会创建堆栈式的两层RNN,第1层接收第0层的输出作为输入, 并计算最终输出bias (bool) – 若为
False
, 则内部的隐藏层不会带有偏置项b_ih
和b_hh
。 默认为True
dropout_p (float) – 若非
0
,则除了最后一层,每个RNN层后会增加一个丢弃概率为dropout_p
的 Dropout 层。 默认为0
invariant_dropout_mask (bool) – 若为
False
,则使用普通的 Dropout;若为True
,则使用SNN中特有的,mask 不 随着时间变化的 Dropout`,参见Dropout
。默认为False
bidirectional (bool) – 若为
True
,则使用双向RNN。默认为False
args – 子类使用的额外参数
kwargs – 子类使用的额外参数
The base-class of a multi-layer spiking RNN.
- 参数
input_size (int) – The number of expected features in the input
x
hidden_size (int) – The number of features in the hidden state
h
num_layers (int) – Number of recurrent layers. E.g., setting
num_layers=2
would mean stacking two LSTMs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final resultsbias (bool) – If
False
, then the layer does not use bias weights b_ih and b_hh. Default:True
dropout_p (float) – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to
dropout
. Default: 0invariant_dropout_mask (bool) – If
False
,use the naive Dropout;IfTrue
,use the dropout in SNN that mask doesn’t change in different time steps, seeDropout
for more information. Defaule:False
bidirectional (bool) – If
True
, becomes a bidirectional LSTM. Default:False
args – additional arguments for sub-class
kwargs – additional arguments for sub-class
- create_cells(*args, **kwargs)[源代码]
-
- 参数
args – 子类使用的额外参数
kwargs – 子类使用的额外参数
- 返回
若
self.bidirectional == True
则会返回正反两个堆栈式RNN;否则返回单个堆栈式RNN- 返回类型
nn.Sequential
- 参数
args – additional arguments for sub-class
kwargs – additional arguments for sub-class
- 返回
If
self.bidirectional == True
, return a RNN for forward direction and a RNN for reverse direction; else, return a single stacking RNN- 返回类型
nn.Sequential
- static base_cell()[源代码]
-
- 返回
构成该RNN的基本RNN Cell。例如对于
SpikingLSTM
, 返回的是SpikingLSTMCell
- 返回类型
nn.Module
- 返回
The base cell of this RNN. E.g., in
SpikingLSTM
this function will returnSpikingLSTMCell
- 返回类型
nn.Module
- static states_num()[源代码]
-
- 返回
状态变量的数量。例如对于
SpikingLSTM
,由于其输出是h
和c
, 因此返回2
;而对于SpikingGRU
,由于其输出是h
,因此返回1
- 返回类型
- 返回
The states number. E.g., for
SpikingLSTM
the output areh
andc
, this function will return2
; forSpikingGRU
the output ish
, this function will return1
- 返回类型
- forward(x: Tensor, states=None)[源代码]
-
- 参数
x (torch.Tensor) –
shape = [T, batch_size, input_size]
,输入序列states (torch.Tensor or tuple) –
self.states_num()
为1
时是单个tensor, 否则是一个tuple,包含self.states_num()
个tensors。 所有的tensor的尺寸均为shape = [num_layers * num_directions, batch, hidden_size]
, 包含self.states_num()
个初始状态 如果RNN是双向的,num_directions
为2
, 否则为1
- 返回
output, output_states output: torch.Tensor
shape = [T, batch, num_directions * hidden_size]
,最后一层在所有时刻的输出- output_states: torch.Tensor or tuple
self.states_num()
为1
时是单个tensor, 否则是一个tuple,包含self.states_num()
个tensors。 所有的tensor的尺寸均为shape = [num_layers * num_directions, batch, hidden_size]
, 包含self.states_num()
个最后时刻的状态
- 参数
x (torch.Tensor) –
shape = [T, batch_size, input_size]
, tensor containing the features of the input sequencestates (torch.Tensor or tuple) – a single tensor when
self.states_num()
is1
, otherwise a tuple withself.states_num()
tensors.shape = [num_layers * num_directions, batch, hidden_size]
for all tensors, containing theself.states_num()
initial states for each element in the batch. If the RNN is bidirectional,num_directions
should be2
, else it should be1
- 返回
output, output_states output: torch.Tensor
shape = [T, batch, num_directions * hidden_size]
, tensor containing the output features from the last layer of the RNN, for eacht
- output_states: torch.Tensor or tuple
a single tensor when
self.states_num()
is1
, otherwise a tuple withself.states_num()
tensors.shape = [num_layers * num_directions, batch, hidden_size]
for all tensors, containing theself.states_num()
states fort = T - 1
- class spikingjelly.activation_based.rnn.SpikingLSTMCell(input_size: int, hidden_size: int, bias=True, surrogate_function1=Erf(alpha=2.0, spiking=True), surrogate_function2=None)[源代码]
-
脉冲 长短时记忆 (LSTM) cell, 最先由 Long Short-Term Memory Spiking Networks and Their Applications 一文提出。
\[\begin{split}i &= \Theta(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f &= \Theta(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g &= \Theta(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o &= \Theta(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' &= f * c + i * g \\ h' &= o * c'\end{split}\]其中 \(\Theta\) 是heaviside阶跃函数(脉冲函数), and \(*\) 是Hadamard点积,即逐元素相乘。
- 参数
input_size (int) – 输入
x
的特征数hidden_size (int) – 隐藏状态
h
的特征数bias (bool) – 若为
False
, 则内部的隐藏层不会带有偏置项b_ih
和b_hh
。 默认为True
surrogate_function1 (spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – 反向传播时用来计算脉冲函数梯度的替代函数, 计算
i
,f
,o
反向传播时使用surrogate_function2 (None or spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – 反向传播时用来计算脉冲函数梯度的替代函数, 计算
g
反向传播时使用。 若为None
, 则设置成surrogate_function1
。默认为None
备注
所有权重和偏置项都会按照 \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) 进行初始化。 其中 \(k = \frac{1}{\text{hidden_size}}\).
示例代码:
T = 6 batch_size = 2 input_size = 3 hidden_size = 4 rnn = rnn.SpikingLSTMCell(input_size, hidden_size) input = torch.randn(T, batch_size, input_size) * 50 h = torch.randn(batch_size, hidden_size) c = torch.randn(batch_size, hidden_size) output = [] for t in range(T): h, c = rnn(input[t], (h, c)) output.append(h) print(output)
A spiking long short-term memory (LSTM) cell, which is firstly proposed in Long Short-Term Memory Spiking Networks and Their Applications.
\[\begin{split}i &= \Theta(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ f &= \Theta(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ g &= \Theta(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ o &= \Theta(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ c' &= f * c + i * g \\ h' &= o * c'\end{split}\]where \(\Theta\) is the heaviside function, and \(*\) is the Hadamard product.
- 参数
input_size (int) – The number of expected features in the input
x
hidden_size (The number of features in the hidden state
h
) – intbias (bool) – If
False
, then the layer does not use bias weightsb_ih
andb_hh
. Default:True
surrogate_function1 (spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – surrogate function for replacing gradient of spiking functions during back-propagation, which is used for generating
i
,f
,o
surrogate_function2 (None or spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – surrogate function for replacing gradient of spiking functions during back-propagation, which is used for generating
g
. IfNone
, the surrogate function for generatingg
will be set assurrogate_function1
. Default:None
Note
All the weights and biases are initialized from \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) where \(k = \frac{1}{\text{hidden_size}}\).
Examples:
T = 6 batch_size = 2 input_size = 3 hidden_size = 4 rnn = rnn.SpikingLSTMCell(input_size, hidden_size) input = torch.randn(T, batch_size, input_size) * 50 h = torch.randn(batch_size, hidden_size) c = torch.randn(batch_size, hidden_size) output = [] for t in range(T): h, c = rnn(input[t], (h, c)) output.append(h) print(output)
- forward(x: Tensor, hc=None)[源代码]
-
- 参数
x (torch.Tensor) –
shape = [batch_size, input_size]
的输入hc (tuple or None) –
(h_0, c_0) h_0 : torch.Tensor
shape = [batch_size, hidden_size]
,起始隐藏状态- c_0torch.Tensor
shape = [batch_size, hidden_size]
,起始细胞状态
如果不提供(h_0, c_0),
h_0
默认c_0
默认为0
- 返回
(h_1, c_1) : h_1 : torch.Tensor
shape = [batch_size, hidden_size]
,下一个时刻的隐藏状态- c_1torch.Tensor
shape = [batch_size, hidden_size]
,下一个时刻的细胞状态
- 返回类型
- 参数
x (torch.Tensor) – the input tensor with
shape = [batch_size, input_size]
hc (tuple or None) –
(h_0, c_0) h_0 : torch.Tensor
shape = [batch_size, hidden_size]
, tensor containing the initial hidden state for each element in the batch- c_0torch.Tensor
shape = [batch_size, hidden_size]
, tensor containing the initial cell state for each element in the batch
If (h_0, c_0) is not provided, both
h_0
andc_0
default to zero
- 返回
(h_1, c_1) : h_1 : torch.Tensor
shape = [batch_size, hidden_size]
, tensor containing the next hidden state for each element in the batch- c_1torch.Tensor
shape = [batch_size, hidden_size]
, tensor containing the next cell state for each element in the batch
- 返回类型
- class spikingjelly.activation_based.rnn.SpikingLSTM(input_size, hidden_size, num_layers, bias=True, dropout_p=0, invariant_dropout_mask=False, bidirectional=False, surrogate_function1=Erf(alpha=2.0, spiking=True), surrogate_function2=None)[源代码]
-
多层`脉冲` 长短时记忆LSTM, 最先由 Long Short-Term Memory Spiking Networks and Their Applications 一文提出。
每一层的计算按照
\[\begin{split}i_{t} &= \Theta(W_{ii} x_{t} + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_{t} &= \Theta(W_{if} x_{t} + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_{t} &= \Theta(W_{ig} x_{t} + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_{t} &= \Theta(W_{io} x_{t} + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_{t} &= f_{t} * c_{t-1} + i_{t} * g_{t} \\ h_{t} &= o_{t} * c_{t-1}'\end{split}\]其中 \(h_{t}\) 是 \(t\) 时刻的隐藏状态,\(c_{t}\) 是 \(t\) 时刻的细胞状态,\(h_{t-1}\) 是该层 \(t-1\) 时刻的隐藏状态或起始状态,\(i_{t}\),\(f_{t}\),\(g_{t}\),\(o_{t}\) 分别是输入,遗忘,细胞,输出门, \(\Theta\) 是heaviside阶跃函数(脉冲函数), and \(*\) 是Hadamard点积,即逐元素相乘。
- 参数
input_size (int) – 输入
x
的特征数hidden_size (int) – 隐藏状态
h
的特征数num_layers (int) – 内部RNN的层数,例如
num_layers = 2
将会创建堆栈式的两层RNN,第1层接收第0层的输出作为输入, 并计算最终输出bias (bool) – 若为
False
, 则内部的隐藏层不会带有偏置项b_ih
和b_hh
。 默认为True
dropout_p (float) – 若非
0
,则除了最后一层,每个RNN层后会增加一个丢弃概率为dropout_p
的 Dropout 层。 默认为0
invariant_dropout_mask (bool) – 若为
False
,则使用普通的 Dropout;若为True
,则使用SNN中特有的,mask 不 随着时间变化的 Dropout`,参见Dropout
。默认为False
bidirectional (bool) – 若为
True
,则使用双向RNN。默认为False
surrogate_function1 (spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – 反向传播时用来计算脉冲函数梯度的替代函数, 计算
i
,f
,o
反向传播时使用surrogate_function2 (None or spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – 反向传播时用来计算脉冲函数梯度的替代函数, 计算
g
反向传播时使用。 若为None
, 则设置成surrogate_function1
。默认为None
The spiking multi-layer long short-term memory (LSTM), which is firstly proposed in Long Short-Term Memory Spiking Networks and Their Applications.
For each element in the input sequence, each layer computes the following function:
\[\begin{split}i_{t} &= \Theta(W_{ii} x_{t} + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ f_{t} &= \Theta(W_{if} x_{t} + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ g_{t} &= \Theta(W_{ig} x_{t} + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ o_{t} &= \Theta(W_{io} x_{t} + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ c_{t} &= f_{t} * c_{t-1} + i_{t} * g_{t} \\ h_{t} &= o_{t} * c_{t-1}'\end{split}\]where \(h_t\) is the hidden state at time t, \(c_t\) is the cell state at time t, \(x_t\) is the input at time t, \(h_{t-1}\) is the hidden state of the layer at time t-1 or the initial hidden state at time 0, and \(i_t\), \(f_t\), \(g_t\), \(o_t\) are the input, forget, cell, and output gates, respectively. \(\Theta\) is the heaviside function, and \(*\) is the Hadamard product.
- 参数
input_size (int) – The number of expected features in the input
x
hidden_size (int) – The number of features in the hidden state
h
num_layers (int) – Number of recurrent layers. E.g., setting
num_layers=2
would mean stacking two LSTMs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final resultsbias (bool) – If
False
, then the layer does not use bias weights b_ih and b_hh. Default:True
dropout_p (float) – If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to
dropout
. Default: 0invariant_dropout_mask (bool) – If
False
,use the naive Dropout;IfTrue
,use the dropout in SNN that mask doesn’t change in different time steps, seeDropout
for more information. Defaule:False
bidirectional (bool) – If
True
, becomes a bidirectional LSTM. Default:False
surrogate_function1 (spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – surrogate function for replacing gradient of spiking functions during back-propagation, which is used for generating
i
,f
,o
surrogate_function2 (None or spikingjelly.activation_based.surrogate.SurrogateFunctionBase) – surrogate function for replacing gradient of spiking functions during back-propagation, which is used for generating
g
. IfNone
, the surrogate function for generatingg
will be set assurrogate_function1
. Default:None
- class spikingjelly.activation_based.rnn.SpikingVanillaRNNCell(input_size: int, hidden_size: int, bias=True, surrogate_function=Erf(alpha=2.0, spiking=True))[源代码]
- class spikingjelly.activation_based.rnn.SpikingVanillaRNN(input_size, hidden_size, num_layers, bias=True, dropout_p=0, invariant_dropout_mask=False, bidirectional=False, surrogate_function=Erf(alpha=2.0, spiking=True))[源代码]
- class spikingjelly.activation_based.rnn.SpikingGRUCell(input_size: int, hidden_size: int, bias=True, surrogate_function1=Erf(alpha=2.0, spiking=True), surrogate_function2=None)[源代码]
spikingjelly.activation_based.spike_op package
Module contents
- spikingjelly.activation_based.spike_op.spike_linear(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None) Tensor [源代码]
-
torch.nn.functional.linear
在输入为脉冲时的特例。备注
在CUDA设备上训练时拥有比
torch.nn.functional.linear
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.functional.linear
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.functional.linear
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- spikingjelly.activation_based.spike_op.spike_conv1d(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[int, Size, List[int], Tuple[int, ...]] = 1, padding: str = 'valid', dilation: Union[int, Size, List[int], Tuple[int, ...]] = 1, groups: int = 1) Tensor [源代码]
-
torch.nn.functional.conv1d
在输入为脉冲时的特例。备注
在CUDA设备上训练时拥有比
torch.nn.functional.conv1d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.functional.conv1d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.functional.conv1d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- spikingjelly.activation_based.spike_op.spike_conv2d(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[int, Size, List[int], Tuple[int, ...]] = 1, padding: str = 'valid', dilation: Union[int, Size, List[int], Tuple[int, ...]] = 1, groups: int = 1) Tensor [源代码]
-
torch.nn.functional.conv2d
在输入为脉冲时的特例。备注
在CUDA设备上训练时拥有比
torch.nn.functional.conv2d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.functional.conv2d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.functional.conv2d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- spikingjelly.activation_based.spike_op.spike_conv3d(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[int, Size, List[int], Tuple[int, ...]] = 1, padding: str = 'valid', dilation: Union[int, Size, List[int], Tuple[int, ...]] = 1, groups: int = 1) Tensor [源代码]
-
torch.nn.functional.conv3d
在输入为脉冲时的特例。备注
在CUDA设备上训练时拥有比
torch.nn.functional.conv3d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.functional.conv3d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.functional.conv3d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- class spikingjelly.activation_based.spike_op.SpikeLinear(in_features: int, out_features: int, bias: bool = True, device=None, dtype=None)[源代码]
基类:
Linear
torch.nn.Linear
在输入为脉冲时的特例。备注
在CUDA设备上运行时拥有比
torch.nn.Linear
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.Linear
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.Linear
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- class spikingjelly.activation_based.spike_op.SpikeConv1d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int]], stride: Union[int, Tuple[int]] = 1, padding: Union[str, int, Tuple[int]] = 0, dilation: Union[int, Tuple[int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None)[源代码]
基类:
Conv1d
torch.nn.Conv1d
在输入为脉冲时的特例。备注
在CUDA设备上运行时拥有比
torch.nn.Conv1d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.Conv1d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.Conv1d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- class spikingjelly.activation_based.spike_op.SpikeConv2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[str, int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None)[源代码]
基类:
Conv2d
torch.nn.Conv2d
在输入为脉冲时的特例。备注
在CUDA设备上运行时拥有比
torch.nn.Conv2d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.Conv2d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.Conv2d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
- class spikingjelly.activation_based.spike_op.SpikeConv3d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int, int]], stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[str, int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None)[源代码]
基类:
Conv3d
torch.nn.Conv3d
在输入为脉冲时的特例。备注
在CUDA设备上运行时拥有比
torch.nn.Conv3d
更低的显存消耗。警告
spike 中的任何元素都必须为0或1。
A specific case of
torch.nn.Conv3d
with inputs are spikes.Note
This function has less memory consumption than
torch.nn.Conv3d
when training on CUDA devices.Warning
Any element in spike must be 0 or 1.
spikingjelly.activation_based.surrogate package
Module contents
- spikingjelly.activation_based.surrogate.heaviside(x: Tensor)[源代码]
-
- 参数
x – 输入tensor
- 返回
输出tensor
heaviside阶跃函数,定义为
\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]阅读 HeavisideStepFunction 以获得更多信息。
- 参数
x – the input tensor
- 返回
the output tensor
The heaviside function, which is defined by
\[\begin{split}g(x) = \begin{cases} 1, & x \geq 0 \\ 0, & x < 0 \\ \end{cases}\end{split}\]For more information, see HeavisideStepFunction.
- spikingjelly.activation_based.surrogate.check_manual_grad(primitive_function, spiking_function, *args, **kwargs)[源代码]
- 参数
primitive_function (callable) – 梯度替代函数的原函数
spiking_function (callable) – 梯度替代函数
梯度替代函数的反向传播一般是手写的,可以用此函数去检查手写梯度是否正确。
此函数检查梯度替代函数spiking_function的反向传播,与原函数primitive_function的反向传播结果是否一致。“一致”被定义为,两者的误差不超过eps。
示例代码:
def s2nn_apply(x, alpha, beta): return surrogate.s2nn.apply(x, alpha, beta) surrogate.check_manual_grad(surrogate.S2NN.primitive_function, s2nn_apply, alpha=4., beta=1.)
- spikingjelly.activation_based.surrogate.check_cuda_grad(neu, surrogate_function, device, *args, **kwargs)[源代码]
- class spikingjelly.activation_based.surrogate.SurrogateFunctionBase(alpha, spiking=True)[源代码]
基类:
Module
- class spikingjelly.activation_based.surrogate.MultiArgsSurrogateFunctionBase(spiking: bool, *args, **kwargs)[源代码]
基类:
Module
- spikingjelly.activation_based.surrogate.piecewise_quadratic_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.piecewise_quadratic(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.surrogate.PiecewiseQuadratic(alpha=1.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用分段二次函数的梯度(三角形函数)的脉冲发放函数。反向传播为
\[\begin{split}g'(x) = \begin{cases} 0, & |x| > \frac{1}{\alpha} \\ -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} \end{cases}\end{split}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} 0, & x < -\frac{1}{\alpha} \\ -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ 1, & x > \frac{1}{\alpha} \\ \end{cases}\end{split}\]- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The piecewise quadratic surrogate spiking function. The gradient is defined by
\[\begin{split}g'(x) = \begin{cases} 0, & |x| > \frac{1}{\alpha} \\ -\alpha^2|x|+\alpha, & |x| \leq \frac{1}{\alpha} \end{cases}\end{split}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} 0, & x < -\frac{1}{\alpha} \\ -\frac{1}{2}\alpha^2|x|x + \alpha x + \frac{1}{2}, & |x| \leq \frac{1}{\alpha} \\ 1, & x > \frac{1}{\alpha} \\ \end{cases}\end{split}\]
- spikingjelly.activation_based.surrogate.piecewise_exp_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.PiecewiseExp(alpha=1.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用分段指数函数的梯度的脉冲发放函数。反向传播为
\[g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} \frac{1}{2}e^{\alpha x}, & x < 0 \\ 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 \end{cases}\end{split}\]- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The piecewise exponential surrogate spiking function. The gradient is defined by
\[g'(x) = \frac{\alpha}{2}e^{-\alpha |x|}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} \frac{1}{2}e^{\alpha x}, & x < 0 \\ 1 - \frac{1}{2}e^{-\alpha x}, & x \geq 0 \end{cases}\end{split}\]
- spikingjelly.activation_based.surrogate.sigmoid_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.Sigmoid(alpha=4.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用sigmoid的梯度的脉冲发放函数。反向传播为
\[g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)\]对应的原函数为
\[g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}\]- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The sigmoid surrogate spiking function. The gradient is defined by
\[g'(x) = \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x)\]The primitive function is defined by
\[g(x) = \mathrm{sigmoid}(\alpha x) = \frac{1}{1+e^{-\alpha x}}\]
- spikingjelly.activation_based.surrogate.soft_sign_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.SoftSign(alpha=2.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用soft sign的梯度的脉冲发放函数。反向传播为
\[g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}} = \frac{1}{2\alpha(\frac{1}{\alpha} + |x|)^{2}}\]对应的原函数为
\[g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1) = \frac{1}{2} (\frac{x}{\frac{1}{\alpha} + |x|} + 1)\]- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The soft sign surrogate spiking function. The gradient is defined by
\[g'(x) = \frac{\alpha}{2(1 + |\alpha x|)^{2}}\]The primitive function is defined by
\[g(x) = \frac{1}{2} (\frac{\alpha x}{1 + |\alpha x|} + 1)\]
- spikingjelly.activation_based.surrogate.atan_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.ATan(alpha=2.0, spiking=True)[源代码]
-
反向传播时使用反正切函数arc tangent的梯度的脉冲发放函数。反向传播为
\[g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}\]对应的原函数为
\[g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}\]The arc tangent surrogate spiking function. The gradient is defined by
\[g'(x) = \frac{\alpha}{2(1 + (\frac{\pi}{2}\alpha x)^2)}\]The primitive function is defined by
\[g(x) = \frac{1}{\pi} \arctan(\frac{\pi}{2}\alpha x) + \frac{1}{2}\]
- spikingjelly.activation_based.surrogate.nonzero_sign_log_abs_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.nonzero_sign_log_abs(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.surrogate.NonzeroSignLogAbs(alpha=1.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
警告
原函数的输出范围并不是(0, 1)。它的优势是反向传播的计算量特别小。
反向传播时使用NonzeroSignLogAbs的梯度的脉冲发放函数。反向传播为
\[g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}\]对应的原函数为
\[g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)\]其中
\[\begin{split}\mathrm{NonzeroSign}(x) = \begin{cases} 1, & x \geq 0 \\ -1, & x < 0 \\ \end{cases}\end{split}\]- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
Warning
The output range the primitive function is not (0, 1). The advantage of this function is that computation cost is small when backward.
The NonzeroSignLogAbs surrogate spiking function. The gradient is defined by
\[g'(x) = \frac{\alpha}{1 + |\alpha x|} = \frac{1}{\frac{1}{\alpha} + |x|}\]The primitive function is defined by
\[g(x) = \mathrm{NonzeroSign}(x) \log (|\alpha x| + 1)\]where
\[\begin{split}\mathrm{NonzeroSign}(x) = \begin{cases} 1, & x \geq 0 \\ -1, & x < 0 \\ \end{cases}\end{split}\]
- spikingjelly.activation_based.surrogate.erf_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.Erf(alpha=2.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的平滑程度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
反向传播时使用高斯误差函数(erf)的梯度的脉冲发放函数。反向传播为
\[g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}\]对应的原函数为
\begin{split} g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ &= \frac{1}{2} \text{erfc}(-\alpha x) \\ &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt \end{split}- 参数
alpha – parameter to control smoothness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The Gaussian error (erf) surrogate spiking function. The gradient is defined by
\[g'(x) = \frac{\alpha}{\sqrt{\pi}}e^{-\alpha^2x^2}\]The primitive function is defined by
\begin{split} g(x) &= \frac{1}{2}(1-\text{erf}(-\alpha x)) \\ &= \frac{1}{2} \text{erfc}(-\alpha x) \\ &= \frac{1}{\sqrt{\pi}}\int_{-\infty}^{\alpha x}e^{-t^2}dt \end{split}
- spikingjelly.activation_based.surrogate.piecewise_leaky_relu_backward(grad_output: Tensor, x: Tensor, w: float, c: float)[源代码]
- class spikingjelly.activation_based.surrogate.piecewise_leaky_relu(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.surrogate.PiecewiseLeakyReLU(w=1.0, c=0.01, spiking=True)[源代码]
基类:
MultiArgsSurrogateFunctionBase
- 参数
w –
-w <= x <= w
时反向传播的梯度为1 / 2w
c –
x > w
或x < -w
时反向传播的梯度为c
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
分段线性的近似脉冲发放函数。梯度为
\[\begin{split}g'(x) = \begin{cases} \frac{1}{w}, & -w \leq x \leq w \\ c, & x < -w ~or~ x > w \end{cases}\end{split}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} cx + cw, & x < -w \\ \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ cx - cw + 1, & x > w \\ \end{cases}\end{split}\]该函数在文章 3 4 5 9 10 12 16 17 中使用。
- 参数
w – when
-w <= x <= w
the gradient is1 / 2w
c – when
x > w
orx < -w
the gradient isc
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The piecewise surrogate spiking function. The gradient is defined by
\[\begin{split}g'(x) = \begin{cases} \frac{1}{w}, & -w \leq x \leq w \\ c, & x < -w ~or~ x > w \end{cases}\end{split}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} cx + cw, & x < -w \\ \frac{1}{2w}x + \frac{1}{2}, & -w \leq x \leq w \\ cx - cw + 1, & x > w \end{cases}\end{split}\]
- class spikingjelly.activation_based.surrogate.squarewave_fourier_series(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.surrogate.SquarewaveFourierSeries(n: int = 2, T_period: float = 8, spiking=True)[源代码]
- class spikingjelly.activation_based.surrogate.S2NN(alpha=4.0, beta=1.0, spiking=True)[源代码]
基类:
MultiArgsSurrogateFunctionBase
- 参数
alpha – 控制
x < 0
时梯度的参数beta – 控制
x >= 0
时梯度的参数spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks 提出的S2NN替代函数。反向传播为
\[\begin{split}g'(x) = \begin{cases} \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ \\frac{beta}{(x + 1)}, x \ge 0 \end{cases}\end{split}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} \mathrm{sigmoid} (\alpha x), x < 0 \\ \beta \mathrm{ln}(x + 1) + 1, x \ge 0 \end{cases}\end{split}\]- 参数
alpha – the param that controls the gradient when
x < 0
beta – the param that controls the gradient when
x >= 0
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The S2NN surrogate spiking function, which is proposed by S2NN: Time Step Reduction of Spiking Surrogate Gradients for Training Energy Efficient Single-Step Neural Networks. The gradient is defined by
\[\begin{split}g'(x) = \begin{cases} \alpha * (1 - \mathrm{sigmoid} (\alpha x)) \mathrm{sigmoid} (\alpha x), x < 0 \\ \beta (x + 1), x \ge 0 \end{cases}\end{split}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} \mathrm{sigmoid} (\alpha x), x < 0 \\ \beta \mathrm{ln}(x + 1) + 1, x \ge 0 \end{cases}\end{split}\]
- class spikingjelly.activation_based.surrogate.QPseudoSpike(alpha=2.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度函数尾部厚度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
Surrogate Gradients Design 提出的 \(q\)-PseudoSpike替代函数。反向传播为
\[g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}\]其中 \(\alpha>1\) 对应原文中的 \(q\)。
对应的原函数为
\[\begin{split}g(x) = \begin{cases} \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. \end{cases}\end{split}\]- 参数
alpha – parameter to control tail fatness of gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The \(q\)-PseudoSpike surrogate spiking function, which is first proposed in Surrogate Gradients Design. The gradient is defined by
\[g'(x) = (1+\frac{2|x|}{\alpha-1})^{-\alpha}\]where \(\alpha>1\) corresponds to \(q\) in paper.
The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} \frac{1}{2}(1-\frac{2x}{\alpha-1})^{1-\alpha}, & x < 0 \\ 1 - \frac{1}{2}(1+\frac{2x}{\alpha-1})^{1-\alpha}, & x \geq 0. \end{cases}\end{split}\]
- spikingjelly.activation_based.surrogate.leaky_k_relu_backward(grad_output: Tensor, x: Tensor, leak: float, k: float)[源代码]
- class spikingjelly.activation_based.surrogate.LeakyKReLU(spiking=True, leak: float = 0.0, k: float = 1.0)[源代码]
基类:
MultiArgsSurrogateFunctionBase
- 参数
spiking (bool) – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagationleak (float) – gradient when
x < 0
反向传播时使用LeakyKReLU的梯度的脉冲发放函数。反向传播为
\[\begin{split}g'(x) = \begin{cases} k, & x \geq 0 \\ leak, & x < 0 \\ \end{cases}\end{split}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} k \cdot x, & x \geq 0 \\ leak \cdot x, & x < 0 \\ \end{cases}\end{split}\]- 参数
The LeakyKReLU surrogate spiking function. The gradient is defined by
\[\begin{split}g'(x) = \begin{cases} k, & x \geq 0 \\ leak, & x < 0 \\ \end{cases}\end{split}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} k \cdot x, & x \geq 0 \\ leak \cdot x, & x < 0 \\ \end{cases}\end{split}\]
- spikingjelly.activation_based.surrogate.fake_numerical_gradient_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.fake_numerical_gradient(*args, **kwargs)[源代码]
基类:
Function
- spikingjelly.activation_based.surrogate.log_tailed_relu_backward(grad_output: Tensor, x: Tensor, alpha: float)[源代码]
- class spikingjelly.activation_based.surrogate.LogTailedReLU(alpha=0.0, spiking=True)[源代码]
-
- 参数
alpha – 控制反向传播时梯度的参数
spiking – 是否输出脉冲,默认为
True
,在前向传播时使用heaviside
而在反向传播使用替代梯度。若为False
则不使用替代梯度,前向传播时,使用反向传播时的梯度替代函数对应的原函数
Deep Learning with Low Precision by Half-wave Gaussian Quantization 提出的 Log-tailed ReLU替代函数。反向传播为
\[\begin{split}g'(x) = \begin{cases} \alpha, & x \leq 0 \\ 1, & 0 < x \leq 0 \\ \frac{1}{x}, x > 1 \\ \end{cases}\end{split}\]对应的原函数为
\[\begin{split}g(x) = \begin{cases} \alpha x, & x \leq 0 \\ x, & 0 < x \leq 0 \\ log(x), x > 1 \\ \end{cases}\end{split}\]- 参数
alpha – parameter to control gradient
spiking – whether output spikes. The default is
True
which means that usingheaviside
in forward propagation and using surrogate gradient in backward propagation. IfFalse
, in forward propagation, using the primitive function of the surrogate gradient function used in backward propagation
The Log-tailed ReLU surrogate spiking function, which is first proposed in Deep Learning with Low Precision by `Half-wave Gaussian Quantization. The gradient is defined by
\[\begin{split}g'(x) = \begin{cases} \alpha, & x \leq 0 \\ 1, & 0 < x \leq 0 \\ \frac{1}{x}, x > 1 \\ \end{cases}\end{split}\]The primitive function is defined by
\[\begin{split}g(x) = \begin{cases} \alpha x, & x \leq 0 \\ x, & 0 < x \leq 0 \\ log(x), x > 1 \\ \end{cases}\end{split}\]
References
- 1(1,2)
Esser S K, Appuswamy R, Merolla P, et al. Backpropagation for energy-efficient neuromorphic computing[J]. Advances in neural information processing systems, 2015, 28: 1117-1125.
- 2(1,2)
Esser S K, Merolla P A, Arthur J V, et al. Convolutional networks for fast, energy-efficient neuromorphic computing[J]. Proceedings of the national academy of sciences, 2016, 113(41): 11441-11446.
- 3(1,2)
Yin S, Venkataramanaiah S K, Chen G K, et al. Algorithm and hardware design of discrete-time spiking neural networks based on back propagation with binary activations[C]//2017 IEEE Biomedical Circuits and Systems Conference (BioCAS). IEEE, 2017: 1-5.
- 4(1,2,3,4,5,6,7,8)
Wu Y, Deng L, Li G, et al. Spatio-temporal backpropagation for training high-performance spiking neural networks[J]. Frontiers in neuroscience, 2018, 12: 331.
- 5(1,2)
Huh D, Sejnowski T J. Gradient descent for spiking neural networks[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1440-1450.
- 6(1,2)
Shrestha S B, Orchard G. SLAYER: spike layer error reassignment in time[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 1419-1428.
- 7(1,2)
Bellec G, Salaj D, Subramoney A, et al. Long short-term memory and learning-to-learn in networks of spiking neurons[C]//Proceedings of the 32nd International Conference on Neural Information Processing Systems. 2018: 795-805.
- 8(1,2)
Zenke F, Ganguli S. Superspike: Supervised learning in multilayer spiking neural networks[J]. Neural computation, 2018, 30(6): 1514-1541.
- 9(1,2)
Wu Y, Deng L, Li G, et al. Direct training for spiking neural networks: Faster, larger, better[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2019, 33(01): 1311-1318.
- 10(1,2)
Gu P, Xiao R, Pan G, et al. STCA: Spatio-Temporal Credit Assignment with Delayed Feedback in Deep Spiking Neural Networks[C]//IJCAI. 2019: 1366-1372.
- 11(1,2,3,4,5,6)
Neftci E O, Mostafa H, Zenke F. Surrogate gradient learning in spiking neural networks: Bringing the power of gradient-based optimization to spiking neural networks[J]. IEEE Signal Processing Magazine, 2019, 36(6): 51-63.
- 12(1,2,3,4)
Roy D, Chakraborty I, Roy K. Scaling deep spiking neural networks with binary stochastic activations[C]//2019 IEEE International Conference on Cognitive Computing (ICCC). IEEE, 2019: 50-58.
- 13(1,2)
Panda P, Aketi S A, Roy K. Toward scalable, efficient, and accurate deep spiking neural networks with backward residual connections, stochastic softmax, and hybridization[J]. Frontiers in Neuroscience, 2020, 14.
- 14(1,2)
Lotfi Rezaabad A, Vishwanath S. Long Short-Term Memory Spiking Networks and Their Applications[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-9.
- 15(1,2)
Woźniak S, Pantazi A, Bohnstingl T, et al. Deep learning incorporating biologically inspired neural dynamics and in-memory computing[J]. Nature Machine Intelligence, 2020, 2(6): 325-336.
- 16(1,2)
Cheng X, Hao Y, Xu J, et al. LISNN: Improving Spiking Neural Networks with Lateral Interactions for Robust Object Recognition[C]//IJCAI. 1519-1525.
- 17(1,2)
Kaiser J, Mostafa H, Neftci E. Synaptic plasticity dynamics for deep continuous local learning (DECOLLE)[J]. Frontiers in Neuroscience, 2020, 14: 424.
- 18(1,2)
Yin B, Corradi F, Bohté S M. Effective and efficient computation with multiple-timescale spiking recurrent neural networks[C]//International Conference on Neuromorphic Systems 2020. 2020: 1-8.
spikingjelly.activation_based.tensor_cache package
Module contents
- class spikingjelly.activation_based.tensor_cache.DataTypeConvertCUDACode[源代码]
基类:
object
- float2bool = '\n extern "C" __global__\n void float2bool(const float* fs, unsigned char* bs, const int &N)\n {\n // assert N == numel / 8 and numel % 8 == 0\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n bs[index] = 0;\n const int mem_offset = (index << 3);\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n bs[index] += ( ((unsigned char) fs[mem_offset + i]) << i);\n }\n }\n }\n '
- half2bool = '\n #include <cuda_fp16.h>\n extern "C" __global__\n void half2bool(const half* fs, unsigned char* bs, const int &N)\n {\n // assert N == numel / 8 and numel % 8 == 0\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n bs[index] = 0;\n const int mem_offset = (index << 3);\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n bs[index] += ( ((unsigned char) __half2float(fs[mem_offset + i])) << i);\n }\n }\n }\n '
- bool2float = '\n extern "C" __global__\n void bool2float(const unsigned char* bs, float* fs, const int &N)\n {\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n const int mem_offset = (index << 3);\n unsigned char compressed_v = bs[index];\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n fs[mem_offset + i] = (float) (compressed_v % 2);\n compressed_v = (compressed_v >> 1);\n }\n }\n }\n '
- bool2half = '\n #include <cuda_fp16.h>\n extern "C" __global__\n void bool2half(const unsigned char* bs, half* fs, const int &N)\n {\n const int index = blockIdx.x * blockDim.x + threadIdx.x;\n if (index < N)\n {\n const int mem_offset = (index << 3);\n unsigned char compressed_v = bs[index];\n #pragma unroll\n for(int i = 0; i < 8; i++)\n {\n fs[mem_offset + i] = __float2half((float) (compressed_v % 2));\n compressed_v = (compressed_v >> 1);\n }\n }\n }\n '
Subpackages
spikingjelly.activation_based.ann2snn package
Subpackages
Submodules
spikingjelly.activation_based.ann2snn.converter module
- class spikingjelly.activation_based.ann2snn.converter.Converter(dataloader, device=None, mode='Max', momentum=0.1, fuse_flag=True)[源代码]
基类:
Module
- 参数
Converter
用于将带有ReLU的ANN转换为SNN。ANN2SNN教程见此处 ANN转换SNN 。
目前支持三种转换模式,由参数mode进行设置。
转换后ReLU模块被删除,SNN需要的新模块(包括VoltageScaler、IFNode等)被创建并存放在snn tailor父模块中。
由于返回值的类型为fx.GraphModule,建议使用print(fx.GraphModule.graph)查看计算图及前向传播关系。更多API参见 GraphModule 。
警告
必须确保ANN中的
ReLU
为module而非function。您最好在ANN模型中使用平均池化而不是最大池化。否则,可能会损害转换后的SNN模型的性能。
- 参数
dataloader (Dataloader) – Dataloader for converting
device (str) – Device
mode (str, float) – Conversion mode. Now support three mode, MaxNorm(mode=’max’), RobustNorm(mode=’99.9%’), and scaling mode(mode=x, where 0<x<=1)
momentum (float) – Momentum value used by modules.VoltageHook
fuse_flag (bool) – Bool specifying if fusion of the conv and the bn happens, by default it happens.
Converter
is used to convert ANN with to SNN.ANN2SNN tutorial is here ANN2SNN .
Three common methods are implemented here, which can be selected by the value of parameter mode.
After converting, ReLU modules will be removed. And new modules needed by SNN, such as VoltageScaler and IFNode, will be created and stored in the parent module ‘snn tailor’.
Due to the type of the return model is fx.GraphModule, you can use ‘print(fx.GraphModule.graph)’ to view how modules links and the how the forward method works. More APIs are here GraphModule .
警告
Make sure that
ReLU
is module rather than function.You’d better use
avgpool
rather thanmaxpool
in your ann model. If not, the performance of the converted snn model may be ruined.- forward(ann: Module)[源代码]
-
- 参数
ann (torch.nn.Module) – 待转换的ann
- 返回
转换得到的snn
- 返回类型
- 参数
ann (torch.nn.Module) – ann to be converted
- 返回
snn
- 返回类型
- static fuse(fx_model: GraphModule, fuse_flag: bool = True) GraphModule [源代码]
-
- 参数
fx_model (torch.fx.GraphModule) – 原模型
fuse_flag (bool) – 标志位,设置为True,则进行conv与bn的融合,反之不进行。
- 返回
conv层和bn层融合后的模型.
- 返回类型
fuse
用于conv与bn的融合。- 参数
fx_model (torch.fx.GraphModule) – Original fx_model
fuse_flag (bool) – Bool specifying if fusion of the conv and the bn happens, by default it happens.
- 返回
fx_model whose conv layer and bn layer have been fused.
- 返回类型
fuse
is used to fuse conv layer and bn layer.
- static set_voltagehook(fx_model: GraphModule, mode='Max', momentum=0.1) GraphModule [源代码]
-
- 参数
fx_model (torch.fx.GraphModule) – 原模型
mode (str, float) – 转换模式。目前支持三种模式,最大电流转换模式,99.9%电流转换模式,以及缩放转换模式
momentum (float) – 动量值,用于VoltageHook
- 返回
带有VoltageHook的模型.
- 返回类型
set_voltagehook
用于给模型添加VoltageHook模块。这里实现了常见的三种模式,同上。- 参数
fx_model (torch.fx.GraphModule) – Original fx_model
mode (str, float) – Conversion mode. Now support three mode, MaxNorm, RobustNorm(99.9%), and scaling mode
momentum (float) – momentum value used by VoltageHook
- 返回
fx_model with VoltageHook.
- 返回类型
set_voltagehook
is used to add VoltageHook to fx_model. Three common methods are implemented here, the same as Converter.mode.
- static replace_by_ifnode(fx_model: GraphModule) GraphModule [源代码]
-
- 参数
fx_model (torch.fx.GraphModule) – 原模型
- 返回
将ReLU替换为IF脉冲神经元后的模型.
- 返回类型
replace_by_ifnode
用于将模型的ReLU替换为IF脉冲神经元。- 参数
fx_model (torch.fx.GraphModule) – Original fx_model
- 返回
fx_model whose ReLU has been replaced by IF neuron.
- 返回类型
replace_by_ifnode
is used to replace ReLU with IF neuron.
spikingjelly.activation_based.ann2snn.modules module
- class spikingjelly.activation_based.ann2snn.modules.VoltageHook(scale=1.0, momentum=0.1, mode='Max')[源代码]
基类:
Module
- 参数
VoltageHook
被置于ReLU后,用于在ANN推理中确定激活的范围。- 参数
scale (float) – initial scaling value
momentum (float) – momentum value
mode (str, float) – The mode. Value “Max” means recording the maximum value of ANN activation, “99.9%” means recording the 99.9% precentile of ANN activation, and a float of 0-1 means recording the corresponding multiple of the maximum activation value.
VoltageHook
is placed behind ReLU and used to determine the range of activations in ANN inference.- forward(x)[源代码]
-
- 参数
x (torch.Tensor) – 输入张量
- 返回
原输入张量
- 返回类型
不对输入张量做任何处理,只是抓取ReLU的激活值
- 参数
x (torch.Tensor) – input tensor
- 返回
original input tensor
- 返回类型
It doesn’t process input tensors, but hooks the activation values of ReLU.
- class spikingjelly.activation_based.ann2snn.modules.VoltageScaler(scale=1.0)[源代码]
基类:
Module
- 参数
scale (float) – 缩放值
VoltageScaler
用于SNN推理中缩放电流。- 参数
scale (float) – scaling value
VoltageScaler
is used for scaling current in SNN inference.- forward(x)[源代码]
-
- 参数
x (torch.Tensor) – 输入张量,亦即输入电流
- 返回
缩放后的电流
- 返回类型
- 参数
x (torch.Tensor) – input tensor, or input current
- 返回
current after scaling
- 返回类型
spikingjelly.activation_based.ann2snn.utils module
Module contents
spikingjelly.activation_based.examples package
Subpackages
- spikingjelly.activation_based.examples.common.multiprocessing_env.worker(remote, parent_remote, env_fn_wrapper)[源代码]
- class spikingjelly.activation_based.examples.common.multiprocessing_env.VecEnv(num_envs, observation_space, action_space)[源代码]
基类:
object
An abstract asynchronous, vectorized environment.
- reset()[源代码]
Reset all the environments and return an array of observations, or a tuple of observation arrays. If step_async is still doing work, that work will be cancelled and step_wait() should not be called until step_async() is invoked again.
- step_async(actions)[源代码]
Tell all the environments to start taking a step with the given actions. Call step_wait() to get the results of the step. You should not call this if a step_async run is already pending.
Submodules
spikingjelly.activation_based.examples.A2C module
spikingjelly.activation_based.examples.DQN_state module
spikingjelly.activation_based.examples.PPO module
spikingjelly.activation_based.examples.Spiking_A2C module
spikingjelly.activation_based.examples.Spiking_DQN_state module
spikingjelly.activation_based.examples.Spiking_PPO module
spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation module
代码作者: Yanqi Chen <chyq@pku.edu.cn>
A reproduction of the paper Enabling Spike-Based Backpropagation for Training Deep Neural Network Architectures.
This code reproduces a novel gradient-based training method of SNN. We to some extent refer to the network structure and some other detailed implementation in the authors’ implementation. Since the training method and neuron models are slightly different from which in this framework, we rewrite them in a compatible style.
Assuming you have at least 1 Nvidia GPU.
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.BaseNode(v_threshold=1.0, v_reset=0.0, surrogate_function=<built-in method apply of FunctionMeta object>, monitor=False)[源代码]
基类:
Module
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.LIFNode(tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=<built-in method apply of FunctionMeta object>, fire=True)[源代码]
基类:
BaseNode
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.IFNode(v_threshold=0.75, v_reset=0.0, surrogate_function=<built-in method apply of FunctionMeta object>)[源代码]
基类:
BaseNode
spikingjelly.activation_based.examples.classify_dvsg module
spikingjelly.activation_based.examples.conv_fashion_mnist module
- class spikingjelly.activation_based.examples.conv_fashion_mnist.CSNN(T: int, channels: int, use_cupy=False)[源代码]
基类:
Module
- spikingjelly.activation_based.examples.conv_fashion_mnist.main()[源代码]
(sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h
- usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]
Classify Fashion-MNIST
- optional arguments:
- -h, --help
show this help message and exit
- -T T
simulating time-steps
-device DEVICE device -b B batch size -epochs N number of total epochs to run -j N number of data loading workers (default: 4) -data-dir DATA_DIR root dir of Fashion-MNIST dataset -out-dir OUT_DIR root dir for saving logs and checkpoint -resume RESUME resume from the checkpoint path -amp automatic mixed precision training -cupy use cupy neuron and multi-step forward mode -opt OPT use which optimizer. SDG or Adam -momentum MOMENTUM momentum for SGD -save-es dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}
spikingjelly.activation_based.examples.dqn_cart_pole module
spikingjelly.activation_based.examples.lif_fc_mnist module
spikingjelly.activation_based.examples.rsnn_sequential_fmnist module
- class spikingjelly.activation_based.examples.rsnn_sequential_fmnist.StatefulSynapseNet[源代码]
基类:
Module
spikingjelly.activation_based.examples.speechcommands module
代码作者: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>
A reproduction of the paper Technical report: supervised training of convolutional spiking neural networks with PyTorch.
This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.
备注
To prevent too much dependency like librosa, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.
Confusion matrix of TEST set after training (50 epochs):
Count |
Prediction |
||||||||||||
“Yes” |
“Stop” |
“No” |
“Right” |
“Up” |
“Left” |
“On” |
“Down” |
“Off” |
“Go” |
Other |
Silence |
||
Ground Truth |
“Yes” |
234 |
0 |
2 |
0 |
0 |
3 |
0 |
0 |
0 |
1 |
16 |
0 |
“Stop” |
0 |
233 |
0 |
1 |
5 |
0 |
0 |
0 |
0 |
1 |
9 |
0 |
|
“No” |
0 |
1 |
223 |
1 |
0 |
1 |
0 |
5 |
0 |
9 |
12 |
0 |
|
“Right” |
0 |
0 |
0 |
234 |
0 |
0 |
0 |
0 |
0 |
0 |
24 |
1 |
|
“Up” |
0 |
4 |
0 |
0 |
249 |
0 |
0 |
0 |
8 |
0 |
11 |
0 |
|
“Left” |
3 |
1 |
2 |
3 |
1 |
250 |
0 |
0 |
1 |
0 |
6 |
0 |
|
“On” |
0 |
3 |
0 |
0 |
0 |
0 |
231 |
0 |
2 |
1 |
9 |
0 |
|
“Down” |
0 |
0 |
7 |
0 |
0 |
1 |
2 |
230 |
0 |
4 |
8 |
1 |
|
“Off” |
0 |
0 |
2 |
1 |
4 |
2 |
6 |
0 |
237 |
1 |
9 |
0 |
|
“Go” |
0 |
2 |
5 |
0 |
0 |
2 |
0 |
1 |
5 |
220 |
16 |
0 |
|
Other |
6 |
21 |
12 |
25 |
22 |
19 |
25 |
14 |
11 |
40 |
4072 |
1 |
|
Silence |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
260 |
- spikingjelly.activation_based.examples.speechcommands.create_fb_matrix(n_freqs: int, f_min: float, f_max: float, n_mels: int, sample_rate: int, dct_type: Optional[str] = 'slaney') Tensor [源代码]
- class spikingjelly.activation_based.examples.speechcommands.MelScaleDelta(order, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0, f_max: Optional[float] = None, n_stft: Optional[int] = None, dct_type: Optional[str] = 'slaney')[源代码]
基类:
Module
spikingjelly.activation_based.examples.spiking_lstm_sequential_mnist module
spikingjelly.activation_based.examples.spiking_lstm_text module
Module contents
spikingjelly.activation_based.model package
Submodules
spikingjelly.activation_based.model.parametric_lif_net module
- class spikingjelly.activation_based.model.parametric_lif_net.MNISTNet(channels=128, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
Module
- class spikingjelly.activation_based.model.parametric_lif_net.FashionMNISTNet(channels=128, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
MNISTNet
- class spikingjelly.activation_based.model.parametric_lif_net.NMNISTNet(channels=128, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
MNISTNet
- class spikingjelly.activation_based.model.parametric_lif_net.CIFAR10Net(channels=256, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
Module
spikingjelly.activation_based.model.sew_resnet module
- class spikingjelly.activation_based.model.sew_resnet.SEWResNet(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
Module
- spikingjelly.activation_based.model.sew_resnet.sew_resnet18(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNet-18
- 返回类型
The spike-element-wise ResNet-18 “Deep Residual Learning in Spiking Neural Networks” modified by the ResNet-18 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.sew_resnet.sew_resnet34(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNet-34
- 返回类型
The spike-element-wise ResNet-34 “Deep Residual Learning in Spiking Neural Networks” modified by the ResNet-34 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.sew_resnet.sew_resnet50(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNet-50
- 返回类型
The spike-element-wise ResNet-50 “Deep Residual Learning in Spiking Neural Networks” modified by the ResNet-50 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.sew_resnet.sew_resnet101(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNet-101
- 返回类型
The spike-element-wise ResNet-101 “Deep Residual Learning in Spiking Neural Networks” modified by the ResNet-101 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.sew_resnet.sew_resnet152(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a single step neuron
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNet-152
- 返回类型
The spike-element-wise ResNet-152 “Deep Residual Learning in Spiking Neural Networks” modified by the ResNet-152 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.sew_resnet.sew_resnext50_32x4d(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a single step neuron
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNeXt-50 32x4d
- 返回类型
The spike-element-wise ResNeXt-50 32x4d “Deep Residual Learning in Spiking Neural Networks” modified by the ResNeXt-50 32x4d model from “Aggregated Residual Transformation for Deep Neural Networks”
- spikingjelly.activation_based.model.sew_resnet.sew_resnext101_32x8d(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a single step neuron
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking ResNeXt-101 32x8d
- 返回类型
The spike-element-wise ResNeXt-101 32x8d “Deep Residual Learning in Spiking Neural Networks” modified by the ResNeXt-101 32x8d model from “Aggregated Residual Transformation for Deep Neural Networks”
- spikingjelly.activation_based.model.sew_resnet.sew_wide_resnet50_2(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a single step neuron
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking Wide ResNet-50-2
- 返回类型
The spike-element-wise Wide ResNet-50-2 “Deep Residual Learning in Spiking Neural Networks” modified by the Wide ResNet-50-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048.
- spikingjelly.activation_based.model.sew_resnet.sew_wide_resnet101_2(pretrained=False, progress=True, cnf: Optional[str] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
cnf (str) – the name of spike-element-wise function
spiking_neuron (callable) – a single step neuron
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking Wide ResNet-101-2
- 返回类型
The spike-element-wise Wide ResNet-101-2 “Deep Residual Learning in Spiking Neural Networks” modified by the Wide ResNet-101-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048.
spikingjelly.activation_based.model.spiking_resnet module
- class spikingjelly.activation_based.model.spiking_resnet.SpikingResNet(block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
Module
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnet18(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNet-18
- 返回类型
A spiking version of ResNet-18 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnet34(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNet-34
- 返回类型
A spiking version of ResNet-34 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnet50(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNet-50
- 返回类型
A spiking version of ResNet-50 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnet101(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNet-101
- 返回类型
A spiking version of ResNet-101 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnet152(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNet-152
- 返回类型
A spiking version of ResNet-152 model from “Deep Residual Learning for Image Recognition”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnext50_32x4d(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNeXt-50 32x4d
- 返回类型
A spiking version of ResNeXt-50 32x4d model from “Aggregated Residual Transformation for Deep Neural Networks”
- spikingjelly.activation_based.model.spiking_resnet.spiking_resnext101_32x8d(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking ResNeXt-101 32x8d
- 返回类型
A spiking version of ResNeXt-101 32x8d model from “Aggregated Residual Transformation for Deep Neural Networks”
- spikingjelly.activation_based.model.spiking_resnet.spiking_wide_resnet50_2(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking Wide ResNet-50-2
- 返回类型
A spiking version of Wide ResNet-50-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048.
- spikingjelly.activation_based.model.spiking_resnet.spiking_wide_resnet101_2(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking Wide ResNet-101-2
- 返回类型
A spiking version of Wide ResNet-101-2 model from “Wide Residual Networks”
The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048.
spikingjelly.activation_based.model.spiking_vgg module
- class spikingjelly.activation_based.model.spiking_vgg.SpikingVGG(cfg, batch_norm=False, norm_layer=None, num_classes=1000, init_weights=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
基类:
Module
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg11(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking VGG-11
- 返回类型
A spiking version of VGG-11 model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg11_bn(pretrained=False, progress=True, norm_layer: Optional[callable] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
norm_layer (callable) – a batch norm layer
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking VGG-11 with norm layer
- 返回类型
A spiking version of VGG-11-BN model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg13(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking VGG-13
- 返回类型
A spiking version of VGG-13 model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg13_bn(pretrained=False, progress=True, norm_layer: Optional[callable] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
norm_layer (callable) – a batch norm layer
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking VGG-11 with norm layer
- 返回类型
A spiking version of VGG-13-BN model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg16(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking VGG-16
- 返回类型
A spiking version of VGG-16 model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg16_bn(pretrained=False, progress=True, norm_layer: Optional[callable] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
norm_layer (callable) – a batch norm layer
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking VGG-16 with norm layer
- 返回类型
A spiking version of VGG-16-BN model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg19(pretrained=False, progress=True, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
- 返回
Spiking VGG-19
- 返回类型
A spiking version of VGG-19 model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
- spikingjelly.activation_based.model.spiking_vgg.spiking_vgg19_bn(pretrained=False, progress=True, norm_layer: Optional[callable] = None, spiking_neuron: Optional[callable] = None, **kwargs)[源代码]
- 参数
pretrained (bool) – If True, the SNN will load parameters from the ANN pre-trained on ImageNet
progress (bool) – If True, displays a progress bar of the download to stderr
norm_layer (callable) – a batch norm layer
spiking_neuron (callable) – a spiking neuron layer
kwargs (dict) – kwargs for spiking_neuron
- 返回
Spiking VGG-19 with norm layer
- 返回类型
A spiking version of VGG-19-BN model from “Very Deep Convolutional Networks for Large-Scale Image Recognition”
spikingjelly.activation_based.model.train_classify module
- spikingjelly.activation_based.model.train_classify.set_deterministic(_seed_: int = 2020, disable_uda=False)[源代码]
spikingjelly.activation_based.model.train_imagenet module
Module contents
Module contents
spikingjelly.datasets package
Submodules
spikingjelly.datasets.asl_dvs module
- class spikingjelly.datasets.asl_dvs.ASLDVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The ASL-DVS dataset, which is proposed by Graph-based Object Classification for Neuromorphic Vision Sensing.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple [源代码]
- 返回
A tuple
(H, W)
, whereH
is the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)
for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.cifar10_dvs module
- spikingjelly.datasets.cifar10_dvs.load_raw_events(fp, bytes_skip=0, bytes_trim=0, filter_dvs=False, times_first=False)[源代码]
- spikingjelly.datasets.cifar10_dvs.parse_raw_address(addr, x_mask=4190208, x_shift=12, y_mask=2143289344, y_shift=22, polarity_mask=2048, polarity_shift=11)[源代码]
- class spikingjelly.datasets.cifar10_dvs.CIFAR10DVS(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The CIFAR10-DVS dataset, which is proposed by `CIFAR10-DVS: An Event-Stream Dataset for Object Classification
<https://internal-journal.frontiersin.org/articles/10.3389/fnins.2017.00309/full>`_.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple [源代码]
- 返回
A tuple
(H, W)
, whereH
is the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)
for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.dvs128_gesture module
- class spikingjelly.datasets.dvs128_gesture.DVS128Gesture(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The DVS128 Gesture dataset, which is proposed by A Low Power, Fully Event-Based Gesture Recognition System.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.Note
In SpikingJelly, there are 1176 train samples and 288 test samples. The total samples number is 1464.
from spikingjelly.datasets import dvs128_gesture data_dir = 'D:/datasets/DVS128Gesture' train_set = dvs128_gesture.DVS128Gesture(data_dir, train=True) test_set = dvs128_gesture.DVS128Gesture(data_dir, train=False) print(f'train samples = {train_set.__len__()}, test samples = {test_set.__len__()}') print(f'total samples = {train_set.__len__() + test_set.__len__()}') # train samples = 1176, test samples = 288 # total samples = 1464
While from the origin paper, the DvsGesture dataset comprises 1342 instances of a set of 11 hand and arm gestures. The difference may be caused by different pre-processing methods.
snnTorch have the same numbers with SpikingJelly:
from snntorch.spikevision import spikedata train_set = spikedata.DVSGesture("D:/datasets/DVS128Gesture/temp2", train=True, num_steps=500, dt=1000) test_set = spikedata.DVSGesture("D:/datasets/DVS128Gesture/temp2", train=False, num_steps=1800, dt=1000) print(f'train samples = {train_set.__len__()}, test samples = {test_set.__len__()}') print(f'total samples = {train_set.__len__() + test_set.__len__()}') # train samples = 1176, test samples = 288 # total samples = 1464
But tonic has different numbers, which are close to 1342:
import tonic train_set = tonic.datasets.DVSGesture(save_to='D:/datasets/DVS128Gesture/temp', train=True) test_set = tonic.datasets.DVSGesture(save_to='D:/datasets/DVS128Gesture/temp', train=False) print(f'train samples = {train_set.__len__()}, test samples = {test_set.__len__()}') print(f'total samples = {train_set.__len__() + test_set.__len__()}') # train samples = 1077, test samples = 264 # total samples = 1341
Here we show how 1176 train samples and 288 test samples are got in SpikingJelly.
The origin dataset is split to train and test set by
trials_to_train.txt
andtrials_to_test.txt
.trials_to_train.txt: user01_fluorescent.aedat user01_fluorescent_led.aedat ... user23_lab.aedat user23_led.aedat trials_to_test.txt: user24_fluorescent.aedat user24_fluorescent_led.aedat ... user29_led.aedat user29_natural.aedat
SpikingJelly will read the txt file and get the aedat file name like
user01_fluorescent.aedat
. The corresponding label file name will be regarded asuser01_fluorescent_labels.csv
.user01_fluorescent_labels.csv: class startTime_usec endTime_usec 1 80048239 85092709 2 89431170 95231007 3 95938861 103200075 4 114845417 123499505 5 124344363 131742581 6 133660637 141880879 7 142360393 149138239 8 150717639 157362334 8 157773346 164029864 9 165057394 171518239 10 172843790 179442817 11 180675853 187389051
Then SpikingJelly will split the aedat to samples by the time range and class in the csv file. In this sample, the first sample
user01_fluorescent_0.npz
is sliced from the origin eventsuser01_fluorescent.aedat
with80048239 <= t < 85092709
andlabel=0
.user01_fluorescent_0.npz
will be saved inroot/events_np/train/0
.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function defines how to read the origin binary data.
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.es_imagenet module
- class spikingjelly.datasets.es_imagenet.ESImageNet(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The ES-ImageNet dataset, which is proposed by ES-ImageNet: A Million Event-Stream Classification Dataset for Spiking Neural Networks.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.n_caltech101 module
- class spikingjelly.datasets.n_caltech101.NCaltech101(root: str, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The N-Caltech101 dataset, which is proposed by Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple [源代码]
- 返回
A tuple
(H, W)
, whereH
is the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)
for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.n_mnist module
- class spikingjelly.datasets.n_mnist.NMNIST(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
-
The N-MNIST dataset, which is proposed by Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- static load_origin_data(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the events file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function defines how to read the origin binary data.
- static get_H_W() Tuple [源代码]
- 返回
A tuple
(H, W)
, whereH
is the height of the data andW` is the weight of the data. For example, this function returns ``(128, 128)
for the DVS128 Gesture dataset.- 返回类型
- static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
spikingjelly.datasets.shd module
- spikingjelly.datasets.shd.cal_fixed_frames_number_segment_index_shd(events_t: ndarray, split_by: str, frames_num: int) tuple [源代码]
- spikingjelly.datasets.shd.integrate_events_segment_to_frame_shd(x: ndarray, W: int, j_l: int = 0, j_r: int = -1) ndarray [源代码]
- spikingjelly.datasets.shd.integrate_events_by_fixed_frames_number_shd(events: Dict, split_by: str, frames_num: int, W: int) ndarray [源代码]
- spikingjelly.datasets.shd.integrate_events_file_to_frames_file_by_fixed_frames_number_shd(h5_file: h5py.File, i: int, output_dir: str, split_by: str, frames_num: int, W: int, print_save: bool = False) None [源代码]
- spikingjelly.datasets.shd.integrate_events_by_fixed_duration_shd(events: Dict, duration: int, W: int) ndarray [源代码]
- spikingjelly.datasets.shd.integrate_events_file_to_frames_file_by_fixed_duration_shd(h5_file: h5py.File, i: int, output_dir: str, duration: int, W: int, print_save: bool = False) None [源代码]
- spikingjelly.datasets.shd.custom_integrate_function_example(h5_file: h5py.File, i: int, output_dir: str, W: int)[源代码]
- class spikingjelly.datasets.shd.SpikingHeidelbergDigits(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
基类:
Dataset
The Spiking Heidelberg Digits (SHD) dataset, which is proposed by The Heidelberg Spiking Data Sets for the Systematic Evaluation of Spiking Neural Networks.
Refer to
spikingjelly.datasets.NeuromorphicDatasetFolder
for more details about params information.Note
Events in this dataset are in the format of
(x, t)
rather than(x, y, t, p)
. Thus, this dataset is not inherited fromspikingjelly.datasets.NeuromorphicDatasetFolder
directly. But their procedures are similar.spikingjelly.datasets.shd.custom_integrate_function_example
is an example ofcustom_integrate_function
, which is similar to the cunstom function for DVS Gesture in theNeuromorphic Datasets Processing
tutorial.- static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
spikingjelly.datasets.speechcommands module
- spikingjelly.datasets.speechcommands.load_speechcommands_item(relpath: str, path: str) Tuple[Tensor, int, str, str, int] [源代码]
- class spikingjelly.datasets.speechcommands.SPEECHCOMMANDS(label_dict: Dict, root: str, silence_cnt: Optional[int] = 0, silence_size: Optional[int] = 16000, transform: Optional[Callable] = None, url: Optional[str] = 'speech_commands_v0.02', split: Optional[str] = 'train', folder_in_archive: Optional[str] = 'SpeechCommands', download: Optional[bool] = False)[源代码]
基类:
Dataset
- 参数
label_dict (Dict) – 标签与类别的对应字典
root (str) – 数据集的根目录
silence_cnt (int, optional) – Silence数据的数量
silence_size (int, optional) – Silence数据的尺寸
transform (Callable, optional) – A function/transform that takes in a raw audio
url (str, optional) – 数据集版本,默认为v0.02
split (str, optional) – 数据集划分,可以是
"train", "test", "val"
,默认为"train"
folder_in_archive (str, optional) – 解压后的目录名称,默认为
"SpeechCommands"
download (bool, optional) – 是否下载数据,默认为False
SpeechCommands语音数据集,出自 Speech Commands: A Dataset for Limited-Vocabulary Speech Recognition,根据给出的测试集与验证集列表进行了划分,包含v0.01与v0.02两个版本。
数据集包含三大类单词的音频:
指令单词,共10个,”Yes”, “No”, “Up”, “Down”, “Left”, “Right”, “On”, “Off”, “Stop”, “Go”. 对于v0.02,还额外增加了5个:”Forward”, “Backward”, “Follow”, “Learn”, “Visual”.
0~9的数字,共10个:”One”, “Two”, “Three”, “Four”, “Five”, “Six”, “Seven”, “Eight”, “Nine”.
辅助词,可以视为干扰词,共10个:”Bed”, “Bird”, “Cat”, “Dog”, “Happy”, “House”, “Marvin”, “Sheila”, “Tree”, “Wow”.
v0.01版本包含共计30类,64,727个音频片段,v0.02版本包含共计35类,105,829个音频片段。更详细的介绍参见前述论文,以及数据集的README。
代码实现基于torchaudio并扩充了功能,同时也参考了 原论文的实现。
Module contents
- spikingjelly.datasets.play_frame(x: Tensor, save_gif_to: Optional[str] = None) None [源代码]
- 参数
x (torch.Tensor or np.ndarray) – frames with
shape=[T, 2, H, W]
save_gif_to (str) – If
None
, this function will play the frames. IfTrue
, this function will not play the frames but save frames to a gif file in the directorysave_gif_to
- 返回
None
- spikingjelly.datasets.load_aedat_v3(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the aedat v3 file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function is written by referring to https://gitlab.com/inivation/dv/dv-python . It can be used for DVS128 Gesture.
- spikingjelly.datasets.load_ATIS_bin(file_name: str) Dict [源代码]
- 参数
file_name (str) – path of the aedat v3 file
- 返回
a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
- 返回类型
Dict
This function is written by referring to https://github.com/jackd/events-tfds . Each ATIS binary example is a separate binary file consisting of a list of events. Each event occupies 40 bits as described below: bit 39 - 32: Xaddress (in pixels) bit 31 - 24: Yaddress (in pixels) bit 23: Polarity (0 for OFF, 1 for ON) bit 22 - 0: Timestamp (in microseconds)
- spikingjelly.datasets.load_npz_frames(file_name: str) ndarray [源代码]
- 参数
file_name (str) – path of the npz file that saves the frames
- 返回
frames
- 返回类型
np.ndarray
- spikingjelly.datasets.integrate_events_segment_to_frame(x: ndarray, y: ndarray, p: ndarray, H: int, W: int, j_l: int = 0, j_r: int = -1) ndarray [源代码]
- 参数
x (numpy.ndarray) – x-coordinate of events
y (numpy.ndarray) – y-coordinate of events
p (numpy.ndarray) – polarity of events
H (int) – height of the frame
W (int) – weight of the frame
j_l (int) – the start index of the integral interval, which is included
j_r – the right index of the integral interval, which is not included
- 返回
frames
- 返回类型
np.ndarray
Denote a two channels frame as \(F\) and a pixel at \((p, x, y)\) as \(F(p, x, y)\), the pixel value is integrated from the events data whose indices are in \([j_{l}, j_{r})\):
\[F(p, x, y) = \sum_{i = j_{l}}^{j_{r} - 1} \mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\]where \(\lfloor \cdot \rfloor\) is the floor operation, \(\mathcal{I}_{p, x, y}(p_{i}, x_{i}, y_{i})\) is an indicator function and it equals 1 only when \((p, x, y) = (p_{i}, x_{i}, y_{i})\).
- spikingjelly.datasets.cal_fixed_frames_number_segment_index(events_t: ndarray, split_by: str, frames_num: int) tuple [源代码]
- 参数
events_t (numpy.ndarray) – events’ t
split_by (str) – ‘time’ or ‘number’
frames_num (int) – the number of frames
- 返回
a tuple
(j_l, j_r)
- 返回类型
Denote
frames_num
as \(M\), ifsplit_by
is'time'
, then\[\begin{split}\Delta T & = [\frac{t_{N-1} - t_{0}}{M}] \\ j_{l} & = \mathop{\arg\min}\limits_{k} \{t_{k} | t_{k} \geq t_{0} + \Delta T \cdot j\} \\ j_{r} & = \begin{cases} \mathop{\arg\max}\limits_{k} \{t_{k} | t_{k} < t_{0} + \Delta T \cdot (j + 1)\} + 1, & j < M - 1 \cr N, & j = M - 1 \end{cases}\end{split}\]If
split_by
is'number'
, then\[\begin{split}j_{l} & = [\frac{N}{M}] \cdot j \\ j_{r} & = \begin{cases} [\frac{N}{M}] \cdot (j + 1), & j < M - 1 \cr N, & j = M - 1 \end{cases}\end{split}\]
- spikingjelly.datasets.integrate_events_by_fixed_frames_number(events: Dict, split_by: str, frames_num: int, H: int, W: int) ndarray [源代码]
- 参数
- 返回
frames
- 返回类型
np.ndarray
Integrate events to frames by fixed frames number. See
cal_fixed_frames_number_segment_index
andintegrate_events_segment_to_frame
for more details.
- spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_frames_number(loader: Callable, events_np_file: str, output_dir: str, split_by: str, frames_num: int, H: int, W: int, print_save: bool = False) None [源代码]
- 参数
loader (Callable) – a function that can load events from events_np_file
events_np_file (str) – path of the events np file
output_dir (str) – output directory for saving the frames
split_by (str) – ‘time’ or ‘number’
frames_num (int) – the number of frames
H (int) – the height of frame
W (int) – the weight of frame
print_save (bool) – If
True
, this function will print saved files’ paths.
- 返回
None
Integrate a events file to frames by fixed frames number and save it. See
cal_fixed_frames_number_segment_index
andintegrate_events_segment_to_frame
for more details.
- spikingjelly.datasets.integrate_events_by_fixed_duration(events: Dict, duration: int, H: int, W: int) ndarray [源代码]
- 参数
- 返回
frames
- 返回类型
np.ndarray
Integrate events to frames by fixed time duration of each frame.
- spikingjelly.datasets.integrate_events_file_to_frames_file_by_fixed_duration(loader: Callable, events_np_file: str, output_dir: str, duration: int, H: int, W: int, print_save: bool = False) None [源代码]
- 参数
loader (Callable) – a function that can load events from events_np_file
events_np_file (str) – path of the events np file
output_dir (str) – output directory for saving the frames
duration (int) – the time duration of each frame
H (int) – the height of frame
W (int) – the weight of frame
print_save (bool) – If
True
, this function will print saved files’ paths.
- 返回
None
Integrate events to frames by fixed time duration of each frame.
- spikingjelly.datasets.create_same_directory_structure(source_dir: str, target_dir: str) None [源代码]
- 参数
- 返回
None
Create the same directory structure in
target_dir
with that ofsource_dir
.
- spikingjelly.datasets.split_to_train_test_set(train_ratio: float, origin_dataset: Dataset, num_classes: int, random_split: bool = False)[源代码]
- 参数
train_ratio (float) – split the ratio of the origin dataset as the train set
origin_dataset (torch.utils.data.Dataset) – the origin dataset
num_classes (int) – total classes number, e.g.,
10
for the MNIST datasetrandom_split (int) – If
False
, the front ratio of samples in each classes will be included in train set, while the reset will be included in test set. IfTrue
, this function will split samples in each classes randomly. The randomness is controlled bynumpy.random.seed
- 返回
a tuple
(train_set, test_set)
- 返回类型
- spikingjelly.datasets.pad_sequence_collate(batch: list)[源代码]
- 参数
batch (list) – a list of samples that contains
(x, y)
, wherex
is a list containing sequences with different length andy
is the label- 返回
batched samples
(x_p, y, x_len), where ``x_p
is paddedx
with the same length, y` is the label, andx_len
is the length of thex
- 返回类型
This function can be use as the
collate_fn
forDataLoader
to process the dataset with variable length, e.g., aNeuromorphicDatasetFolder
with fixed duration to integrate events to frames. Here is an example: .. code-block:: python class VariableLengthDataset(torch.utils.data.Dataset):- def __init__(self, n=1000):
super().__init__() self.n = n
- def __getitem__(self, i):
return torch.rand([i + 1, 2]), self.n - i - 1
- def __len__(self):
return self.n
- loader = torch.utils.data.DataLoader(VariableLengthDataset(n=32), batch_size=2, collate_fn=pad_sequence_collate,
shuffle=True)
- for i, (x_p, label, x_len) in enumerate(loader):
print(f’x_p.shape={x_p.shape}, label={label}, x_len={x_len}’) if i == 2:
break
And the outputs are: .. code-block:: bash
x_p.shape=torch.Size([2, 18, 2]), label=tensor([14, 30]), x_len=tensor([18, 2]) x_p.shape=torch.Size([2, 29, 2]), label=tensor([3, 6]), x_len=tensor([29, 26]) x_p.shape=torch.Size([2, 23, 2]), label=tensor([ 9, 23]), x_len=tensor([23, 9])
- spikingjelly.datasets.padded_sequence_mask(sequence_len: Tensor, T=None)[源代码]
- 参数
sequence_len (torch.Tensor) – a tensor
shape = [N]
that contains sequences lengths of each batch elementT (int) – The maximum length of sequences. If
None
, the maximum element insequence_len
will be seen asT
- 返回
a bool mask with shape = [T, N], where the padded position is
False
- 返回类型
Here is an example: .. code-block:: python
x1 = torch.rand([2, 6]) x2 = torch.rand([3, 6]) x3 = torch.rand([4, 6]) x = torch.nn.utils.rnn.pad_sequence([x1, x2, x3]) # [T, N, *] print(‘x.shape=’, x.shape) x_len = torch.as_tensor([x1.shape[0], x2.shape[0], x3.shape[0]]) mask = padded_sequence_mask(x_len) print(‘mask.shape=’, mask.shape) print(‘mask=n’, mask)
And the outputs are: .. code-block:: bash
x.shape= torch.Size([4, 3, 6]) mask.shape= torch.Size([4, 3]) mask=
- tensor([[ True, True, True],
[ True, True, True], [False, True, True], [False, False, True]])
- class spikingjelly.datasets.NeuromorphicDatasetFolder(root: str, train: Optional[bool] = None, data_type: str = 'event', frames_number: Optional[int] = None, split_by: Optional[str] = None, duration: Optional[int] = None, custom_integrate_function: Optional[Callable] = None, custom_integrated_frames_dir_name: Optional[str] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None)[源代码]
基类:
DatasetFolder
- 参数
root (str) – root path of the dataset
train (bool) – whether use the train set. Set
True
orFalse
for those datasets provide train/test division, e.g., DVS128 Gesture dataset. If the dataset does not provide train/test division, e.g., CIFAR10-DVS, please setNone
and usesplit_to_train_test_set
function to get train/test setdata_type (str) – event or frame
frames_number (int) – the integrated frame number
split_by (str) – time or number
duration (int) – the time duration of each frame
custom_integrate_function (Callable) – a user-defined function that inputs are
events, H, W
.events
is a dict whose keys are['t', 'x', 'y', 'p']
and values arenumpy.ndarray
H
is the height of the data andW
is the weight of the data. For example, H=128 and W=128 for the DVS128 Gesture dataset. The user should define how to integrate events to frames, and return frames.custom_integrated_frames_dir_name (str or None) – The name of directory for saving frames integrating by
custom_integrate_function
. Ifcustom_integrated_frames_dir_name
isNone
, it will be set tocustom_integrate_function.__name__
transform (callable) – a function/transform that takes in a sample and returns a transformed version. E.g,
transforms.RandomCrop
for images.target_transform (callable) – a function/transform that takes in the target and transforms it.
The base class for neuromorphic dataset. Users can define a new dataset by inheriting this class and implementing all abstract methods. Users can refer to
spikingjelly.datasets.dvs128_gesture.DVS128Gesture
. Ifdata_type == 'event'
the sample in this dataset is a dict whose keys are
['t', 'x', 'y', 'p']
and values arenumpy.ndarray
.- If
data_type == 'frame'
andframes_number
is notNone
events will be integrated to frames with fixed frames number.
split_by
will define how to split events. Seecal_fixed_frames_number_segment_index
for more details.- If
data_type == 'frame'
,frames_number
isNone
, andduration
is notNone
events will be integrated to frames with fixed time duration.
- If
data_type == 'frame'
,frames_number
isNone
,duration
isNone
, andcustom_integrate_function
is notNone
: events will be integrated by the user-defined function and saved to the
custom_integrated_frames_dir_name
directory inroot
directory. Here is an example from SpikingJelly’s tutorials:from spikingjelly.datasets.dvs128_gesture import DVS128Gesture from typing import Dict import numpy as np import spikingjelly.datasets as sjds def integrate_events_to_2_frames_randomly(events: Dict, H: int, W: int): index_split = np.random.randint(low=0, high=events['t'].__len__()) frames = np.zeros([2, 2, H, W]) t, x, y, p = (events[key] for key in ('t', 'x', 'y', 'p')) frames[0] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, 0, index_split) frames[1] = sjds.integrate_events_segment_to_frame(x, y, p, H, W, index_split, events['t'].__len__()) return frames root_dir = 'D:/datasets/DVS128Gesture' train_set = DVS128Gesture(root_dir, train=True, data_type='frame', custom_integrate_function=integrate_events_to_2_frames_randomly) from spikingjelly.datasets import play_frame frame, label = train_set[500] play_frame(frame)
- abstract static resource_url_md5() list [源代码]
- 返回
A list
url
thaturl[i]
is a tuple, which contains the i-th file’s name, download link, and MD5- 返回类型
- abstract static downloadable() bool [源代码]
- 返回
Whether the dataset can be directly downloaded by python codes. If not, the user have to download it manually
- 返回类型
- abstract static extract_downloaded_files(download_root: str, extract_root: str)[源代码]
- 参数
- 返回
None
This function defines how to extract download files.
- abstract static create_events_np_files(extract_root: str, events_np_root: str)[源代码]
- 参数
extract_root (str) – Root directory path which saves extracted files from downloaded files
events_np_root – Root directory path which saves events files in the
npz
format
- 返回
None
This function defines how to convert the origin binary data in
extract_root
tonpz
format and save converted files inevents_np_root
.
- spikingjelly.datasets.random_temporal_delete(x_seq: Tensor, T_remain: int, batch_first)[源代码]
- 参数
x_seq (torch.Tensor or np.ndarray) – a sequence with shape = [T, N, *], where T is the sequence length and N is the batch size
T_remain (int) – the remained length
batch_first (bool) – if True, x_seq will be regarded as shape = [N, T, *]
- 返回
the sequence with length T_remain, which is obtained by randomly removing T - T_remain slices
- 返回类型
torch.Tensor or np.ndarray
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Codes example:
import torch from spikingjelly.datasets import random_temporal_delete T = 8 T_remain = 5 N = 4 x_seq = torch.arange(0, N*T).view([N, T]) print('x_seq=\n', x_seq) print('random_temporal_delete(x_seq)=\n', random_temporal_delete(x_seq, T_remain, batch_first=True))
Outputs:
x_seq= tensor([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31]]) random_temporal_delete(x_seq)= tensor([[ 0, 1, 4, 6, 7], [ 8, 9, 12, 14, 15], [16, 17, 20, 22, 23], [24, 25, 28, 30, 31]])
- class spikingjelly.datasets.RandomTemporalDelete(T_remain: int, batch_first: bool)[源代码]
基类:
Module
- 参数
T_remain (int) – the remained length
batch_first – if True, x_seq will be regarded as shape = [N, T, *]
The random temporal delete data augmentation used in Deep Residual Learning in Spiking Neural Networks. Refer to
random_temporal_delete
for more details.
- spikingjelly.datasets.create_sub_dataset(source_dir: str, target_dir: str, ratio: float, use_soft_link=True, randomly=False)[源代码]
- 参数
source_dir (str) – the directory path of the origin dataset
target_dir (str) – the directory path of the sub dataset
ratio (float) – the ratio of samples sub dataset will copy from the origin dataset
use_soft_link (bool) – if
True
, the sub dataset will use soft link to copy; else, the sub dataset will copy filesrandomly (bool) – if
True
, the files copy from the origin dataset will be picked up randomly. The randomness is controlled bynumpy.random.seed
Create a sub dataset with copy
ratio
of samples from the origin dataset.
spikingjelly.timing_based package
Subpackages
spikingjelly.timing_based.examples package
Submodules
spikingjelly.timing_based.examples.tempotron_mnist module
- spikingjelly.timing_based.examples.tempotron_mnist.main()[源代码]
- 返回
None
使用高斯调谐曲线编码器编码图像为脉冲,单层Tempotron进行MNIST识别。
这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。
Use Gaussian tuned activation function encoder to encode the images to spikes.
The network with single Tempotron structure for classifying MNIST.
This function initials the network, starts training and shows accuracy on test dataset.
Module contents
Submodules
spikingjelly.timing_based.encoding module
- class spikingjelly.timing_based.encoding.GaussianTuning(n, m, x_min: Tensor, x_max: Tensor)[源代码]
基类:
object
- 参数
n – 特征的数量,int
m – 编码一个特征所使用的神经元数量,int
x_min – n个特征的最小值,shape=[n]的tensor
x_max – n个特征的最大值,shape=[n]的tensor
Bohte S M, Kok J N, La Poutre J A, et al. Error-backpropagation in temporally encoded networks of spiking neurons[J]. Neurocomputing, 2002, 48(1): 17-37. 中提出的高斯调谐曲线编码方式
编码器所使用的变量所在的device与x_min.device一致
spikingjelly.timing_based.neuron module
- class spikingjelly.timing_based.neuron.Tempotron(in_features, out_features, T, tau=15.0, tau_s=3.75, v_threshold=1.0)[源代码]
基类:
Module
- 参数
in_features – 输入数量,含义与nn.Linear的in_features参数相同
out_features – 输出数量,含义与nn.Linear的out_features参数相同
T – 仿真周期
tau – LIF神经元的积分时间常数
tau_s – 突触上的电流的衰减时间常数
v_threshold – 阈值电压
Gutig R, Sompolinsky H. The tempotron: a neuron that learns spike timing–based decisions[J]. Nature Neuroscience, 2006, 9(3): 420-428. 中提出的Tempotron模型
- static psp_kernel(t: Tensor, tau, tau_s)[源代码]
- 参数
t – 表示时刻的tensor
tau – LIF神经元的积分时间常数
tau_s – 突触上的电流的衰减时间常数
- 返回
t时刻突触后的LIF神经元的电压值
- static mse_loss(v_max, v_threshold, label, num_classes)[源代码]
- 参数
v_max – Tempotron神经元在仿真周期内输出的最大电压值,与forward函数在ret_type == ‘v_max’时的返回值相 同。shape=[batch_size, out_features]的tensor
v_threshold – Tempotron的阈值电压,float或shape=[batch_size, out_features]的tensor
label – 样本的真实标签,shape=[batch_size]的tensor
num_classes – 样本的类别总数,int
- 返回
分类错误的神经元的电压,与阈值电压之差的均方误差
- forward(in_spikes: Tensor, ret_type)[源代码]
- 参数
in_spikes – shape=[batch_size, in_features]
in_spikes[:, i]表示第i个输入脉冲的脉冲发放时刻,介于0到T之间,T是仿真时长
in_spikes[:, i] < 0则表示无脉冲发放 :param ret_type: 返回值的类项,可以为’v’,’v_max’,’spikes’ :return:
ret_type == ‘v’: 返回一个shape=[batch_size, out_features, T]的tensor,表示out_features个Tempotron神经元在仿真时长T 内的电压值
ret_type == ‘v_max’: 返回一个shape=[batch_size, out_features]的tensor,表示out_features个Tempotron神经元在仿真时长T 内的峰值电压
ret_type == ‘spikes’: 返回一个out_spikes,shape=[batch_size, out_features]的tensor,表示out_features个Tempotron神 经元的脉冲发放时刻,out_spikes[:, i]表示第i个输出脉冲的脉冲发放时刻,介于0到T之间,T是仿真时长。out_spikes[:, i] < 0 表示无脉冲发放
Module contents
spikingjelly.visualizing package
Module contents
- spikingjelly.visualizing.plot_2d_heatmap(array: ndarray, title: str, xlabel: str, ylabel: str, int_x_ticks=True, int_y_ticks=True, plot_colorbar=True, colorbar_y_label='magnitude', x_max=None, figsize=(12, 8), dpi=200)[源代码]
- 参数
array – shape=[T, N]的任意数组
title – 热力图的标题
xlabel – 热力图的x轴的label
ylabel – 热力图的y轴的label
int_x_ticks – x轴上是否只显示整数刻度
int_y_ticks – y轴上是否只显示整数刻度
plot_colorbar – 是否画出显示颜色和数值对应关系的colorbar
colorbar_y_label – colorbar的y轴label
x_max – 横轴的最大刻度。若设置为
None
,则认为横轴的最大刻度是array.shape[1]
dpi – 绘图的dpi
- 返回
绘制好的figure
绘制一张二维的热力图。可以用来绘制一张表示多个神经元在不同时刻的电压的热力图,示例代码:
import torch from spikingjelly.activation_based import neuron from spikingjelly import visualizing from matplotlib import pyplot as plt import numpy as np lif = neuron.LIFNode(tau=100.) x = torch.rand(size=[32]) * 4 T = 50 s_list = [] v_list = [] for t in range(T): s_list.append(lif(x).unsqueeze(0)) v_list.append(lif.v.unsqueeze(0)) s_list = torch.cat(s_list) v_list = torch.cat(v_list) visualizing.plot_2d_heatmap(array=np.asarray(v_list), title='Membrane Potentials', xlabel='Simulating Step', ylabel='Neuron Index', int_x_ticks=True, x_max=T, dpi=200) plt.show()
- spikingjelly.visualizing.plot_2d_bar_in_3d(array: ndarray, title: str, xlabel: str, ylabel: str, zlabel: str, int_x_ticks=True, int_y_ticks=True, int_z_ticks=False, dpi=200)[源代码]
- 参数
array – shape=[T, N]的任意数组
title – 图的标题
xlabel – x轴的label
ylabel – y轴的label
zlabel – z轴的label
int_x_ticks – x轴上是否只显示整数刻度
int_y_ticks – y轴上是否只显示整数刻度
int_z_ticks – z轴上是否只显示整数刻度
dpi – 绘图的dpi
- 返回
绘制好的figure
将shape=[T, N]的任意数组,绘制为三维的柱状图。可以用来绘制多个神经元的脉冲发放频率,随着时间的变化情况,示例代码:
import torch from spikingjelly import visualizing from matplotlib import pyplot as plt Epochs = 5 N = 10 firing_rate = torch.zeros(Epochs, N) init_firing_rate = torch.rand(size=[N]) for i in range(Epochs): firing_rate[i] = torch.softmax(init_firing_rate * (i + 1) ** 2, dim=0) visualizing.plot_2d_bar_in_3d(firing_rate.numpy(), title='spiking rates of output layer', xlabel='neuron index', ylabel='training epoch', zlabel='spiking rate', int_x_ticks=True, int_y_ticks=True, int_z_ticks=False, dpi=200) plt.show()
也可以用来绘制一张表示多个神经元在不同时刻的电压的热力图,示例代码:
import torch from spikingjelly import visualizing from matplotlib import pyplot as plt from spikingjelly.activation_based import neuron neuron_num = 4 T = 50 lif_node = neuron.LIFNode(tau=100.) w = torch.rand([neuron_num]) * 10 v_list = [] for t in range(T): lif_node(w * torch.rand(size=[neuron_num])) v_list.append(lif_node.v.unsqueeze(0)) v_list = torch.cat(v_list) visualizing.plot_2d_bar_in_3d(v_list, title='voltage of neurons', xlabel='neuron index', ylabel='simulating step', zlabel='voltage', int_x_ticks=True, int_y_ticks=True, int_z_ticks=False, dpi=200) plt.show()
- spikingjelly.visualizing.plot_1d_spikes(spikes: asarray, title: str, xlabel: str, ylabel: str, int_x_ticks=True, int_y_ticks=True, plot_firing_rate=True, firing_rate_map_title='firing rate', figsize=(12, 8), dpi=200)[源代码]
- 参数
spikes – shape=[T, N]的np数组,其中的元素只为0或1,表示N个时长为T的脉冲数据
title – 热力图的标题
xlabel – 热力图的x轴的label
ylabel – 热力图的y轴的label
int_x_ticks – x轴上是否只显示整数刻度
int_y_ticks – y轴上是否只显示整数刻度
plot_firing_rate – 是否画出各个脉冲发放频率
firing_rate_map_title – 脉冲频率发放图的标题
dpi – 绘图的dpi
- 返回
绘制好的figure
画出N个时长为T的脉冲数据。可以用来画N个神经元在T个时刻的脉冲发放情况,示例代码:
import torch from spikingjelly.activation_based import neuron from spikingjelly import visualizing from matplotlib import pyplot as plt import numpy as np lif = neuron.LIFNode(tau=100.) x = torch.rand(size=[32]) * 4 T = 50 s_list = [] v_list = [] for t in range(T): s_list.append(lif(x).unsqueeze(0)) v_list.append(lif.v.unsqueeze(0)) s_list = torch.cat(s_list) v_list = torch.cat(v_list) visualizing.plot_1d_spikes(spikes=np.asarray(s_list), title='Membrane Potentials', xlabel='Simulating Step', ylabel='Neuron Index', dpi=200) plt.show()
- spikingjelly.visualizing.plot_2d_feature_map(x3d: asarray, nrows, ncols, space, title: str, figsize=(12, 8), dpi=200)[源代码]
- 参数
x3d – shape=[C, W, H],C个尺寸为W * H的矩阵。这样的矩阵一般来源于卷积层后的脉冲神经元的输出
nrows – 画成多少行
ncols – 画成多少列
space – 矩阵之间的间隙
title – 图的标题
figsize – 图片大小
dpi – 绘图的dpi
- 返回
一个figure,将C个矩阵全部画出,然后排列成nrows行ncols列
将C个尺寸为W * H的矩阵,全部画出,然后排列成nrows行ncols列。这样的矩阵一般来源于卷积层后的脉冲神经元的输出,通过这个函数可以对输出进行可视化。示例代码:
from spikingjelly import visualizing import numpy as np from matplotlib import pyplot as plt C = 48 W = 8 H = 8 spikes = (np.random.rand(C, W, H) > 0.8).astype(float) visualizing.plot_2d_feature_map(spikes=spikes, nrows=6, ncols=8, space=2, title='Spiking Feature Maps', dpi=200) plt.show()
- spikingjelly.visualizing.plot_one_neuron_v_s(v: ndarray, s: ndarray, v_threshold=1.0, v_reset=0.0, title='$V[t]$ and $S[t]$ of the neuron', figsize=(12, 8), dpi=200)[源代码]
- 参数
v – shape=[T], 存放神经元不同时刻的电压
s – shape=[T], 存放神经元不同时刻释放的脉冲
v_threshold – 神经元的阈值电压
v_reset – 神经元的重置电压。也可以为
None
title – 图的标题
dpi – 绘图的dpi
- 返回
一个figure
绘制单个神经元的电压、脉冲随着时间的变化情况。示例代码:
import torch from spikingjelly.activation_based import neuron from spikingjelly import visualizing from matplotlib import pyplot as plt lif = neuron.LIFNode(tau=100.) x = torch.Tensor([2.0]) T = 150 s_list = [] v_list = [] for t in range(T): s_list.append(lif(x)) v_list.append(lif.v) visualizing.plot_one_neuron_v_s(v_list, s_list, v_threshold=lif.v_threshold, v_reset=lif.v_reset, dpi=200) plt.show()