SpikingFlow.learning package¶
Module contents¶
-
class
SpikingFlow.learning.
STDPModule
(tf_module, connection_module, neuron_module, tau_pre, tau_post, learning_rate, f_w=<function STDPModule.<lambda>>)[源代码]¶ 基类:
torch.nn.modules.module.Module
- 参数
tf_module – connection.transform中的脉冲-电流转换器
connection_module – 突触
neuron_module – 神经元
tau_pre – pre脉冲的迹的时间常数
tau_post – post脉冲的迹的时间常数
learning_rate – 学习率
f_w – 权值函数,输入是权重w,输出一个float或者是与权重w相同shape的tensor
由tf_module,connection_module,neuron_module构成的STDP学习的基本单元
利用迹的方式(Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spike timing[J]. Biological cybernetics, 2008, 98(6): 459-478.)实现STDP学习,更新connection_module中的参数
pre脉冲到达时,权重减少trace_post * f_w(w) * learning_rate
post脉冲到达时,权重增加trace_pre * f_w(w) * learning_rate
示例代码
import SpikingFlow.simulating as simulating import SpikingFlow.learning as learning import SpikingFlow.connection as connection import SpikingFlow.connection.transform as tf import SpikingFlow.neuron as neuron import torch from matplotlib import pyplot # 新建一个仿真器 sim = simulating.Simulator() # 添加各个模块。为了更明显的观察到脉冲,我们使用IF神经元,而且把膜电阻设置的很大 # 突触的pre是2个输入,而post是1个输出,连接权重是shape=[1, 2]的tensor sim.append(learning.STDPModule(tf.SpikeCurrent(amplitude=0.5), connection.Linear(2, 1), neuron.IFNode(shape=[1], r=50.0, v_threshold=1.0), tau_pre=10.0, tau_post=10.0, learning_rate=1e-3 )) # 新建list,分别保存pre的2个输入脉冲、post的1个输出脉冲,以及对应的连接权重 pre_spike_list0 = [] pre_spike_list1 = [] post_spike_list = [] w_list0 = [] w_list1 = [] T = 200 for t in range(T): if t < 100: # 前100步仿真,pre_spike[0]和pre_spike[1]都是发放一次1再发放一次0 if t % 2 == 0: pre_spike = torch.ones(size=[2], dtype=torch.bool) else: pre_spike = torch.zeros(size=[2], dtype=torch.bool) else: # 后100步仿真,pre_spike[0]一直为0,而pre_spike[1]一直为1 pre_spike = torch.zeros(size=[2], dtype=torch.bool) pre_spike[1] = True post_spike = sim.step(pre_spike) pre_spike_list0.append(pre_spike[0].float().item()) pre_spike_list1.append(pre_spike[1].float().item()) post_spike_list.append(post_spike.float().item()) w_list0.append(sim.module_list[-1].module_list[2].w[:, 0].item()) w_list1.append(sim.module_list[-1].module_list[2].w[:, 1].item()) # 画出pre_spike[0] pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list0, width=0.1, label='pre_spike[0]') pyplot.legend() pyplot.show() # 画出pre_spike[1] pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list1, width=0.1, label='pre_spike[1]') pyplot.legend() pyplot.show() # 画出post_spike pyplot.bar(torch.arange(0, T).tolist(), post_spike_list, width=0.1, label='post_spike') pyplot.legend() pyplot.show() # 画出2个输入与1个输出的连接权重w_0和w_1 pyplot.plot(w_list0, c='r', label='w[0]') pyplot.plot(w_list1, c='g', label='w[1]') pyplot.legend() pyplot.show()
-
forward
(pre_spike)[源代码]¶ - 参数
pre_spike – 输入脉冲
- 返回
经过本module后的输出脉冲
需要注意的时,由于本module含有tf_module, connection_module, neuron_module三个module
因此在t时刻的输入,到t+3dt才能得到其输出
-
training
: bool¶
-
class
SpikingFlow.learning.
STDPUpdater
(tau_pre, tau_post, learning_rate, f_w=<function STDPUpdater.<lambda>>)[源代码]¶ 基类:
object
- 参数
neuron_module – 神经元
tau_pre – pre脉冲的迹的时间常数
tau_post – post脉冲的迹的时间常数
learning_rate – 学习率
f_w – 权值函数,输入是权重w,输出一个float或者是与权重w相同shape的tensor
利用迹的方式(Morrison A, Diesmann M, Gerstner W. Phenomenological models of synaptic plasticity based on spike timing[J]. Biological cybernetics, 2008, 98(6): 459-478.)实现STDP学习,更新connection_module中的参数
pre脉冲到达时,权重减少trace_post * f_w(w) * learning_rate
post脉冲到达时,权重增加trace_pre * f_w(w) * learning_rate
与STDPModule类似,但需要手动给定前后脉冲,这也带来了更为灵活的使用方式,例如 不使用突触实际连接的前后神经元的脉冲,而是使用其他脉冲来指导某个突触的学习
示例代码
import SpikingFlow.simulating as simulating import SpikingFlow.learning as learning import SpikingFlow.connection as connection import SpikingFlow.connection.transform as tf import SpikingFlow.neuron as neuron import torch from matplotlib import pyplot # 定义权值函数f_w def f_w(x: torch.Tensor): x_abs = x.abs() return x_abs / (x_abs.sum() + 1e-6) # 新建一个仿真器 sim = simulating.Simulator() # 放入脉冲电流转换器、突触、LIF神经元 sim.append(tf.SpikeCurrent(amplitude=0.5)) sim.append(connection.Linear(2, 1)) sim.append(neuron.LIFNode(shape=[1], r=10.0, v_threshold=1.0, tau=100.0)) # 新建一个STDPUpdater updater = learning.STDPUpdater(tau_pre=50.0, tau_post=100.0, learning_rate=1e-1, f_w=f_w) # 新建list,保存pre脉冲、post脉冲、突触权重w_00, w_01 pre_spike_list0 = [] pre_spike_list1 = [] post_spike_list = [] w_list0 = [] w_list1 = [] T = 500 for t in range(T): if t < 250: if t % 2 == 0: pre_spike = torch.ones(size=[2], dtype=torch.bool) else: pre_spike = torch.randint(low=0, high=2, size=[2]).bool() else: pre_spike = torch.zeros(size=[2], dtype=torch.bool) if t % 2 == 0: pre_spike[1] = True pre_spike_list0.append(pre_spike[0].float().item()) pre_spike_list1.append(pre_spike[1].float().item()) post_spike = sim.step(pre_spike) updater.update(sim.module_list[1], pre_spike, post_spike) post_spike_list.append(post_spike.float().item()) w_list0.append(sim.module_list[1].w[:, 0].item()) w_list1.append(sim.module_list[1].w[:, 1].item()) pyplot.figure(figsize=(8, 16)) pyplot.subplot(4, 1, 1) pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list0, width=0.1, label='pre_spike[0]') pyplot.legend() pyplot.subplot(4, 1, 2) pyplot.bar(torch.arange(0, T).tolist(), pre_spike_list1, width=0.1, label='pre_spike[1]') pyplot.legend() pyplot.subplot(4, 1, 3) pyplot.bar(torch.arange(0, T).tolist(), post_spike_list, width=0.1, label='post_spike') pyplot.legend() pyplot.subplot(4, 1, 4) pyplot.plot(w_list0, c='r', label='w[0]') pyplot.plot(w_list1, c='g', label='w[1]') pyplot.legend() pyplot.show()