spikingjelly.activation_based.neuron.base_node 源代码

from abc import abstractmethod
from typing import Optional

import torch
import torch.nn as nn

from .. import base, surrogate

__all__ = ["BaseNode", "NonSpikingBaseNode", "SimpleBaseNode"]


[文档] class SimpleBaseNode(base.MemoryModule): def __init__( self, v_threshold: float = 1.0, v_reset: Optional[float] = 0.0, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(), detach_reset: bool = False, step_mode="s", ): """ **API Language:** :ref:`中文 <SimpleBaseNode.__init__-cn>` | :ref:`English <SimpleBaseNode.__init__-en>` ---- .. _SimpleBaseNode.__init__-cn: * **中文** :class:`BaseNode` 的简化版,便于用户修改或扩展神经元。 :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压 :type v_reset: Optional[float] :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: 是否将 reset 过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 ``'s'`` (单步) 或 ``'m'`` (多步) :type step_mode: str ---- .. _SimpleBaseNode.__init__-en: * **English** A simple version of :class:`BaseNode`. Users can modify this neuron easily. :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer :type v_reset: Optional[float] :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be ``'s'`` (single-step) or ``'m'`` (multi-step) :type step_mode: str :return: None :rtype: None """ super().__init__() self.v_threshold = v_threshold self.v_reset = v_reset self.surrogate_function = surrogate_function self.detach_reset = detach_reset self.step_mode = step_mode self.register_memory(name="v", value=0.0)
[文档] def single_step_forward(self, x: torch.Tensor): self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
[文档] def neuronal_charge(self, x: torch.Tensor): raise NotImplementedError
[文档] def neuronal_fire(self): return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike): if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.v_reset is None: # soft reset self.v = self.v - self.v_threshold * spike_d else: # hard reset self.v = spike_d * self.v_reset + (1.0 - spike_d) * self.v
[文档] class BaseNode(base.MemoryModule): def __init__( self, v_threshold: float = 1.0, v_reset: Optional[float] = 0.0, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(), detach_reset: bool = False, step_mode="s", backend="torch", store_v_seq: bool = False, ): """ **API Language:** :ref:`中文 <BaseNode.__init__-cn>` | :ref:`English <BaseNode.__init__-en>` ---- .. _BaseNode.__init__-cn: * **中文** 可微分SNN神经元的基类神经元。 :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: Optional[float] :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 或 ``'triton'`` 后端速度更快。 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool ---- .. _BaseNode.__init__-en: * **English** This class is the base class of differentiable spiking neurons. :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: Optional[float] :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: surrogate.SurrogateFunctionBase :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` or ``'triton'`` backend will be faster :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool :return: None :rtype: None """ assert isinstance(v_reset, float) or v_reset is None assert isinstance(v_threshold, float) assert isinstance(detach_reset, bool) super().__init__() if v_reset is None: self.register_memory("v", 0.0) else: self.register_memory("v", v_reset) self.v_threshold = v_threshold self.v_reset = v_reset self.detach_reset = detach_reset self.surrogate_function = surrogate_function self.step_mode = step_mode self.backend = backend self.store_v_seq = store_v_seq # used in lava_exchange self.lava_s_cale = 1 << 6 # used for cupy backend self.forward_kernel = None self.backward_kernel = None self._inductor_compiled_graphs = {} @property def store_v_seq(self): return self._store_v_seq @store_v_seq.setter def store_v_seq(self, value: bool): self._store_v_seq = value if value: if not hasattr(self, "v_seq"): self.register_memory("v_seq", None) @staticmethod def apply_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float): v = (1.0 - spike) * v + spike * v_reset return v @staticmethod def apply_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float): v = v - spike * v_threshold return v
[文档] @abstractmethod def neuronal_charge(self, x: torch.Tensor): """ **API Language:** :ref:`中文 <BaseNode.neuronal_charge-cn>` | :ref:`English <BaseNode.neuronal_charge-en>` ---- .. _BaseNode.neuronal_charge-cn: * **中文** 定义神经元的充电差分方程。子类必须实现这个函数。 ---- .. _BaseNode.neuronal_charge-en: * **English** Define the charge difference equation. The sub-class must implement this function. """
[文档] def neuronal_fire(self): """ **API Language:** :ref:`中文 <BaseNode.neuronal_fire-cn>` | :ref:`English <BaseNode.neuronal_fire-en>` ---- .. _BaseNode.neuronal_fire-cn: * **中文** 根据当前神经元的电压、阈值,计算输出脉冲。 ---- .. _BaseNode.neuronal_fire-en: * **English** Calculate out spikes of neurons by their current membrane potential and threshold voltage. """ return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike): """ **API Language:** :ref:`中文 <BaseNode.neuronal_reset-cn>` | :ref:`English <BaseNode.neuronal_reset-en>` ---- .. _BaseNode.neuronal_reset-cn: * **中文** 根据当前神经元释放的脉冲,对膜电位进行重置。 ---- .. _BaseNode.neuronal_reset-en: * **English** Reset the membrane potential according to neurons' output spikes. """ if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.v_reset is None: # soft reset self.v = self.apply_soft_reset(self.v, spike_d, self.v_threshold) else: # hard reset self.v = self.apply_hard_reset(self.v, spike_d, self.v_reset)
def extra_repr(self): return f"v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}"
[文档] def single_step_forward(self, x: torch.Tensor): """ **API Language:** :ref:`中文 <BaseNode.single_step_forward-cn>` | :ref:`English <BaseNode.single_step_forward-en>` ---- .. _BaseNode.single_step_forward-cn: * **中文** 按照充电、放电、重置的顺序进行前向传播。 :param x: 输入到神经元的电压增量 :type x: torch.Tensor :return: 神经元的输出脉冲 :rtype: torch.Tensor ---- .. _BaseNode.single_step_forward-en: * **English** Forward by the order of ``neuronal_charge``, ``neuronal_fire``, and ``neuronal_reset``. :param x: increment of voltage inputted to neurons :type x: torch.Tensor :return: out spikes of neurons :rtype: torch.Tensor """ self.v_float_to_tensor(x) self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor): T = x_seq.shape[0] y_seq = [] if self.store_v_seq: v_seq = [] for t in range(T): y = self.single_step_forward(x_seq[t]) y_seq.append(y) if self.store_v_seq: v_seq.append(self.v) if self.store_v_seq: self.v_seq = torch.stack(v_seq) return torch.stack(y_seq)
def v_float_to_tensor(self, x: torch.Tensor): if isinstance(self.v, float): v_init = self.v self.v = torch.full_like(x, v_init, requires_grad=False) elif isinstance(self.v, torch.Tensor): if self.v.shape != x.shape: self.v = torch.full_like( x, self.v_reset if self.v_reset is not None else 0.0, requires_grad=False, ) elif self.v.dtype != x.dtype or self.v.device != x.device: self.v = self.v.to(dtype=x.dtype, device=x.device) def _compile_inductor_graph(self, cache_key, fn): compiled = self._inductor_compiled_graphs.get(cache_key) if compiled is not None: return compiled if not hasattr(torch, "compile"): raise RuntimeError( f"{self._get_name()} backend='inductor' requires torch.compile." ) compile_kwargs = {"backend": "inductor"} try: compiled = torch.compile( fn, **compile_kwargs, options={ "triton.cudagraphs": False, "triton.cudagraph_trees": False, }, ) except TypeError: compiled = torch.compile(fn, **compile_kwargs) self._inductor_compiled_graphs[cache_key] = compiled return compiled @staticmethod def _canonicalize_inductor_tensor(tensor: torch.Tensor) -> torch.Tensor: return tensor.contiguous() @staticmethod def _inductor_tensor_signature(tensor: torch.Tensor): return ( tuple(tensor.shape), tensor.ndim, str(tensor.dtype), tensor.device.type, tensor.device.index, tensor.is_contiguous(), bool(tensor.requires_grad), ) def _inductor_runtime_cache_key(self, *tensors: torch.Tensor): return tuple(self._inductor_tensor_signature(t) for t in tensors) def _surrogate_inductor_cache_key(self): sg = self.surrogate_function params = tuple(sorted(getattr(sg, "_sg_params", {}).items())) return ( type(sg).__module__, type(sg).__qualname__, getattr(sg, "spiking", None), params, ) def __getstate__(self): state = super().__getstate__() if "_inductor_compiled_graphs" in state: state["_inductor_compiled_graphs"] = {} return state def __setstate__(self, state): super().__setstate__(state) if not hasattr(self, "_inductor_compiled_graphs"): self._inductor_compiled_graphs = {}
[文档] class NonSpikingBaseNode(nn.Module, base.MultiStepModule): def __init__(self, decode: Optional[str] = None): """ :param decode: 解码方式。若不为 ``None``,在 ``forward`` 中将使用该方式对膜电位序列进行解码 :type decode: Optional[str] :return: None :rtype: None """ super().__init__() self.decode = decode
[文档] @abstractmethod def neuronal_charge(self, x: torch.Tensor): raise NotImplementedError
[文档] def forward(self, x_seq: torch.Tensor): self.v = torch.full_like(x_seq[0].data, fill_value=0.0) T = x_seq.shape[0] v_seq = [] for t in range(T): self.neuronal_charge(x_seq[t]) v_seq.append(self.v) if self.decode == "max-mem": return torch.max(torch.stack(v_seq, 0), 0).values elif self.decode == "max-abs-mem": v_stack = torch.stack(v_seq, 0) max_mem = torch.max(v_stack, 0).values min_mem = torch.min(v_stack, 0).values mem = max_mem * (max_mem.abs() > min_mem.abs()) + min_mem * ( max_mem.abs() <= min_mem.abs() ) return mem elif self.decode == "mean-mem": return torch.mean(torch.stack(v_seq, 0), 0) elif self.decode == "last_mem": return v_seq[-1] else: return v_seq