SpikingFlow.encoding 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
[文档]class BaseEncoder(nn.Module): def __init__(self): ''' 所有编码器的基类 编码器将输入数据(例如图像)编码为脉冲数据 ''' super().__init__()
[文档] def forward(self, x): ''' :param x: 要编码的数据 :return: 编码后的脉冲,或者是None 将x编码为脉冲。少数编码器(例如ConstantEncoder)可以将x编码成时长为1个dt的脉冲,在这种情况下,本函数返回编码后的脉冲 多数编码器(例如PeriodicEncoder),都是把x编码成时长为n个dt的脉冲out_spike,out_spike.shape=[n, *] 因此编码一次后,需要调用n次step()函数才能将脉冲全部发放完毕 第index调用step()会得到out_spike[index] ''' raise NotImplementedError
[文档] def step(self): ''' :return: 1个dt的脉冲 多数编码器(例如PeriodicEncoder),编码一次x,需要经过多步仿真才能将数据输出,这种情况下则用step来获取每一步的数据 ''' raise NotImplementedError
[文档] def reset(self): ''' :return: None 将编码器的所有状态变量设置为初始状态。对于有状态的编码器,需要重写这个函数 ''' pass
[文档]class ConstantEncoder(BaseEncoder): def __init__(self): ''' 将输入简单转化为脉冲,输入中大于0的位置输出1,其他位置输出0 ''' super().__init__()
[文档] def forward(self, x: torch.Tensor): ''' :param x: 任意的tensor :return: x.bool(),大于0的位置为1,其他位置为0 ''' return x.bool()
[文档]class PeriodicEncoder(BaseEncoder): def __init__(self, out_spike): ''' :param out_spike: shape=[T, *],PeriodicEncoder会不断的输出out_spike[0], out_spike[1], ..., out_spike[T-1], out_spike[0], out_spike[1], ... 给定out_spike后,周期性的输出out_spike[0], out_spike[1], ..., out_spike[T-1]的编码器 ''' super().__init__() assert out_spike.dtype == torch.bool self.out_spike = out_spike self.T = out_spike.shape[0] self.index = 0
[文档] def forward(self, x): ''' :param x: 输入数据,实际上并不需要输入数据,因为out_spike在初始化时已经被指定了 :return: 调用step()后得到的返回值 ''' return self.step()
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为T时修改为0 ''' index = self.index self.index += 1 if self.index == self.T: self.index = 0 return self.out_spike[index]
[文档] def set_out_spike(self, out_spike): ''' :param out_spike: 新设定的out_spike,必须是torch.bool :return: None 重新设定编码器的输出脉冲self.out_spike为out_spike ''' assert out_spike.dtype == torch.bool self.out_spike = out_spike self.T = out_spike.shape[0] self.index = 0
[文档] def reset(self): ''' :return: None 重置编码器的状态变量,对于PeriodicEncoder而言将索引index置0即可 ''' self.index = 0
[文档]class LatencyEncoder(BaseEncoder): def __init__(self, max_spike_time, function_type='linear', device='cpu'): ''' :param max_spike_time: 最晚(最大)脉冲发放时间 :param function_type: 'linear'或'log' :param device: 数据所在设备 延迟编码,刺激强度越大,脉冲发放越早。要求刺激强度已经被归一化到[0, 1] 脉冲发放时间 :math:`t_i` 与刺激强度 :math:`x_i` 满足 type='linear' .. math:: t_i = (t_{max} - 1) * (1 - x_i) type='log' .. math:: t_i = (t_{max} - 1) - ln(alpha * x_i + 1) :math:`alpha` 满足 .. math:: (t_{max} - 1) - ln(alpha * 1 + 1) = 0 这导致此编码器很容易发生溢出,因为 .. math:: alpha = exp(t_{max} - 1) - 1 当 :math:`t_{max}` 较大时 :math:`alpha` 极大 示例代码 .. code-block:: python x = torch.rand(size=[3, 2]) max_spike_time = 20 le = encoding.LatencyEncoder(max_spike_time) le(x) print(x) print(le.spike_time) for i in range(max_spike_time): print(le.step()) ''' super().__init__() self.device = device assert isinstance(max_spike_time, int) and max_spike_time > 1 self.max_spike_time = max_spike_time if function_type == 'log': self.alpha = math.exp(max_spike_time - 1) - 1 elif function_type != 'linear': raise NotImplementedError self.type = function_type self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档] def forward(self, x): ''' :param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1] 将输入数据x编码为max_spike_time个时刻的max_spike_time个脉冲 ''' # 将输入数据转换为不同时刻发放的脉冲 if self.type == 'log': self.spike_time = (self.max_spike_time - 1 - torch.log(self.alpha * x + 1)).round().long() else: self.spike_time = (self.max_spike_time - 1) * (1 - x).round().long() self.out_spike = F.one_hot(self.spike_time, num_classes=self.max_spike_time).bool() # [*, max_spike_time] self.out_spike.transpose_(0, self.out_spike.shape[-1]) # [*, max_spike_time] -> [max_spike_time, *]
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0 ''' index = self.index self.index += 1 if self.index == self.max_spike_time: self.index = 0 return self.out_spike[self.index]
[文档] def reset(self): ''' :return: None 重置LatencyEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0 ''' self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档]class PoissonEncoder(BaseEncoder): def __init__(self): ''' 泊松频率编码,输出脉冲可以看作是泊松流,发放脉冲的概率即为刺激强度,要求刺激强度已经被归一化到[0, 1] 示例代码 .. code-block:: python pe = encoding.PoissonEncoder() x = torch.rand(size=[8]) print(x) for i in range(10): print(pe(x)) ''' super().__init__()
[文档] def forward(self, x): ''' :param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1] 将输入数据x编码为脉冲,脉冲发放的概率即为对应位置元素的值 ''' out_spike = torch.rand_like(x).le(x) # torch.rand_like(x)生成与x相同shape的介于[0, 1)之间的随机数, 这个随机数小于等于x中对应位置的元素,则发放脉冲 return out_spike
[文档]class GaussianTuningCurveEncoder(BaseEncoder): def __init__(self, x_min, x_max, tuning_curve_num, max_spike_time, device='cpu'): ''' :param x_min: float,或者是shape=[M]的tensor,表示M个特征的最小值 :param x_max: float,或者是shape=[M]的tensor,表示M个特征的最大值 :param tuning_curve_num: 编码每个特征使用的高斯函数(调谐曲线)数量 :param max_spike_time: 最大脉冲发放时间,所有数据都会被编码到[0, max_spike_time - 1]范围内的脉冲发放时间 :param device: 数据所在设备 Bohte S M, Kok J N, La Poutre H. Error-backpropagation in temporally encoded networks of spiking neurons[J]. Neurocomputing, 2002, 48(1-4): 17-37. 高斯调谐曲线编码,一种时域编码方法 首先生成tuning_curve_num个高斯函数,这些高斯函数的对称轴在数据范围内均匀排列 对于每一个输入x,计算tuning_curve_num个高斯函数的值,使用这些函数值线性地生成tuning_curve_num个脉冲发放时间 待编码向量是M维tensor,也就是有M个特征 1个M维tensor会被编码成shape=[M, tuning_curve_num]的tensor,表示M * tuning_curve_num个神经元的脉冲发放时间 需要注意的是,编码一次数据,经过max_spike_time步仿真,才能进行下一次的编码 示例代码 .. code-block:: python x = torch.rand(size=[3, 2]) tuning_curve_num = 10 max_spike_time = 20 ge = encoding.GaussianTuningCurveEncoder(x.min(0)[0], x.max(0)[0], tuning_curve_num=tuning_curve_num, max_spike_time=max_spike_time) ge(x) for i in range(max_spike_time): print(ge.step()) ''' super().__init__() self.x_min = x_min self.x_max = x_max assert tuning_curve_num > 2 self.tuning_curve_num = tuning_curve_num assert isinstance(max_spike_time, int) and max_spike_time > 1 self.max_spike_time = max_spike_time self.device = device if isinstance(x_min, torch.Tensor): self.mu = torch.zeros(size=[x_min.shape[0], tuning_curve_num], dtype=torch.float, device=self.device) else: self.mu = torch.zeros(size=[1, tuning_curve_num], dtype=torch.float, device=self.device) # 生成tuning_curve_num个高斯函数的方差和均值 self.sigma = 1 / 1.5 * (x_max - x_min) / (tuning_curve_num - 2) for i in range(tuning_curve_num): self.mu[:, i] = x_min + (2 * i - 3) / 2 * (x_max - x_min) / (tuning_curve_num - 2) self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档] def forward(self, x): ''' :param x: 要编码的数据,shape=[batch_size, M] 将输入数据x编码为脉冲 ''' assert self.index == 0 self.spike_time = torch.zeros(size=[x.shape[0], x.shape[1], self.tuning_curve_num], dtype=torch.float, device=self.device) for i in range(self.tuning_curve_num): self.spike_time[:, :, i] = torch.exp( -torch.pow(x - self.mu[:, i], 2) / 2 / (self.sigma ** 2)) # 数值在[0, 1]之间 self.spike_time = (-(self.max_spike_time - 1) * self.spike_time + ( self.max_spike_time - 1)).round().long() # [batch_size, M, tuning_curve_num] self.out_spike = F.one_hot(self.spike_time, num_classes=self.max_spike_time).bool() # [batch_size, M, tuning_curve_num, max_spike_time] # 太晚发放的脉冲(最后时刻的脉冲)认为全部是0 self.out_spike[:, :, :, -1].zero_()
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0 ''' index = self.index self.index += 1 if self.index == self.max_spike_time: self.index = 0 return self.out_spike[:, :, :, index]
[文档] def reset(self): ''' :return: None 重置GaussianTuningCurveEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0 ''' self.spike_time = 0 self.out_spike = 0 self.index = 0