import logging
from typing import Optional
import torch
from .. import surrogate
from .base_node import BaseNode, NonSpikingBaseNode, SimpleBaseNode
try:
from ..cuda_kernel.auto_cuda import neuron_kernel as ac_neuron_kernel
from ..cuda_kernel.auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel
except BaseException as e:
logging.info(f"spikingjelly.activation_based.neuron: {e}")
ac_neuron_kernel = None
ss_ac_neuron_kernel = None
try:
from .. import triton_kernel
except BaseException as e:
logging.info(f"spikingjelly.activation_based.neuron: {e}")
triton_kernel = None
__all__ = ["SimpleIFNode", "IFNode", "NonSpikingIFNode"]
def _is_expected_triton_fallback_error(exc: RuntimeError) -> bool:
message = str(exc).lower()
expected_markers = (
"unsupported",
"not supported",
"no triton",
"triton is not installed",
"failed to import triton",
"dtype",
"invalid argument",
)
return any(marker in message for marker in expected_markers)
[文档]
class SimpleIFNode(SimpleBaseNode):
def __init__(
self,
v_threshold: float = 1.0,
v_reset: Optional[float] = 0.0,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(),
detach_reset: bool = False,
step_mode="s",
):
"""
**API Language:**
:ref:`中文 <SimpleIFNode.__init__-cn>` | :ref:`English <SimpleIFNode.__init__-en>`
----
.. _SimpleIFNode.__init__-cn:
* **中文**
:class:`IFNode` 的简化版实现。
:param v_threshold: 神经元阈值电压
:type v_threshold: float
:param v_reset: 神经元重置电压
:type v_reset: Optional[float]
:param surrogate_function: 替代梯度函数
:type surrogate_function: surrogate.SurrogateFunctionBase
:param detach_reset: 是否在反向传播时分离 reset 计算图
:type detach_reset: bool
:param step_mode: 步进模式,可为 ``"s"`` 或 ``"m"``
:type step_mode: str
----
.. _SimpleIFNode.__init__-en:
* **English**
A simple version of :class:`IFNode`.
:param v_threshold: Threshold voltage of the neuron
:type v_threshold: float
:param v_reset: Reset voltage of the neuron
:type v_reset: Optional[float]
:param surrogate_function: Surrogate gradient function
:type surrogate_function: surrogate.SurrogateFunctionBase
:param detach_reset: Whether to detach reset graph in backward
:type detach_reset: bool
:param step_mode: Step mode, either ``"s"`` or ``"m"``
:type step_mode: str
:return: None
:rtype: None
"""
super().__init__(
v_threshold, v_reset, surrogate_function, detach_reset, step_mode
)
[文档]
def neuronal_charge(self, x: torch.Tensor):
r"""
**API Language:**
:ref:`中文 <SimpleIFNode.neuronal_charge-cn>` | :ref:`English <SimpleIFNode.neuronal_charge-en>`
----
.. _SimpleIFNode.neuronal_charge-cn:
* **中文**
* **中文**
神经元充电的微分方程:
.. math::
H[t] = V[t-1] + X[t]
:param x: 输入电压
:type x: torch.Tensor
:return: None(膜电位更新存储在 ``self.v`` 中)
:rtype: None
----
.. _SimpleIFNode.neuronal_charge-en:
* **English**
* **English**
The differential equation for neuronal charge:
.. math::
H[t] = V[t-1] + X[t]
:param x: Input voltage
:type x: torch.Tensor
:return: None (membrane potential is stored in ``self.v``)
:rtype: None
"""
self.v = self.v + x
[文档]
class IFNode(BaseNode):
def __init__(
self,
v_threshold: float = 1.0,
v_reset: Optional[float] = 0.0,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.Sigmoid(),
detach_reset: bool = False,
step_mode="s",
backend="torch",
store_v_seq: bool = False,
):
"""
**API Language:**
:ref:`中文 <IFNode.__init__-cn>` | :ref:`English <IFNode.__init__-en>`
----
.. _IFNode.__init__-cn:
* **中文**
Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像 LIF 神经元那样衰减。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + X[t]
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: Optional[float]
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: surrogate.SurrogateFunctionBase
:param detach_reset: 是否将 reset 过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 或 ``'triton'`` 后端速度更快。
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
----
.. _IFNode.__init__-en:
* **English**
The Integrate-and-Fire neuron, which can be seen as an ideal integrator. The voltage of the IF neuron will not decay
as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + X[t]
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: Optional[float]
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: surrogate.SurrogateFunctionBase
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` or ``'triton'`` backend will be faster
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
:return: None
:rtype: None
"""
super().__init__(
v_threshold,
v_reset,
surrogate_function,
detach_reset,
step_mode,
backend,
store_v_seq,
)
@property
def supported_backends(self):
if self.step_mode == "s":
return ("torch", "cupy")
elif self.step_mode == "m":
return ("torch", "cupy", "triton", "inductor")
else:
raise ValueError(self.step_mode)
[文档]
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
@staticmethod
def _eval_single_step_forward(
x: torch.Tensor,
v: torch.Tensor,
v_threshold: float,
v_reset,
tau: Optional[float] = None,
decay_input: Optional[bool] = None,
):
"""Unified single-step eval (replaces jit_eval_single_step_forward_*)."""
v = v + x
spike = (v >= v_threshold).to(x)
v = (
(v - spike * v_threshold)
if v_reset is None
else (v_reset * spike + (1.0 - spike) * v)
)
return spike, v
@staticmethod
def _eval_multi_step_forward(
x_seq: torch.Tensor,
v: torch.Tensor,
v_threshold: float,
v_reset,
tau: Optional[float] = None,
decay_input: Optional[bool] = None,
store_v_seq: bool = False,
):
"""Unified multi-step eval (replaces jit_eval_multi_step_forward_*)."""
T = x_seq.shape[0]
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq) if store_v_seq else None
soft_reset = v_reset is None
_vr = 0.0 if soft_reset else v_reset
for t in range(T):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = (
(v - spike * v_threshold)
if soft_reset
else (_vr * spike + (1.0 - spike) * v)
)
spike_seq[t] = spike
if store_v_seq:
v_seq[t] = v
if store_v_seq:
return spike_seq, v, v_seq
return spike_seq, v
# kept for subclass backward-compatibility
@staticmethod
def jit_eval_single_step_forward_hard_reset(
x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float
):
v = v + x
spike = (v >= v_threshold).to(x)
v = v_reset * spike + (1.0 - spike) * v
return spike, v
@staticmethod
def jit_eval_single_step_forward_soft_reset(
x: torch.Tensor, v: torch.Tensor, v_threshold: float
):
v = v + x
spike = (v >= v_threshold).to(x)
v = v - spike * v_threshold
return spike, v
def _build_inductor_multi_step_graph(self):
store_v_seq = self.store_v_seq
soft_reset = self.v_reset is None
v_reset = 0.0 if soft_reset else self.v_reset
surrogate_fn = self.surrogate_function
v_threshold = self.v_threshold
detach_reset = self.detach_reset
def _graph(x_seq: torch.Tensor, v_init: torch.Tensor):
v = v_init
spike_seq = torch.empty_like(x_seq)
if store_v_seq:
v_seq = torch.empty_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = surrogate_fn(v - v_threshold)
spike_d = spike.detach() if detach_reset else spike
if soft_reset:
v = v - spike_d * v_threshold
else:
v = spike_d * v_reset + (1.0 - spike_d) * v
spike_seq[t] = spike
if store_v_seq:
v_seq[t] = v
if store_v_seq:
return spike_seq, v, v_seq
return spike_seq, v
return _graph
def _inductor_multi_step_forward(self, x_seq: torch.Tensor):
self.v_float_to_tensor(x_seq[0])
x_seq = self._canonicalize_inductor_tensor(x_seq)
v_init = self._canonicalize_inductor_tensor(self.v)
graph = self._compile_inductor_graph(
(
"if",
self.store_v_seq,
self.v_threshold,
self.v_reset,
self.detach_reset,
self._surrogate_inductor_cache_key(),
self._inductor_runtime_cache_key(x_seq, v_init),
),
self._build_inductor_multi_step_graph(),
)
out = graph(x_seq, v_init)
if self.store_v_seq:
spike_seq, self.v, self.v_seq = out
else:
spike_seq, self.v = out
return spike_seq
[文档]
def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == "inductor":
return self._inductor_multi_step_forward(x_seq)
if self.training:
if self.backend == "torch":
return super().multi_step_forward(x_seq)
elif self.backend == "cupy":
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = "float"
elif x_seq.dtype == torch.half:
dtype = "half2"
else:
raise NotImplementedError(x_seq.dtype)
if (
self.forward_kernel is None
or not self.forward_kernel.check_attributes(
hard_reset=hard_reset, dtype=dtype
)
):
self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel(
hard_reset=hard_reset, dtype=dtype
)
if (
self.backward_kernel is None
or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype,
)
):
self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel(
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype,
)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.multistep_if(
x_seq=x_seq.flatten(1),
v_init=self.v.flatten(0),
v_threshold=self.v_threshold,
v_reset=self.v_reset,
detach_reset=self.detach_reset,
surrogate_function=self.surrogate_function,
forward_kernel=self.forward_kernel,
backward_kernel=self.backward_kernel,
)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
elif self.backend == "triton":
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = triton_kernel.multistep_if(
x_seq,
self.v,
self.v_threshold,
self.v_reset,
self.detach_reset,
self.surrogate_function,
)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x_seq[0])
if x_seq.is_cuda and getattr(self.surrogate_function, "spiking", True):
try:
spike_seq, v_seq = triton_kernel.multistep_if(
x_seq,
self.v,
self.v_threshold,
self.v_reset,
self.detach_reset,
self.surrogate_function,
)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1]
else:
self.v = v_seq[-1].clone()
return spike_seq
except (NotImplementedError, AttributeError, TypeError, KeyError) as e:
logging.debug("Falling back from Triton IF kernel in eval: %s", e)
except RuntimeError as e:
if _is_expected_triton_fallback_error(e):
logging.debug(
"Falling back from Triton IF kernel in eval: %s", e
)
else:
logging.exception(
"Unexpected Triton IF kernel failure in eval "
"(dtype=%s, surrogate=%s)",
x_seq.dtype,
type(self.surrogate_function).__name__,
)
raise
# torch & cupy backend:
out = self._eval_multi_step_forward(
x_seq,
self.v,
self.v_threshold,
self.v_reset,
store_v_seq=self.store_v_seq,
)
if self.store_v_seq:
spike_seq, self.v, self.v_seq = out
else:
spike_seq, self.v = out
return spike_seq
[文档]
def single_step_forward(self, x: torch.Tensor):
if self.training:
if self.backend == "torch":
return super().single_step_forward(x)
elif self.backend == "cupy":
hard_reset = self.v_reset is not None
if x.dtype == torch.float:
dtype = "float"
elif x.dtype == torch.half:
dtype = "half2"
else:
raise NotImplementedError(x.dtype)
if (
self.forward_kernel is None
or not self.forward_kernel.check_attributes(
hard_reset=hard_reset, dtype=dtype
)
):
self.forward_kernel = ss_ac_neuron_kernel.IFNodeFPKernel(
hard_reset=hard_reset, dtype=dtype
)
if (
self.backward_kernel is None
or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype,
)
):
self.backward_kernel = ss_ac_neuron_kernel.IFNodeBPKernel(
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype,
)
self.v_float_to_tensor(x)
spike, v = ss_ac_neuron_kernel.ss_if_step(
x.flatten(0),
self.v.flatten(0),
self.v_threshold,
self.v_reset,
self.forward_kernel,
self.backward_kernel,
)
spike = spike.reshape(x.shape)
v = v.reshape(x.shape)
self.v = v
return spike
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x)
spike, self.v = self._eval_single_step_forward(
x,
self.v,
self.v_threshold,
self.v_reset,
)
return spike
[文档]
class NonSpikingIFNode(NonSpikingBaseNode):
def __init__(self, decode: Optional[str] = None):
"""
**API Language:**
:ref:`中文 <NonSpikingIFNode.__init__-cn>` | :ref:`English <NonSpikingIFNode.__init__-en>`
----
.. _NonSpikingIFNode.__init__-cn:
* **中文**
* **中文**
不发放脉冲的 IF 节点,输出膜电位(或根据 ``decode`` 进行解码)。
:param decode: 非脉冲输出解码方式,见 :class:`NonSpikingBaseNode`
:type decode: Optional[str]
----
.. _NonSpikingIFNode.__init__-en:
* **English**
* **English**
Non-spiking IF node that outputs membrane potential (or decoded outputs specified by ``decode``).
:param decode: Decoding mode for non-spiking outputs, see :class:`NonSpikingBaseNode`
:type decode: Optional[str]
:return: None
:rtype: None
"""
super().__init__(decode)
[文档]
def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x