import torch
import torch.nn as nn
import torch.nn.functional as F
import math
[文档]class BaseEncoder(nn.Module):
def __init__(self):
'''
所有编码器的基类。编码器将输入数据(例如图像)编码为脉冲数据。
'''
super().__init__()
[文档] def forward(self, x):
'''
:param x: 要编码的数据
:return: 编码后的脉冲,或者是None
将x编码为脉冲。少数编码器(例如ConstantEncoder)可以将x编码成时长为1个dt的脉冲,在这种情况下,本函数返回编码后的脉冲。
多数编码器(例如PeriodicEncoder),都是把x编码成时长为n个dt的脉冲out_spike,out_spike.shape=[n, *]。
因此编码一次后,需要调用n次step()函数才能将脉冲全部发放完毕,第index次调用step()会得到out_spike[index]。
'''
raise NotImplementedError
[文档] def step(self):
'''
:return: 1个dt的脉冲
多数编码器(例如PeriodicEncoder),编码一次x,需要经过多步仿真才能将数据输出,这种情况下则用step来获取每一步的数据。
'''
raise NotImplementedError
[文档] def reset(self):
'''
:return: None
将编码器的所有状态变量设置为初始状态。对于有状态的编码器,需要重写这个函数。
'''
pass
[文档]class PeriodicEncoder(BaseEncoder):
def __init__(self, out_spike):
'''
:param out_spike: shape=[T, *],PeriodicEncoder会不断的输出out_spike[0], out_spike[1], ..., out_spike[T-1],
out_spike[0], out_spike[1], ...
给定out_spike后,周期性的输出out_spike[0], out_spike[1], ..., out_spike[T-1]的编码器。
'''
super().__init__()
assert out_spike.dtype == torch.bool
self.out_spike = out_spike
self.T = out_spike.shape[0]
self.index = 0
[文档] def forward(self, x):
'''
:param x: 输入数据,实际上并不需要输入数据,因为out_spike在初始化时已经被指定了
:return: 调用step()后得到的返回值
'''
return self.step()
[文档] def step(self):
'''
:return: out_spike[index]
初始化时index=0,每调用一次,index则自增1,index为T时修改为0。
'''
index = self.index
self.index += 1
if self.index == self.T:
self.index = 0
return self.out_spike[index]
[文档] def set_out_spike(self, out_spike):
'''
:param out_spike: 新设定的out_spike,必须是torch.bool
:return: None
重新设定编码器的输出脉冲self.out_spike为out_spike。
'''
assert out_spike.dtype == torch.bool
self.out_spike = out_spike
self.T = out_spike.shape[0]
self.index = 0
[文档] def reset(self):
'''
:return: None
重置编码器的状态变量,对于PeriodicEncoder而言将索引index置0即可。
'''
self.index = 0
[文档]class LatencyEncoder(BaseEncoder):
def __init__(self, max_spike_time, function_type='linear', device='cpu'):
'''
:param max_spike_time: 最晚(最大)脉冲发放时间
:param function_type: 'linear'或'log'
:param device: 数据所在设备
延迟编码,刺激强度越大,脉冲发放越早。要求刺激强度已经被归一化到[0, 1]。
脉冲发放时间 :math:`t_i` 与刺激强度 :math:`x_i` 满足:
type='linear'
.. math::
t_i = (t_{max} - 1) * (1 - x_i)
type='log'
.. math::
t_i = (t_{max} - 1) - ln(\\alpha * x_i + 1)
:math:`\\alpha` 满足:
.. math::
(t_{max} - 1) - ln(\\alpha * 1 + 1) = 0
这导致此编码器很容易发生溢出,因为
.. math::
\\alpha = e^{t_{max} - 1} - 1
当 :math:`t_{max}` 较大时 :math:`\\alpha` 极大。
示例代码:
.. code-block:: python
x = torch.rand(size=[3, 2])
max_spike_time = 20
le = encoding.LatencyEncoder(max_spike_time)
le(x)
print(x)
print(le.spike_time)
for i in range(max_spike_time):
print(le.step())
'''
super().__init__()
self.device = device
assert isinstance(max_spike_time, int) and max_spike_time > 1
self.max_spike_time = max_spike_time
if function_type == 'log':
self.alpha = math.exp(max_spike_time - 1) - 1
elif function_type != 'linear':
raise NotImplementedError
self.type = function_type
self.spike_time = 0
self.out_spike = 0
self.index = 0
[文档] def forward(self, x):
'''
:param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1]
将输入数据x编码为max_spike_time个时刻的max_spike_time个脉冲。
'''
# 将输入数据转换为不同时刻发放的脉冲
if self.type == 'log':
self.spike_time = (self.max_spike_time - 1 - torch.log(self.alpha * x + 1)).round().long()
else:
self.spike_time = ((self.max_spike_time - 1) * (1 - x)).round().long()
self.out_spike = F.one_hot(self.spike_time,
num_classes=self.max_spike_time).bool() # [*, max_spike_time]
[文档] def step(self):
'''
:return: out_spike[index]
初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0。
'''
index = self.index
self.index += 1
if self.index == self.max_spike_time:
self.index = 0
return self.out_spike[..., self.index]
[文档] def reset(self):
'''
:return: None
重置LatencyEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0。
'''
self.spike_time = 0
self.out_spike = 0
self.index = 0
[文档]class PoissonEncoder(BaseEncoder):
def __init__(self):
'''
泊松频率编码,输出脉冲可以看作是泊松流,发放脉冲的概率即为刺激强度,要求刺激强度已经被归一化到[0, 1]。
示例代码:
.. code-block:: python
pe = encoding.PoissonEncoder()
x = torch.rand(size=[8])
print(x)
for i in range(10):
print(pe(x))
'''
super().__init__()
[文档] def forward(self, x):
'''
:param x: 要编码的数据,任意形状的tensor,要求x的数据范围必须在[0, 1]
将输入数据x编码为脉冲,脉冲发放的概率即为对应位置元素的值。
'''
out_spike = torch.rand_like(x).le(x)
# torch.rand_like(x)生成与x相同shape的介于[0, 1)之间的随机数, 这个随机数小于等于x中对应位置的元素,则发放脉冲
return out_spike
[文档]class GaussianTuningCurveEncoder(BaseEncoder):
def __init__(self, x_min, x_max, tuning_curve_num, max_spike_time, device='cpu'):
'''
:param x_min: float,或者是shape=[M]的tensor,表示M个特征的最小值
:param x_max: float,或者是shape=[M]的tensor,表示M个特征的最大值
:param tuning_curve_num: 编码每个特征使用的高斯函数(调谐曲线)数量
:param max_spike_time: 最大脉冲发放时间,所有数据都会被编码到[0, max_spike_time - 1]范围内的脉冲发放时间
:param device: 数据所在设备
Bohte S M, Kok J N, La Poutre H. Error-backpropagation in temporally encoded networks of spiking neurons[J]. Neurocomputing, 2002, 48(1-4): 17-37.
高斯调谐曲线编码,一种时域编码方法。
首先生成tuning_curve_num个高斯函数,这些高斯函数的对称轴在数据范围内均匀排列,对于每一个输入x,计算tuning_curve_num个\
高斯函数的值,使用这些函数值线性地生成tuning_curve_num个脉冲发放时间。
待编码向量是M维tensor,也就是有M个特征。
1个M维tensor会被编码成shape=[M, tuning_curve_num]的tensor,表示M * tuning_curve_num个神经元的脉冲发放时间。
需要注意的是,编码一次数据,经过max_spike_time步仿真,才能进行下一次的编码。
示例代码:
.. code-block:: python
x = torch.rand(size=[3, 2])
tuning_curve_num = 10
max_spike_time = 20
ge = encoding.GaussianTuningCurveEncoder(x.min(0)[0], x.max(0)[0], tuning_curve_num=tuning_curve_num, max_spike_time=max_spike_time)
ge(x)
for i in range(max_spike_time):
print(ge.step())
'''
super().__init__()
self.x_min = x_min
self.x_max = x_max
assert tuning_curve_num > 2
self.tuning_curve_num = tuning_curve_num
assert isinstance(max_spike_time, int) and max_spike_time > 1
self.max_spike_time = max_spike_time
self.device = device
if isinstance(x_min, torch.Tensor):
self.mu = torch.zeros(size=[x_min.shape[0], tuning_curve_num], dtype=torch.float, device=self.device)
else:
self.mu = torch.zeros(size=[1, tuning_curve_num], dtype=torch.float, device=self.device)
# 生成tuning_curve_num个高斯函数的方差和均值
self.sigma = 1 / 1.5 * (x_max - x_min) / (tuning_curve_num - 2)
for i in range(tuning_curve_num):
self.mu[:, i] = x_min + (2 * i - 3) / 2 * (x_max - x_min) / (tuning_curve_num - 2)
self.spike_time = 0
self.out_spike = 0
self.index = 0
[文档] def forward(self, x):
'''
:param x: 要编码的数据,shape=[batch_size, M]
将输入数据x编码为脉冲。
'''
assert self.index == 0
self.spike_time = torch.zeros(size=[x.shape[0], x.shape[1], self.tuning_curve_num], dtype=torch.float,
device=self.device)
for i in range(self.tuning_curve_num):
self.spike_time[:, :, i] = torch.exp(
-torch.pow(x - self.mu[:, i], 2) / 2 / (self.sigma ** 2)) # 数值在[0, 1]之间
self.spike_time = (-(self.max_spike_time - 1) * self.spike_time + (
self.max_spike_time - 1)).round().long() # [batch_size, M, tuning_curve_num]
self.out_spike = F.one_hot(self.spike_time,
num_classes=self.max_spike_time).bool() # [batch_size, M, tuning_curve_num, max_spike_time]
# 太晚发放的脉冲(最后时刻的脉冲)认为全部是0
self.out_spike[:, :, :, -1].zero_()
[文档] def step(self):
'''
:return: out_spike[index]
初始化时index=0,每调用一次,index则自增1,index为max_spike_time时修改为0。
'''
index = self.index
self.index += 1
if self.index == self.max_spike_time:
self.index = 0
return self.out_spike[:, :, :, index]
[文档] def reset(self):
'''
:return: None
重置GaussianTuningCurveEncoder的所有状态变量(包括spike_time,out_spike,index)为初始值0。
'''
self.spike_time = 0
self.out_spike = 0
self.index = 0
[文档]class IntervalEncoder(BaseEncoder):
def __init__(self, T_in, shape, device='cpu'):
'''
:param T_in: 脉冲发放的间隔
:param shape: 输出形状
:param device: 输出脉冲所在的设备
每隔 ``T_in`` 个步长就发放一次脉冲的编码器。
'''
super().__init__()
self.t = 0
self.T_in = T_in
self.out_spike = [torch.zeros(size=shape, device=device, dtype=torch.bool),
torch.ones(size=shape, device=device, dtype=torch.bool)]
[文档] def step(self):
if self.t == self.T_in:
self.t = 0
return self.out_spike[1]
else:
self.t += 1
return self.out_spike[0]
[文档] def reset(self):
self.t = 0
[文档]class WeightedPhaseEncoder(BaseEncoder):
def __init__(self, phase, period, device='cpu'):
'''
:param phase: 一个周期内用于编码的相数
:type phase: int
:param device: 输出脉冲所在的设备
:type device: str
Kim J, Kim H, Huh S, et al. Deep neural networks with weighted spikes[J]. Neurocomputing, 2018, 311: 373-386.
带权的相位编码,一种基于二进制表示的编码方法。
将输入按照二进制各位展开,从高位到低位遍历输入进行脉冲编码。相比于频率编码,每一位携带的信息量更多。编码相位数为 :math:`K` 时,可以对于处于区间 :math:`[0, 1-2^{-K}]` 的数进行编码。以下为原始论文中的示例:
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
| Phase (K=8) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
+==================================+================+================+================+================+================+================+================+================+
| Spike weight :math:`\omega(t)` | 2\ :sup:`-1` | 2\ :sup:`-2` | 2\ :sup:`-3` | 2\ :sup:`-4` | 2\ :sup:`-5` | 2\ :sup:`-6` | 2\ :sup:`-7` | 2\ :sup:`-8` |
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
| 192/256 | 1 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
| 1/256 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 |
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
| 128/256 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
| 255/256 | 1 | 1 | 1 | 1 | 1 | 1 | 1 | 1 |
+----------------------------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+
'''
super().__init__()
self.t = 0
self.phase = phase
self.device = device
[文档] def forward(self, x):
'''
:param x: 要编码的数据,shape=[batch_size, *]
将输入数据x编码为一个周期内的脉冲。
'''
assert (x >= 0).all() and (x <= 1 - 2 ** (-self.phase)).all()
inputs = x.copy()
self.out_spike = torch.empty((self.phase,) + x.shape, device=self.device) # 编码为[phase, batch_size, *]
w = 0.5
for i in range(self.phase):
self.out_spike[i] = inputs >= w
inputs -= w * self.out_spike[i]
w *= 0.5
[文档] def step(self):
out = self.out_spike[self.t]
self.t += 1
if self.t == self.phase:
self.t = 0
return out
[文档] def reset(self):
self.t = 0