# STDP学习

## STDP(Spike Timing Dependent Plasticity)

STDP(Spike Timing Dependent Plasticity)最早是由 1 提出，是在生物实验中发现的一种突触可塑性机制。实验发现，突触权重 受到突触连接的前神经元(pre)和后神经元(post)的脉冲发放的影响，具体而言是：

STDP可以使用如下公式进行拟合：

\begin{split}\begin{align} \begin{split} \Delta w_{ij} = \begin{cases} A\exp(\frac{-|t_{i}-t_{j}|}{\tau_{+}}) , t_{i} \leq t_{j}, A > 0\\ B\exp(\frac{-|t_{i}-t_{j}|}{\tau_{-}}) , t_{i} > t_{j}, B < 0 \end{cases} \end{split} \end{align}\end{split}

\begin{align}\begin{aligned}tr_{pre}[i][t] = tr_{pre}[i][t] -\frac{tr_{pre}[i][t-1]}{\tau_{pre}} + s[i][t]\\tr_{post}[j][t] = tr_{post}[j][t] -\frac{tr_{post}[j][t-1]}{\tau_{post}} + s[j][t]\end{aligned}\end{align}

$\Delta W[i][j][t] = F_{post}(w[i][j][t]) \cdot tr_{pre}[i][t] \cdot s[j][t] - F_{pre}(w[i][j][t]) \cdot tr_{post}[j][t] \cdot s[i][t]$

## STDP优化器

spikingjelly.activation_based.learning.STDPLearner 提供了STDP优化器的实现，支持卷积和全连接层，请读者先阅读其API文档以获取使用方法。

import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, layer, learning
from matplotlib import pyplot as plt
torch.manual_seed(0)

def f_weight(x):

tau_pre = 2.
tau_post = 2.
T = 128
N = 1
lr = 0.01
net = nn.Sequential(
layer.Linear(1, 1, bias=False),
neuron.IFNode()
)
nn.init.constant_(net[0].weight.data, 0.4)


STDPLearner 可以将负的权重的更新量 - delta_w * scale 叠加到参数的梯度上，因而与深度学习完全兼容。

$W = W - lr \cdot \nabla W$

optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=0.)


in_spike = (torch.rand([T, N, 1]) > 0.7).float()
stdp_learner = learning.STDPLearner(step_mode='s', synapse=net[0], sn=net[1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)


out_spike = []
trace_pre = []
trace_post = []
weight = []
for t in range(T):
out_spike.append(net(in_spike[t]).squeeze())
optimizer.step()
weight.append(net[0].weight.data.clone().squeeze())
trace_pre.append(stdp_learner.trace_pre.squeeze())
trace_post.append(stdp_learner.trace_post.squeeze())

in_spike = in_spike.squeeze()
out_spike = torch.stack(out_spike)
trace_pre = torch.stack(trace_pre)
trace_post = torch.stack(trace_post)
weight = torch.stack(weight)


in_spike, out_spike, trace_pre, trace_post, weight 画出，得到下图：

## 与梯度下降混合使用

import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import learning, layer, neuron, functional

T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'


def f_weight(x):

net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)

functional.set_step_mode(net, step_mode)


instances_stdp = (layer.Conv2d, )


stdp_learners = []

for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)


params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)

params_stdp_set = set(params_stdp)
for p in net.parameters():
if p not in params_stdp_set:

optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)


x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])


optimizer_gd.zero_grad()
y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()


optimizer_stdp.zero_grad()


for i in range(stdp_learners.__len__()):

optimizer_gd.step()
optimizer_stdp.step()


STDPLearner 为代表的所有学习器都是 MemoryModule 的子类，其内部记忆状态包括了突触前后神经元的迹 trace_pre, trace_post ；另外，学习器内部用于记录神经元活动的监视器存储了突触前后神经元的发放历史；这些发放历史也可以视作学习器的内部记忆状态。因此，必须及时调用学习器的 reset() 方法，来清空其内部记忆状态，从而防止内存/显存消耗量随着训练而不断增长！通常的做法是：在每个batch结束后，将学习器和网络一起重制：

functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()


import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import learning, layer, neuron, functional

T = 8
N = 2
C = 3
H = 32
W = 32
lr = 0.1
tau_pre = 2.
tau_post = 100.
step_mode = 'm'

def f_weight(x):

net = nn.Sequential(
layer.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False),
neuron.IFNode(),
layer.MaxPool2d(2, 2),
layer.Flatten(),
layer.Linear(16 * 8 * 8, 64, bias=False),
neuron.IFNode(),
layer.Linear(64, 10, bias=False),
neuron.IFNode(),
)

functional.set_step_mode(net, step_mode)

instances_stdp = (layer.Conv2d, )

stdp_learners = []

for i in range(net.__len__()):
if isinstance(net[i], instances_stdp):
stdp_learners.append(
learning.STDPLearner(step_mode=step_mode, synapse=net[i], sn=net[i+1], tau_pre=tau_pre, tau_post=tau_post,
f_pre=f_weight, f_post=f_weight)
)

params_stdp = []
for m in net.modules():
if isinstance(m, instances_stdp):
for p in m.parameters():
params_stdp.append(p)

params_stdp_set = set(params_stdp)
for p in net.parameters():
if p not in params_stdp_set:

optimizer_stdp = SGD(params_stdp, lr=lr, momentum=0.)

x_seq = (torch.rand([T, N, C, H, W]) > 0.5).float()
target = torch.randint(low=0, high=10, size=[N])

y = net(x_seq).mean(0)
loss = F.cross_entropy(y, target)
loss.backward()

for i in range(stdp_learners.__len__()):

optimizer_gd.step()
optimizer_stdp.step()

functional.reset_net(net)
for i in range(stdp_learners.__len__()):
stdp_learners[i].reset()

1

Bi, Guo-qiang, and Mu-ming Poo. “Synaptic modifications in cultured hippocampal neurons: dependence on spike timing, synaptic strength, and postsynaptic cell type.” Journal of neuroscience 18.24 (1998): 10464-10472.

2

Froemke, Robert C., et al. “Contribution of individual spikes in burst-induced long-term synaptic modification.” Journal of neurophysiology (2006).

3(1,2)

Morrison, Abigail, Markus Diesmann, and Wulfram Gerstner. “Phenomenological models of synaptic plasticity based on spike timing.” Biological cybernetics 98.6 (2008): 459-478.