SpikingFlow.learning package

Module contents

class SpikingFlow.learning.STDPModule(tf_module, connection_module, neuron_module, tau_pre, tau_post, learning_rate, f_w=<function STDPModule.<lambda>>)[源代码]

基类:torch.nn.modules.module.Module

参数
  • tf_module – connection.transform中的脉冲-电流转换器

  • connection_module – 突触

  • neuron_module – 神经元

  • tau_pre – pre脉冲的迹的时间常数

  • tau_post – post脉冲的迹的时间常数

  • learning_rate – 学习率

  • f_w – 权值函数,输入是权重w,输出一个float或者是与权重w相同shape的tensor

由tf_module,connection_module,neuron_module构成的STDP学习的基本单元

利用迹的方式(Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spike timing[J]. Biological cybernetics, 2008, 98(6): 459-478.)实现STDP学习,更新connection_module中的参数

pre脉冲到达时,权重减少trace_post * f_w(w) * learning_rate

post脉冲到达时,权重增加trace_pre * f_w(w) * learning_rate

示例代码

import SpikingFlow.simulating as simulating
import SpikingFlow.learning as learning
import SpikingFlow.connection as connection
import SpikingFlow.connection.transform as tf
import SpikingFlow.neuron as neuron
import torch
from matplotlib import pyplot

# 新建一个仿真器
sim = simulating.Simulator()

# 添加各个模块。为了更明显的观察到脉冲,我们使用IF神经元,而且把膜电阻设置的很大
# 突触的pre是2个输入,而post是1个输出,连接权重是shape=[1, 2]的tensor
sim.append(learning.STDPModule(tf.SpikeCurrent(amplitude=0.5),
                               connection.Linear(2, 1),
                               neuron.IFNode(shape=[1], r=50.0, v_threshold=1.0),
                               tau_pre=10.0,
                               tau_post=10.0,
                               learning_rate=1e-3
                               ))
# 新建list,分别保存pre的2个输入脉冲、post的1个输出脉冲,以及对应的连接权重
pre_spike_list0 = []
pre_spike_list1 = []
post_spike_list = []
w_list0 = []
w_list1 = []
T = 200

for t in range(T):
    if t < 100:
        # 前100步仿真,pre_spike[0]和pre_spike[1]都是发放一次1再发放一次0
        if t % 2 == 0:
            pre_spike = torch.ones(size=[2], dtype=torch.bool)
        else:
            pre_spike = torch.zeros(size=[2], dtype=torch.bool)
    else:
        # 后100步仿真,pre_spike[0]一直为0,而pre_spike[1]一直为1
        pre_spike = torch.zeros(size=[2], dtype=torch.bool)
        pre_spike[1] = True

    post_spike = sim.step(pre_spike)
    pre_spike_list0.append(pre_spike[0].float().item())
    pre_spike_list1.append(pre_spike[1].float().item())

    post_spike_list.append(post_spike.float().item())

    w_list0.append(sim.module_list[-1].module_list[2].w[:, 0].item())
    w_list1.append(sim.module_list[-1].module_list[2].w[:, 1].item())

# 画出pre_spike[0]
pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list0, width=0.1, label='pre_spike[0]')
pyplot.legend()
pyplot.show()

# 画出pre_spike[1]
pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list1, width=0.1, label='pre_spike[1]')
pyplot.legend()
pyplot.show()

# 画出post_spike
pyplot.bar(torch.arange(0, T).tolist(), post_spike_list, width=0.1, label='post_spike')
pyplot.legend()
pyplot.show()

# 画出2个输入与1个输出的连接权重w_0和w_1
pyplot.plot(w_list0, c='r', label='w[0]')
pyplot.plot(w_list1, c='g', label='w[1]')
pyplot.legend()
pyplot.show()
update_param(pre=True)[源代码]
forward(pre_spike)[源代码]
参数

pre_spike – 输入脉冲

返回

经过本module后的输出脉冲

需要注意的时,由于本module含有tf_module, connection_module, neuron_module三个module

因此在t时刻的输入,到t+3dt才能得到其输出

reset()[源代码]
返回

None

