spikingjelly.activation_based.neuron.adapt 源代码

import logging
from typing import Optional

import torch

from .. import surrogate
from .base_node import BaseNode

try:
    from .. import cuda_kernel
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.neuron: {e}")
    cuda_kernel = None


__all__ = ["AdaptBaseNode", "IzhikevichNode"]


[文档] class AdaptBaseNode(BaseNode): def __init__( self, v_threshold: float = 1.0, v_reset: Optional[float] = 0.0, v_rest: float = 0.0, w_rest: float = 0.0, tau_w: float = 2.0, a: float = 0.0, b: float = 0.0, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(), detach_reset: bool = False, step_mode="s", backend="torch", store_v_seq: bool = False, ): """ **API Language:** :ref:`中文 <AdaptBaseNode.__init__-cn>` | :ref:`English <AdaptBaseNode.__init__-en>` ---- .. _AdaptBaseNode.__init__-cn: * **中文** 带适应性电流的脉冲神经元基类。在 :class:`BaseNode` 的基础上增加了膜电位恢复变量 :math:`w`,用于实现神经元适应性和脉冲频率适应性。 :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 重置电压。若为 ``None`` 则使用软重置 :type v_reset: Optional[float] :param v_rest: 静息电位 :type v_rest: float :param w_rest: 适应性电流的静息值 :type w_rest: float :param tau_w: 适应性电流的时间常数 :type tau_w: float :param a: 阈下耦合参数,控制亚阈值电位对适应电流的影响 :type a: float :param b: 脉冲触发跳跃幅度,控制脉冲后适应电流的增加量 :type b: float :param surrogate_function: 替代梯度函数 :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: 是否将重置过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可为 ``'s'`` (单步) 或 ``'m'`` (多步) :type step_mode: str :param backend: 后端 :type backend: str :param store_v_seq: 是否保存中间电压值 :type store_v_seq: bool ---- .. _AdaptBaseNode.__init__-en: * **English** Base neuron with adaptation current. Extends :class:`BaseNode` with a membrane recovery variable :math:`w` that provides spike-frequency adaptation. :param v_threshold: Threshold voltage of the neuron :type v_threshold: float :param v_reset: Reset voltage. If ``None``, uses soft reset :type v_reset: Optional[float] :param v_rest: Resting potential :type v_rest: float :param w_rest: Resting value of the adaptation current :type w_rest: float :param tau_w: Time constant of the adaptation current :type tau_w: float :param a: Subthreshold coupling parameter, controls subthreshold influence on adaptation current :type a: float :param b: Spike-triggered jump amplitude, controls adaptation current increase after each spike :type b: float :param surrogate_function: Surrogate gradient function :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: Whether to detach the reset computation graph :type detach_reset: bool :param step_mode: Step mode, can be ``'s'`` (single-step) or ``'m'`` (multi-step) :type step_mode: str :param backend: Backend for computation :type backend: str :param store_v_seq: Whether to store intermediate membrane potentials :type store_v_seq: bool :return: None :rtype: None """ # b: jump amplitudes # a: subthreshold coupling assert isinstance(w_rest, float) assert isinstance(v_rest, float) assert isinstance(tau_w, float) assert isinstance(a, float) assert isinstance(b, float) super().__init__( v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq, ) self.register_memory("w", w_rest) self.w_rest = w_rest self.v_rest = v_rest self.tau_w = tau_w self.a = a self.b = b
[文档] @staticmethod def jit_neuronal_adaptation( w: torch.Tensor, tau_w: float, a: float, v_rest: float, v: torch.Tensor ): return w + 1.0 / tau_w * (a * (v - v_rest) - w)
[文档] def neuronal_adaptation(self): """ **API Language:** :ref:`中文 <AdaptBaseNode.neuronal_adaptation-cn>` | :ref:`English <AdaptBaseNode.neuronal_adaptation-en>` ---- .. _AdaptBaseNode.neuronal_adaptation-cn: * **中文** 脉冲触发的适应性电流的更新 ---- .. _AdaptBaseNode.neuronal_adaptation-en: * **English** Spike-triggered update of adaptation current. """ self.w = self.jit_neuronal_adaptation( self.w, self.tau_w, self.a, self.v_rest, self.v )
[文档] @staticmethod def apply_hard_reset( v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_reset: float, b: float, spike: torch.Tensor, ): v = (1.0 - spike_d) * v + spike * v_reset w = w + b * spike return v, w
[文档] @staticmethod def apply_soft_reset( v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_threshold: float, b: float, spike: torch.Tensor, ): v = v - spike_d * v_threshold w = w + b * spike return v, w
[文档] def neuronal_reset(self, spike): """ **API Language:** :ref:`中文 <AdaptBaseNode.neuronal_reset-cn>` | :ref:`English <AdaptBaseNode.neuronal_reset-en>` ---- .. _AdaptBaseNode.neuronal_reset-cn: * **中文** 根据当前神经元释放的脉冲,对膜电位进行重置。 ---- .. _AdaptBaseNode.neuronal_reset-en: * **English** Reset the membrane potential according to neurons' output spikes. """ if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.v_reset is None: # soft reset self.v, self.w = self.apply_soft_reset( self.v, self.w, spike_d, self.v_threshold, self.b, spike ) else: # hard reset self.v, self.w = self.apply_hard_reset( self.v, self.w, spike_d, self.v_reset, self.b, spike )
def extra_repr(self): return ( super().extra_repr() + f", v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}" )
[文档] def single_step_forward(self, x: torch.Tensor): """ **API Language:** :ref:`中文 <AdaptBaseNode.single_step_forward-cn>` | :ref:`English <AdaptBaseNode.single_step_forward-en>` ---- .. _AdaptBaseNode.single_step_forward-cn: * **中文** 按照充电、适应、放电、重置的顺序进行前向传播。 :param x: 输入到神经元的电压增量 :type x: torch.Tensor :return: 神经元的输出脉冲 :rtype: torch.Tensor ---- .. _AdaptBaseNode.single_step_forward-en: * **English** Forward by the order of ``neuronal_charge``, ``neuronal_adaptation``, ``neuronal_fire``, and ``neuronal_reset``. :param x: increment of voltage inputted to neurons :type x: torch.Tensor :return: out spikes of neurons :rtype: torch.Tensor """ self.v_float_to_tensor(x) self.w_float_to_tensor(x) self.neuronal_charge(x) self.neuronal_adaptation() spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
def w_float_to_tensor(self, x: torch.Tensor): if isinstance(self.w, float): w_init = self.w self.w = torch.full_like(x.data, fill_value=w_init)
[文档] class IzhikevichNode(AdaptBaseNode): def __init__( self, tau: float = 2.0, v_c: float = 0.8, a0: float = 1.0, v_threshold: float = 1.0, v_reset: Optional[float] = 0.0, v_rest: float = -0.1, w_rest: float = 0.0, tau_w: float = 2.0, a: float = 0.0, b: float = 0.0, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(), detach_reset: bool = False, step_mode="s", backend="torch", store_v_seq: bool = False, ): """ **API Language:** :ref:`中文 <IzhikevichNode.__init__-cn>` | :ref:`English <IzhikevichNode.__init__-en>` ---- .. _IzhikevichNode.__init__-cn: * **中文** Izhikevich 脉冲神经元模型。参数 :math:`\\tau` 控制膜电位时间常数,:math:`v_c` 和 :math:`a0` 控制非线性 dynamics。 继承了 :class:`AdaptBaseNode` 的适应性电流机制。 :param tau: 膜电位时间常数 :type tau: float :param v_c: 截止电压,控制非线性响应的阈值 :type v_c: float :param a0: 非线性系数 :type a0: float :param v_threshold: 阈值电压 :type v_threshold: float :param v_reset: 重置电压 :type v_reset: Optional[float] :param v_rest: 静息电位 :type v_rest: float :param w_rest: 适应性电流静息值 :type w_rest: float :param tau_w: 适应性电流时间常数 :type tau_w: float :param a: 阈下耦合参数 :type a: float :param b: 脉冲触发跳跃幅度 :type b: float :param surrogate_function: 替代梯度函数 :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: 是否分离重置计算图 :type detach_reset: bool :param step_mode: 步进模式 :type step_mode: str :param backend: 后端 :type backend: str :param store_v_seq: 是否保存中间电压值 :type store_v_seq: bool ---- .. _IzhikevichNode.__init__-en: * **English** Izhikevich spiking neuron model. The parameters :math:`\\tau`, :math:`v_c`, and :math:`a0` control membrane dynamics. Inherits the adaptation current mechanism from :class:`AdaptBaseNode`. :param tau: Membrane time constant :type tau: float :param v_c: Cutoff voltage controlling the nonlinear response threshold :type v_c: float :param a0: Nonlinear coefficient :type a0: float :param v_threshold: Threshold voltage :type v_threshold: float :param v_reset: Reset voltage :type v_reset: Optional[float] :param v_rest: Resting potential :type v_rest: float :param w_rest: Resting value of adaptation current :type w_rest: float :param tau_w: Time constant of adaptation current :type tau_w: float :param a: Subthreshold coupling parameter :type a: float :param b: Spike-triggered jump amplitude :type b: float :param surrogate_function: Surrogate gradient function :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: Whether to detach reset computation graph :type detach_reset: bool :param step_mode: Step mode, ``'s'`` or ``'m'`` :type step_mode: str :param backend: Backend :type backend: str :param store_v_seq: Whether to store intermediate membrane potentials :type store_v_seq: bool :return: None :rtype: None """ assert isinstance(tau, float) and tau > 1.0 assert a0 > 0 super().__init__( v_threshold, v_reset, v_rest, w_rest, tau_w, a, b, surrogate_function, detach_reset, step_mode, backend, store_v_seq, ) self.tau = tau self.v_c = v_c self.a0 = a0 def extra_repr(self): return super().extra_repr() + f", tau={self.tau}, v_c={self.v_c}, a0={self.a0}"
[文档] def neuronal_charge(self, x: torch.Tensor): self.v = ( self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c) - self.w) / self.tau )
@property def supported_backends(self): if self.step_mode == "s": return ("torch",) elif self.step_mode == "m": return ("torch", "cupy") else: raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == "torch": return super().multi_step_forward(x_seq) elif self.backend == "cupy": self.v_float_to_tensor(x_seq[0]) self.w_float_to_tensor(x_seq[0]) spike_seq, v_seq, w_seq = cuda_kernel.multistep_izhikevich_ptt( x_seq.flatten(1), self.v.flatten(0), self.w.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, self.a, self.b, self.tau_w, self.v_c, self.a0, self.detach_reset, self.surrogate_function, ) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) w_seq = w_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() self.w = w_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)