spikingjelly.timing_based.encoding 源代码

import torch
import torch.nn as nn
import torch.nn.functional as F

[文档]class GaussianTuning: def __init__(self, n, m, x_min: torch.Tensor, x_max: torch.Tensor): ''' :param n: 特征的数量,int :param m: 编码一个特征所使用的神经元数量,int :param x_min: n个特征的最小值,shape=[n]的tensor :param x_max: n个特征的最大值,shape=[n]的tensor Bohte S M, Kok J N, La Poutre J A, et al. Error-backpropagation in temporally encoded networks of spiking \ neurons[J]. Neurocomputing, 2002, 48(1): 17-37. 中提出的高斯调谐曲线编码方式 编码器所使用的变量所在的device与x_min.device一致 ''' assert m > 2 self.m = m self.n = n i = torch.arange(1, m+1).unsqueeze(0).repeat(n, 1).float().to(x_min.device) # shape=[n, m] self.mu = x_min.unsqueeze(-1).repeat(1, m) + \ (2 * i - 3) / 2 * \ (x_max.unsqueeze(-1).repeat(1, m) - x_min.unsqueeze(-1).repeat(1, m)) / (m - 2) # shape=[n, m] self.sigma2 = (1 / 1.5 * (x_max - x_min) / (m - 2)).unsqueeze(-1).square().repeat(1, m) # shape=[n, m] # print('mu\n', self.mu) # print('sigma2\n', self.sigma2)
[文档] def encode(self, x: torch.Tensor, max_spike_time=50): ''' :param x: shape=[batch_size, n, k],batch_size个数据,每个数据含有n个特征,每个特征中有k个数据 :param max_spike_time: 最大(最晚)脉冲发放时间,也可以称为编码时间窗口的长度 :return: out_spikes, shape=[batch_size, n, k, m],将每个数据编码成了m个神经元的脉冲发放时间 ''' x_shape = x.shape # shape: [batch_size, n, k] -> [batch_size, k, n] x = x.permute(0, 2, 1) x = x.contiguous().view(-1, x.shape[2]) # shape=[batch_size * k, n] x = x.unsqueeze(-1).repeat(1, 1, self.m) # shape=[batch_size * k, n, m] # 计算x对应的m个高斯函数值 y = torch.exp(- (x - self.mu).square() / 2 / self.sigma2) # shape=[batch_size * k, n, m] out_spikes = (max_spike_time * (1 - y)).round() out_spikes[out_spikes >= max_spike_time] = -1 # -1表示无脉冲发放 # shape: [batch_size * k, n, m] -> [batch_size, k, n, m] out_spikes = out_spikes.view(x_shape[0], x_shape[2], out_spikes.shape[1], out_spikes.shape[2]) out_spikes = out_spikes.permute(0, 2, 1, 3) # shape: [batch_size, k, n, m] -> [batch_size, n, k, m] return out_spikes