spikingjelly.clock_driven.encoding 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
[文档]class BaseEncoder(nn.Module): def __init__(self): ''' 所有编码器的基类。编码器将输入数据(例如图像)编码为脉冲数据。 ''' super().__init__()
[文档] def forward(self, x): ''' :param x: 要编码的数据 :return: 编码后的脉冲,或者是None 将x编码为脉冲。少数编码器(例如ConstantEncoder)可以将x编码成时长为1个dt的脉冲,在这种情况下,本函数返回编码后的脉冲。 多数编码器(例如PeriodicEncoder),都是把x编码成时长为n个dt的脉冲out_spike,out_spike.shape=[n, *]。 因此编码一次后,需要调用n次step()函数才能将脉冲全部发放完毕,第index次调用step()会得到out_spike[index]。 ''' raise NotImplementedError
[文档] def step(self): ''' :return: 1个dt的脉冲 多数编码器(例如PeriodicEncoder),编码一次x,需要经过多步仿真才能将数据输出,这种情况下则用step来获取每一步的数据。 ''' raise NotImplementedError
[文档] def reset(self): ''' :return: None 将编码器的所有状态变量设置为初始状态。对于有状态的编码器,需要重写这个函数。 ''' pass
[文档]class PeriodicEncoder(BaseEncoder): def __init__(self, out_spike): ''' :param out_spike: shape=[T, *],PeriodicEncoder会不断的输出out_spike[0], out_spike[1], ..., out_spike[T-1], out_spike[0], out_spike[1], ... 给定out_spike后,周期性的输出out_spike[0], out_spike[1], ..., out_spike[T-1]的编码器。 ''' super().__init__() assert out_spike.dtype == torch.bool self.out_spike = out_spike self.T = out_spike.shape[0] self.index = 0
[文档] def forward(self, x): ''' :param x: 输入数据,实际上并不需要输入数据,因为out_spike在初始化时已经被指定了 :return: 调用step()后得到的返回值 ''' return self.step()
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为T时修改为0。 ''' index = self.index self.index += 1 if self.index == self.T: self.index = 0 return self.out_spike[index]
[文档] def set_out_spike(self, out_spike): ''' :param out_spike: 新设定的out_spike,必须是torch.bool :return: None 重新设定编码器的输出脉冲self.out_spike为out_spike。 ''' assert out_spike.dtype == torch.bool self.out_spike = out_spike self.T = out_spike.shape[0] self.index = 0
[文档] def reset(self): ''' :return: None 重置编码器的状态变量,对于PeriodicEncoder而言将索引index置0即可。 ''' self.index = 0
[文档]class LatencyEncoder(BaseEncoder): def __init__(self, max_spike_time, function_type='linear', device='cpu'): ''' :param max_spike_time: 最晚(最大)脉冲发放时间 :param function_type: 'linear'或'log' :param device: 数据所在设备 延迟编码,刺激强度越大,脉冲发放越早。要求刺激强度已经被归一化到[0, 1]。 脉冲发放时间 :math:`t_i` 与刺激强度 :math:`x_i` 满足: type='linear' .. math:: t_i = (t_{max} - 1) * (1 - x_i) type='log' .. math:: t_i = (t_{max} - 1) - ln(\\alpha * x_i + 1) :math:`\\alpha` 满足: .. math:: (t_{max} - 1) - ln(\\alpha * 1 + 1) = 0 这导致此编码器很容易发生溢出,因为 .. math:: \\alpha = e^{t_{max} - 1} - 1 当 :math:`t_{max}` 较大时 :math:`\\alpha` 极大。 示例代码: .. code-block:: python x = torch.rand(size=[3, 2]) max_spike_time = 20 le = encoding.LatencyEncoder(max_spike_time) le(x) print(x) print(le.spike_time) for i in range(max_spike_time): print(le.step()) ''' super().__init__() self.device = device assert isinstance(max_spike_time, int) and max_spike_time > 1 self.max_spike_time = max_spike_time if function_type == 'log': self.alpha = math.exp(max_spike_time - 1) - 1 elif function_type != 'linear': raise NotImplementedError self.type = function_type self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档] def forward(self, x): ''' :param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1] 将输入数据x编码为max_spike_time个时刻的max_spike_time个脉冲。 ''' # 将输入数据转换为不同时刻发放的脉冲 if self.type == 'log': self.spike_time = (self.max_spike_time - 1 - torch.log(self.alpha * x + 1)).round().long() else: self.spike_time = ((self.max_spike_time - 1) * (1 - x)).round().long() self.out_spike = F.one_hot(self.spike_time, num_classes=self.max_spike_time).bool() # [*, max_spike_time]
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0。 ''' index = self.index self.index += 1 if self.index == self.max_spike_time: self.index = 0 return self.out_spike[..., self.index]
[文档] def reset(self): ''' :return: None 重置LatencyEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0。 ''' self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档]class PoissonEncoder(BaseEncoder): def __init__(self): ''' 泊松频率编码,输出脉冲可以看作是泊松流,发放脉冲的概率即为刺激强度,要求刺激强度已经被归一化到[0, 1]。 示例代码: .. code-block:: python pe = encoding.PoissonEncoder() x = torch.rand(size=[8]) print(x) for i in range(10): print(pe(x)) ''' super().__init__()
[文档] def forward(self, x): ''' :param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1] 将输入数据x编码为脉冲,脉冲发放的概率即为对应位置元素的值。 ''' out_spike = torch.rand_like(x).le(x) # torch.rand_like(x)生成与x相同shape的介于[0, 1)之间的随机数, 这个随机数小于等于x中对应位置的元素,则发放脉冲 return out_spike
[文档]class GaussianTuningCurveEncoder(BaseEncoder): def __init__(self, x_min, x_max, tuning_curve_num, max_spike_time, device='cpu'): ''' :param x_min: float,或者是shape=[M]的tensor,表示M个特征的最小值 :param x_max: float,或者是shape=[M]的tensor,表示M个特征的最大值 :param tuning_curve_num: 编码每个特征使用的高斯函数(调谐曲线)数量 :param max_spike_time: 最大脉冲发放时间,所有数据都会被编码到[0, max_spike_time - 1]范围内的脉冲发放时间 :param device: 数据所在设备 Bohte S M, Kok J N, La Poutre H. Error-backpropagation in temporally encoded networks of spiking neurons[J]. Neurocomputing, 2002, 48(1-4): 17-37. 高斯调谐曲线编码,一种时域编码方法。 首先生成tuning_curve_num个高斯函数,这些高斯函数的对称轴在数据范围内均匀排列,对于每一个输入x,计算tuning_curve_num个\ 高斯函数的值,使用这些函数值线性地生成tuning_curve_num个脉冲发放时间。 待编码向量是M维tensor,也就是有M个特征。 1个M维tensor会被编码成shape=[M, tuning_curve_num]的tensor,表示M * tuning_curve_num个神经元的脉冲发放时间。 需要注意的是,编码一次数据,经过max_spike_time步仿真,才能进行下一次的编码。 示例代码: .. code-block:: python x = torch.rand(size=[3, 2]) tuning_curve_num = 10 max_spike_time = 20 ge = encoding.GaussianTuningCurveEncoder(x.min(0)[0], x.max(0)[0], tuning_curve_num=tuning_curve_num, max_spike_time=max_spike_time) ge(x) for i in range(max_spike_time): print(ge.step()) ''' super().__init__() self.x_min = x_min self.x_max = x_max assert tuning_curve_num > 2 self.tuning_curve_num = tuning_curve_num assert isinstance(max_spike_time, int) and max_spike_time > 1 self.max_spike_time = max_spike_time self.device = device if isinstance(x_min, torch.Tensor): self.mu = torch.zeros(size=[x_min.shape[0], tuning_curve_num], dtype=torch.float, device=self.device) else: self.mu = torch.zeros(size=[1, tuning_curve_num], dtype=torch.float, device=self.device) # 生成tuning_curve_num个高斯函数的方差和均值 self.sigma = 1 / 1.5 * (x_max - x_min) / (tuning_curve_num - 2) for i in range(tuning_curve_num): self.mu[:, i] = x_min + (2 * i - 3) / 2 * (x_max - x_min) / (tuning_curve_num - 2) self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档] def forward(self, x): ''' :param x: 要编码的数据,shape=[batch_size, M] 将输入数据x编码为脉冲。 ''' assert self.index == 0 self.spike_time = torch.zeros(size=[x.shape[0], x.shape[1], self.tuning_curve_num], dtype=torch.float, device=self.device) for i in range(self.tuning_curve_num): self.spike_time[:, :, i] = torch.exp( -torch.pow(x - self.mu[:, i], 2) / 2 / (self.sigma ** 2)) # 数值在[0, 1]之间 self.spike_time = (-(self.max_spike_time - 1) * self.spike_time + ( self.max_spike_time - 1)).round().long() # [batch_size, M, tuning_curve_num] self.out_spike = F.one_hot(self.spike_time, num_classes=self.max_spike_time).bool() # [batch_size, M, tuning_curve_num, max_spike_time] # 太晚发放的脉冲(最后时刻的脉冲)认为全部是0 self.out_spike[:, :, :, -1].zero_()
[文档] def step(self): ''' :return: out_spike[index] 初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0。 ''' index = self.index self.index += 1 if self.index == self.max_spike_time: self.index = 0 return self.out_spike[:, :, :, index]
[文档] def reset(self): ''' :return: None 重置GaussianTuningCurveEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0。 ''' self.spike_time = 0 self.out_spike = 0 self.index = 0
[文档]class IntervalEncoder(BaseEncoder): def __init__(self, T_in, shape, device='cpu'): ''' :param T_in: 脉冲发放的间隔 :param shape: 输出形状 :param device: 输出脉冲所在的设备 每隔 ``T_in`` 个步长就发放一次脉冲的编码器。 ''' super().__init__() self.t = 0 self.T_in = T_in self.out_spike = [torch.zeros(size=shape, device=device, dtype=torch.bool), torch.ones(size=shape, device=device, dtype=torch.bool)]
[文档] def step(self): if self.t == self.T_in: self.t = 0 return self.out_spike[1] else: self.t += 1 return self.out_spike[0]
[文档] def reset(self): self.t = 0
[文档]class WeightedPhaseEncoder(BaseEncoder): def __init__(self, phase, period, device='cpu'): ''' :param phase: 一个周期内用于编码的相数 :type phase: int :param device: 输出脉冲所在的设备 :type device: str Kim J, Kim H, Huh S, et al. Deep neural networks with weighted spikes[J]. Neurocomputing, 2018, 311: 373-386. 带权的相位编码,一种基于二进制表示的编码方法。 将输入按照二进制各位展开,从高位到低位遍历输入进行脉冲编码。相比于频率编码,每一位携带的信息量更多。编码相位数为 :math:`K` 时,可以对于处于区间 :math:`[0, 1-2^{-K}]` 的数进行编码。以下为原始论文中的示例: +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ | Phase (K=8) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +==================================+================+================+================+================+================+================+================+================+ | Spike weight :math:`\omega(t)` | 2\ :sup:`-1` | 2\ :sup:`-2` | 2\ :sup:`-3` | 2\ :sup:`-4` | 2\ :sup:`-5` | 2\ :sup:`-6` | 2\ :sup:`-7` | 2\ :sup:`-8` | +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ | 192/256 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ | 1/256 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ | 128/256 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ | 255/256 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | +----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+ ''' super().__init__() self.t = 0 self.phase = phase self.device = device
[文档] def forward(self, x): ''' :param x: 要编码的数据,shape=[batch_size, *] 将输入数据x编码为一个周期内的脉冲。 ''' assert (x >= 0).all() and (x <= 1 - 2 ** (-self.phase)).all() inputs = x.copy() self.out_spike = torch.empty((self.phase,) + x.shape, device=self.device) # 编码为[phase, batch_size, *] w = 0.5 for i in range(self.phase): self.out_spike[i] = inputs >= w inputs -= w * self.out_spike[i] w *= 0.5
[文档] def step(self): out = self.out_spike[self.t] self.t += 1 if self.t == self.phase: self.t = 0 return out
[文档] def reset(self): self.t = 0