将module自身,以及tf_module, connection_module, neuron_module的状态变量重置为初始值

training: bool
class SpikingFlow.learning.STDPUpdater(tau_pre, tau_post, learning_rate, f_w=<function STDPUpdater.<lambda>>)[源代码]

基类:object

参数
  • neuron_module – 神经元

  • tau_pre – pre脉冲的迹的时间常数

  • tau_post – post脉冲的迹的时间常数

  • learning_rate – 学习率

  • f_w – 权值函数,输入是权重w,输出一个float或者是与权重w相同shape的tensor

利用迹的方式(Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spike timing[J]. Biological cybernetics, 2008, 98(6): 459-478.)实现STDP学习,更新connection_module中的参数

pre脉冲到达时,权重减少trace_post * f_w(w) * learning_rate

post脉冲到达时,权重增加trace_pre * f_w(w) * learning_rate

与STDPModule类似,但需要手动给定前后脉冲,这也带来了更为灵活的使用方式,例如 不使用突触实际连接的前后神经元的脉冲,而是使用其他脉冲来指导某个突触的学习

示例代码

import SpikingFlow.simulating as simulating
import SpikingFlow.learning as learning
import SpikingFlow.connection as connection
import SpikingFlow.connection.transform as tf
import SpikingFlow.neuron as neuron
import torch
from matplotlib import pyplot

# 定义权值函数f_w
def f_w(x: torch.Tensor):
    x_abs = x.abs()
    return x_abs / (x_abs.sum() + 1e-6)

# 新建一个仿真器
sim = simulating.Simulator()

# 放入脉冲电流转换器、突触、LIF神经元
sim.append(tf.SpikeCurrent(amplitude=0.5))
sim.append(connection.Linear(2, 1))
sim.append(neuron.LIFNode(shape=[1], r=10.0, v_threshold=1.0, tau=100.0))

# 新建一个STDPUpdater
updater = learning.STDPUpdater(tau_pre=50.0,
                               tau_post=100.0,
                               learning_rate=1e-1,
                               f_w=f_w)

# 新建list,保存pre脉冲、post脉冲、突触权重w_00, w_01
pre_spike_list0 = []
pre_spike_list1 = []
post_spike_list = []
w_list0 = []
w_list1 = []

T = 500
for t in range(T):
    if t < 250:
        if t % 2 == 0:
            pre_spike = torch.ones(size=[2], dtype=torch.bool)
        else:
            pre_spike = torch.randint(low=0, high=2, size=[2]).bool()
    else:
        pre_spike = torch.zeros(size=[2], dtype=torch.bool)
        if t % 2 == 0:
            pre_spike[1] = True




    pre_spike_list0.append(pre_spike[0].float().item())
    pre_spike_list1.append(pre_spike[1].float().item())

    post_spike = sim.step(pre_spike)

    updater.update(sim.module_list[1], pre_spike, post_spike)

    post_spike_list.append(post_spike.float().item())

    w_list0.append(sim.module_list[1].w[:, 0].item())
    w_list1.append(sim.module_list[1].w[:, 1].item())

pyplot.figure(figsize=(8, 16))
pyplot.subplot(4, 1, 1)
pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list0, width=0.1, label='pre_spike[0]')
pyplot.legend()

pyplot.subplot(4, 1, 2)
pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list1, width=0.1, label='pre_spike[1]')
pyplot.legend()

pyplot.subplot(4, 1, 3)
pyplot.bar(torch.arange(0, T).tolist(), post_spike_list, width=0.1, label='post_spike')
pyplot.legend()

pyplot.subplot(4, 1, 4)
pyplot.plot(w_list0, c='r', label='w[0]')
pyplot.plot(w_list1, c='g', label='w[1]')
pyplot.legend()
pyplot.show()
reset()[源代码]
返回

None

将前后脉冲的迹设置为0

update(connection_module, pre_spike, post_spike, inverse=False)[源代码]
参数
  • connection_module – 突触

  • pre_spike – 输入脉冲

  • post_spike – 输出脉冲

  • inverse – 为True则进行Anti-STDP

指定突触的前后脉冲,进行STDP学习