spikingjelly.activation_based.neuron.plif 源代码

import logging
import math
from typing import Optional

import torch
import torch.nn as nn

from .. import surrogate
from .base_node import BaseNode

try:
    from ..cuda_kernel.auto_cuda import neuron_kernel as ac_neuron_kernel
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.neuron: {e}")
    ac_neuron_kernel = None

try:
    from .. import triton_kernel
except BaseException as e:
    logging.info(f"spikingjelly.activation_based.neuron: {e}")
    triton_kernel = None


__all__ = ["ParametricLIFNode"]


[文档] class ParametricLIFNode(BaseNode): def __init__( self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.0, v_reset: Optional[float] = 0.0, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(), detach_reset: bool = False, step_mode="s", backend="torch", store_v_seq: bool = False, ): """ **API Language:** :ref:`中文 <ParametricLIFNode.__init__-cn>` | :ref:`English <ParametricLIFNode.__init__-en>` ---- .. _ParametricLIFNode.__init__-cn: * **中文** Parametric Leaky Integrate-and-Fire (PLIF) 神经元模型,提出自 `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_。可以看作是带漏电的积分器。其阈下神经动力学方程为: 若 ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) 若 ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。 :param init_tau: 膜电位时间常数的初始值 :type init_tau: float :param decay_input: 输入是否也会参与衰减 :type decay_input: bool :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: Optional[float] :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: 是否将 reset 过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 或 ``'triton'`` 后端速度更快。 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool ---- .. _ParametricLIFNode.__init__-en: * **English** The Parametric Leaky Integrate-and-Fire (PLIF) neuron, proposed in `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_, can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed: IF ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) IF ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter. :param init_tau: the initial value of membrane time constant :type init_tau: float :param decay_input: whether the input will decay :type decay_input: bool :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: Optional[float] :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend for this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` or ``'triton'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool :return: None :rtype: None """ assert isinstance(init_tau, float) and init_tau > 1.0 super().__init__( v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq, ) self.decay_input = decay_input init_w = -math.log(init_tau - 1.0) self.w = nn.Parameter(torch.as_tensor(init_w)) # as reciprocal_tau @property def supported_backends(self): if self.step_mode == "s": return ("torch",) elif self.step_mode == "m": return ("torch", "cupy", "triton", "inductor") else: raise ValueError(self.step_mode) def extra_repr(self): with torch.no_grad(): tau = 1.0 / self.w.sigmoid() return super().extra_repr() + f", tau={tau}"
[文档] def neuronal_charge(self, x: torch.Tensor): if self.decay_input: if self.v_reset is None or self.v_reset == 0.0: self.v = self.v + (x - self.v) * self.w.sigmoid() else: self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() else: if self.v_reset is None or self.v_reset == 0.0: self.v = self.v * (1.0 - self.w.sigmoid()) + x else: self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
def _build_inductor_multi_step_graph(self): store_v_seq = self.store_v_seq soft_reset = self.v_reset is None v_reset = 0.0 if soft_reset else self.v_reset surrogate_fn = self.surrogate_function v_threshold = self.v_threshold detach_reset = self.detach_reset decay_input = self.decay_input def _graph( x_seq: torch.Tensor, v_init: torch.Tensor, reciprocal_tau: torch.Tensor ): v = v_init spike_seq = torch.empty_like(x_seq) if store_v_seq: v_seq = torch.empty_like(x_seq) for t in range(x_seq.shape[0]): if decay_input: v = v + (x_seq[t] - (v - v_reset)) * reciprocal_tau else: v = v - (v - v_reset) * reciprocal_tau + x_seq[t] spike = surrogate_fn(v - v_threshold) spike_d = spike.detach() if detach_reset else spike if soft_reset: v = v - spike_d * v_threshold else: v = v_reset * spike_d + (1.0 - spike_d) * v spike_seq[t] = spike if store_v_seq: v_seq[t] = v if store_v_seq: return spike_seq, v, v_seq return spike_seq, v return _graph def _inductor_multi_step_forward(self, x_seq: torch.Tensor): self.v_float_to_tensor(x_seq[0]) x_seq = self._canonicalize_inductor_tensor(x_seq) v_init = self._canonicalize_inductor_tensor(self.v) reciprocal_tau = self._canonicalize_inductor_tensor(self.w.sigmoid().to(x_seq)) graph = self._compile_inductor_graph( ( "plif", self.store_v_seq, self.decay_input, self.v_threshold, self.v_reset, self.detach_reset, self._surrogate_inductor_cache_key(), self._inductor_runtime_cache_key(x_seq, v_init, reciprocal_tau), ), self._build_inductor_multi_step_graph(), ) out = graph(x_seq, v_init, reciprocal_tau) if self.store_v_seq: spike_seq, self.v, self.v_seq = out else: spike_seq, self.v = out return spike_seq
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == "inductor": return self._inductor_multi_step_forward(x_seq) elif self.backend == "torch": return super().multi_step_forward(x_seq) elif self.backend == "cupy": hard_reset = self.v_reset is not None if x_seq.dtype == torch.float: dtype = "float" elif x_seq.dtype == torch.half: dtype = "half2" else: raise NotImplementedError(x_seq.dtype) if self.forward_kernel is None or not self.forward_kernel.check_attributes( hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input ): self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel( decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype ) if ( self.backward_kernel is None or not self.backward_kernel.check_attributes( surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input, ) ): self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel( decay_input=self.decay_input, surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype, ) self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = ac_neuron_kernel.multistep_plif( x_seq=x_seq.flatten(1), v_init=self.v.flatten(0), decay=self.w.sigmoid().to(x_seq), decay_input=self.decay_input, v_threshold=self.v_threshold, v_reset=self.v_reset, detach_reset=self.detach_reset, surrogate_function=self.surrogate_function, forward_kernel=self.forward_kernel, backward_kernel=self.backward_kernel, ) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq elif self.backend == "triton": self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = triton_kernel.multistep_plif( x_seq, self.v, self.w.sigmoid().to(x_seq), self.decay_input, self.v_threshold, self.v_reset, self.detach_reset, self.surrogate_function, ) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)