spikingjelly.activation_based.neuron.dsr 源代码

import math

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from .. import base

__all__ = ["DSRIFNode", "DSRLIFNode"]


[文档] class DSRIFNode(base.MemoryModule): def __init__( self, T: int = 20, v_threshold: float = 6.0, alpha: float = 0.5, v_threshold_training: bool = True, v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.01, step_mode="m", backend="torch", **kwargs, ): """ **API Language:** :ref:`中文 <DSRIFNode.__init__-cn>` | :ref:`English <DSRIFNode.__init__-en>` ---- .. _DSRIFNode.__init__-cn: * **中文** DSR IF 神经元,由 `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation <https://arxiv.org/pdf/2205.00459.pdf>`_ 提出。 该模型基于对脉冲表示的可微建模,用于低时延、高性能脉冲神经网络训练。 :param T: 时间步数 :type T: int :param v_threshold: 神经元阈值电压的初始值 :type v_threshold: float :param alpha: 阈值电压的缩放因子 :type alpha: float :param v_threshold_training: 是否将阈值电压设为可学习参数,默认为 ``True`` :type v_threshold_training: bool :param v_threshold_grad_scaling: 对阈值电压梯度进行缩放的系数 :type v_threshold_grad_scaling: float :param v_threshold_lower_bound: 训练过程中阈值电压允许的最小值 :type v_threshold_lower_bound: float :param step_mode: 步进模式,仅支持 ``'m'`` (多步) :type step_mode: str :param backend: 使用的后端。不同 ``step_mode`` 支持的后端可能不同。 可通过 ``self.supported_backends`` 查看当前步进模式支持的后端。 DSR-IF 仅支持 ``'torch'`` 后端 :type backend: str ---- .. _DSRIFNode.__init__-en: * **English** DSR IF neuron, proposed in `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation <https://arxiv.org/pdf/2205.00459.pdf>`_. This model enables low-latency and high-performance SNN training via differentiable spike representations. :param T: number of time-steps :type T: int :param v_threshold: initial membrane potential threshold :type v_threshold: float :param alpha: scaling factor of the membrane potential threshold :type alpha: float :param v_threshold_training: whether the membrane potential threshold is learnable, default: ``True`` :type v_threshold_training: bool :param v_threshold_grad_scaling: scaling factor applied to the gradient of the membrane potential threshold :type v_threshold_grad_scaling: float :param v_threshold_lower_bound: minimum allowable membrane potential threshold during training :type v_threshold_lower_bound: float :param step_mode: step mode, only `'m'` (multi-step) is supported :type step_mode: str :param backend: backend of this neuron layer. Supported backends depend on ``step_mode``. Users can print ``self.supported_backends`` to check availability. DSR-IF only supports the ``'torch'`` backend :type backend: str :return: None :rtype: None """ assert isinstance(T, int) and T is not None assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0 assert ( isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0 ) assert step_mode == "m" super().__init__() self.backend = backend self.step_mode = step_mode self.T = T if v_threshold_training: self.v_threshold = nn.Parameter(torch.tensor(v_threshold)) else: self.v_threshold = torch.tensor(v_threshold) self.alpha = alpha self.v_threshold_lower_bound = v_threshold_lower_bound self.v_threshold_grad_scaling = v_threshold_grad_scaling @property def supported_backends(self): return "torch" def extra_repr(self): with torch.no_grad(): T = self.T v_threshold = self.v_threshold alpha = self.alpha v_threshold_lower_bound = self.v_threshold_lower_bound v_threshold_grad_scaling = self.v_threshold_grad_scaling return ( f", T={T}" + f", init_vth={v_threshold}" + f", alpha={alpha}" + f", vth_bound={v_threshold_lower_bound}" + f", vth_g_scale={v_threshold_grad_scaling}" )
[文档] def multi_step_forward(self, x_seq: torch.Tensor): with torch.no_grad(): self.v_threshold.copy_( F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound ) iffunc = self.DSRIFFunction.apply y_seq = iffunc( x_seq, self.T, self.v_threshold, self.alpha, self.v_threshold_grad_scaling ) return y_seq
[文档] class DSRIFFunction(torch.autograd.Function):
[文档] @staticmethod def forward( ctx, inp, T=10, v_threshold=1.0, alpha=0.5, v_threshold_grad_scaling=1.0 ): ctx.save_for_backward(inp) mem_potential = torch.zeros_like(inp[0]).to(inp.device) spikes = [] for t in range(inp.size(0)): mem_potential = mem_potential + inp[t] spike = ( (mem_potential >= alpha * v_threshold).float() * v_threshold ).float() mem_potential = mem_potential - spike spikes.append(spike) output = torch.stack(spikes) ctx.T = T ctx.v_threshold = v_threshold ctx.v_threshold_grad_scaling = v_threshold_grad_scaling return output
[文档] @staticmethod def backward(ctx, grad_output): with torch.no_grad(): inp = ctx.saved_tensors[0] T = ctx.T v_threshold = ctx.v_threshold v_threshold_grad_scaling = ctx.v_threshold_grad_scaling input_rate_coding = torch.mean(inp, 0) grad_output_coding = torch.mean(grad_output, 0) * T input_grad = grad_output_coding.clone() input_grad[ (input_rate_coding < 0) | (input_rate_coding > v_threshold) ] = 0 input_grad = torch.stack([input_grad for _ in range(T)]) / T v_threshold_grad = grad_output_coding.clone() v_threshold_grad[input_rate_coding <= v_threshold] = 0 v_threshold_grad = ( torch.sum(v_threshold_grad) * v_threshold_grad_scaling ) if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1: try: dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM) except Exception: raise RuntimeWarning( "Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel." ) return input_grad, None, v_threshold_grad, None, None
[文档] class DSRLIFNode(base.MemoryModule): def __init__( self, T: int = 20, v_threshold: float = 1.0, tau: float = 2.0, delta_t: float = 0.05, alpha: float = 0.3, v_threshold_training: bool = True, v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.1, step_mode="m", backend="torch", **kwargs, ): """ **API Language:** :ref:`中文 <DSRLIFNode.__init__-cn>` | :ref:`English <DSRLIFNode.__init__-en>` ---- .. _DSRLIFNode.__init__-cn: * **中文** DSR LIF 神经元,由 `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation <https://arxiv.org/pdf/2205.00459.pdf>`_ 提出。该模型通过对脉冲表示进行可微建模,实现低时延、高性能的脉冲神经网络训练。 :param T: 时间步数 :type T: int :param v_threshold: 神经元阈值电压的初始值 :type v_threshold: float :param tau: 膜电位时间常数 :type tau: float :param delta_t: 对连续时间 LIF 微分方程进行离散化的时间步长 :type delta_t: float :param alpha: 阈值电压的缩放因子 :type alpha: float :param v_threshold_training: 是否将阈值电压设为可学习参数,默认为 ``True`` :type v_threshold_training: bool :param v_threshold_grad_scaling: 对阈值电压梯度进行缩放的系数 :type v_threshold_grad_scaling: float :param v_threshold_lower_bound: 训练过程中阈值电压允许的最小值 :type v_threshold_lower_bound: float :param step_mode: 步进模式,仅支持 ``'m'`` (多步) :type step_mode: str :param backend: 使用的后端。不同 ``step_mode`` 支持的后端可能不同。 可通过 ``self.supported_backends`` 查看当前步进模式支持的后端。 DSR-LIF 仅支持 ``'torch'`` 后端 :type backend: str ---- .. _DSRLIFNode.__init__-en: * **English** DSR LIF neuron, proposed in `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation <https://arxiv.org/pdf/2205.00459.pdf>`_. This model enables low-latency and high-performance SNN training by differentiating spike representations. :param T: number of time-steps :type T: int :param v_threshold: initial membrane potential threshold :type v_threshold: float :param tau: membrane time constant :type tau: float :param delta_t: discretization step for the continuous-time LIF differential equation :type delta_t: float :param alpha: scaling factor of the membrane potential threshold :type alpha: float :param v_threshold_training: whether the membrane potential threshold is learnable, default: ``True`` :type v_threshold_training: bool :param v_threshold_grad_scaling: scaling factor applied to the gradient of the membrane potential threshold :type v_threshold_grad_scaling: float :param v_threshold_lower_bound: minimum allowable membrane potential threshold during training :type v_threshold_lower_bound: float :param step_mode: step mode, only `'m'` (multi-step) is supported :type step_mode: str :param backend: backend of this neuron layer. Supported backends depend on ``step_mode``. Users can print ``self.supported_backends`` to check availability. DSR-LIF only supports the ``'torch'`` backend :type backend: str :return: None :rtype: None """ assert isinstance(T, int) and T is not None assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0 assert ( isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0 ) assert step_mode == "m" super().__init__() self.backend = backend self.step_mode = step_mode self.T = T if v_threshold_training: self.v_threshold = nn.Parameter(torch.tensor(v_threshold)) else: self.v_threshold = torch.tensor(v_threshold) self.tau = tau self.delta_t = delta_t self.alpha = alpha self.v_threshold_lower_bound = v_threshold_lower_bound self.v_threshold_grad_scaling = v_threshold_grad_scaling @property def supported_backends(self): return "torch" def extra_repr(self): with torch.no_grad(): T = self.T v_threshold = self.v_threshold tau = self.tau delta_t = self.delta_t alpha = self.alpha v_threshold_lower_bound = self.v_threshold_lower_bound v_threshold_grad_scaling = self.v_threshold_grad_scaling return ( f", T={T}" + f", init_vth={v_threshold}" + f", tau={tau}" + f", dt={delta_t}" + f", alpha={alpha}" + f", vth_bound={v_threshold_lower_bound}" + f", vth_g_scale={v_threshold_grad_scaling}" )
[文档] def multi_step_forward(self, x_seq: torch.Tensor): with torch.no_grad(): self.v_threshold.copy_( F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound ) liffunc = self.DSRLIFFunction.apply y_seq = liffunc( x_seq, self.T, self.v_threshold, self.tau, self.delta_t, self.alpha, self.v_threshold_grad_scaling, ) return y_seq
[文档] @classmethod def weight_rate_spikes(cls, data, tau, delta_t): T = data.shape[0] chw = data.size()[2:] data_reshape = data.permute(list(range(1, len(chw) + 2)) + [0]) weight = torch.tensor( [ math.exp(-1 / tau * (delta_t * T - ii * delta_t)) for ii in range(1, T + 1) ] ).to(data_reshape.device) return (weight * data_reshape).sum(dim=len(chw) + 1) / weight.sum()
[文档] class DSRLIFFunction(torch.autograd.Function):
[文档] @staticmethod def forward( ctx, inp, T, v_threshold, tau, delta_t=0.05, alpha=0.3, v_threshold_grad_scaling=1.0, ): ctx.save_for_backward(inp) mem_potential = torch.zeros_like(inp[0]).to(inp.device) beta = math.exp(-delta_t / tau) spikes = [] for t in range(inp.size(0)): mem_potential = beta * mem_potential + (1 - beta) * inp[t] spike = ( (mem_potential >= alpha * v_threshold).float() * v_threshold ).float() mem_potential = mem_potential - spike spikes.append(spike / delta_t) output = torch.stack(spikes) ctx.T = T ctx.v_threshold = v_threshold ctx.tau = tau ctx.delta_t = delta_t ctx.v_threshold_grad_scaling = v_threshold_grad_scaling return output
[文档] @staticmethod def backward(ctx, grad_output): inp = ctx.saved_tensors[0] T = ctx.T v_threshold = ctx.v_threshold delta_t = ctx.delta_t tau = ctx.tau v_threshold_grad_scaling = ctx.v_threshold_grad_scaling input_rate_coding = DSRLIFNode.weight_rate_spikes(inp, tau, delta_t) grad_output_coding = ( DSRLIFNode.weight_rate_spikes(grad_output, tau, delta_t) * T ) indexes = (input_rate_coding > 0) & ( input_rate_coding < v_threshold / delta_t * tau ) input_grad = torch.zeros_like(grad_output_coding) input_grad[indexes] = grad_output_coding[indexes].clone() / tau input_grad = torch.stack([input_grad for _ in range(T)]) / T v_threshold_grad = grad_output_coding.clone() v_threshold_grad[input_rate_coding <= v_threshold / delta_t * tau] = 0 v_threshold_grad = ( torch.sum(v_threshold_grad) * delta_t * v_threshold_grad_scaling ) if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1: try: dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM) except Exception: raise RuntimeWarning( "Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel." ) return input_grad, None, v_threshold_grad, None, None, None, None