from abc import abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from . import surrogate, base
from .. import configure
import math
import numpy as np
import logging
from .auto_cuda import neuron_kernel as ac_neuron_kernel
try:
import cupy
from . import neuron_kernel, cuda_utils
except BaseException as e:
logging.info(f'spikingjelly.activation_based.neuron: {e}')
cupy = None
neuron_kernel = None
cuda_utils = None
[文档]class BaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <BaseNode.__init__-en>`
.. _BaseNode.__init__-cn:
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
可微分SNN神经元的基类神经元。
* :ref:`中文API <BaseNode.__init__-cn>`
.. _BaseNode.__init__-en:
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
This class is the base class of differentiable spiking neurons.
"""
assert isinstance(v_reset, float) or v_reset is None
assert isinstance(v_threshold, float)
assert isinstance(detach_reset, bool)
super().__init__()
if v_reset is None:
self.register_memory('v', 0.)
else:
self.register_memory('v', v_reset)
self.v_threshold = v_threshold
self.v_reset = v_reset
self.detach_reset = detach_reset
self.surrogate_function = surrogate_function
self.step_mode = step_mode
self.backend = backend
self.store_v_seq = store_v_seq
# used in lava_exchange
self.lava_s_cale = 1 << 6
# used for cupy backend
self.forward_kernel = None
self.backward_kernel = None
@property
def store_v_seq(self):
return self._store_v_seq
@store_v_seq.setter
def store_v_seq(self, value: bool):
self._store_v_seq = value
if value:
if not hasattr(self, 'v_seq'):
self.register_memory('v_seq', None)
[文档] @staticmethod
@torch.jit.script
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
v = (1. - spike) * v + spike * v_reset
return v
[文档] @staticmethod
@torch.jit.script
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
v = v - spike * v_threshold
return v
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
"""
* :ref:`API in English <BaseNode.neuronal_charge-en>`
.. _BaseNode.neuronal_charge-cn:
定义神经元的充电差分方程。子类必须实现这个函数。
* :ref:`中文API <BaseNode.neuronal_charge-cn>`
.. _BaseNode.neuronal_charge-en:
Define the charge difference equation. The sub-class must implement this function.
"""
raise NotImplementedError
[文档] def neuronal_fire(self):
"""
* :ref:`API in English <BaseNode.neuronal_fire-en>`
.. _BaseNode.neuronal_fire-cn:
根据当前神经元的电压、阈值,计算输出脉冲。
* :ref:`中文API <BaseNode.neuronal_fire-cn>`
.. _BaseNode.neuronal_fire-en:
Calculate out spikes of neurons by their current membrane potential and threshold voltage.
"""
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
"""
* :ref:`API in English <BaseNode.neuronal_reset-en>`
.. _BaseNode.neuronal_reset-cn:
根据当前神经元释放的脉冲,对膜电位进行重置。
* :ref:`中文API <BaseNode.neuronal_reset-cn>`
.. _BaseNode.neuronal_reset-en:
Reset the membrane potential according to neurons' output spikes.
"""
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)
else:
# hard reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
[文档] def single_step_forward(self, x: torch.Tensor):
"""
* :ref:`API in English <BaseNode.single_step_forward-en>`
.. _BaseNode.single_step_forward-cn:
:param x: 输入到神经元的电压增量
:type x: torch.Tensor
:return: 神经元的输出脉冲
:rtype: torch.Tensor
按照充电、放电、重置的顺序进行前向传播。
* :ref:`中文API <BaseNode.single_step_forward-cn>`
.. _BaseNode.single_step_forward-en:
:param x: increment of voltage inputted to neurons
:type x: torch.Tensor
:return: out spikes of neurons
:rtype: torch.Tensor
Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`.
"""
self.v_float_to_tensor(x)
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
y_seq = []
if self.store_v_seq:
v_seq = []
for t in range(T):
y = self.single_step_forward(x_seq[t])
y_seq.append(y)
if self.store_v_seq:
v_seq.append(self.v)
if self.store_v_seq:
self.v_seq = torch.stack(v_seq)
return torch.stack(y_seq)
[文档] def v_float_to_tensor(self, x: torch.Tensor):
if isinstance(self.v, float):
v_init = self.v
self.v = torch.full_like(x.data, v_init)
[文档]class AdaptBaseNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
v_rest: float = 0., w_rest: float = 0., tau_w: float = 2., a: float = 0., b: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
# b: jump amplitudes
# a: subthreshold coupling
assert isinstance(w_rest, float)
assert isinstance(v_rest, float)
assert isinstance(tau_w, float)
assert isinstance(a, float)
assert isinstance(b, float)
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.register_memory('w', w_rest)
self.w_rest = w_rest
self.v_rest = v_rest
self.tau_w = tau_w
self.a = a
self.b = b
[文档] @staticmethod
@torch.jit.script
def jit_neuronal_adaptation(w: torch.Tensor, tau_w: float, a: float, v_rest: float, v: torch.Tensor):
return w + 1. / tau_w * (a * (v - v_rest) - w)
[文档] def neuronal_adaptation(self):
"""
* :ref:`API in English <AdaptBaseNode.neuronal_adaptation-en>`
.. _AdaptBaseNode.neuronal_adaptation-cn:
脉冲触发的适应性电流的更新
* :ref:`中文API <AdaptBaseNode.neuronal_adaptation-cn>`
.. _AdaptBaseNode.neuronal_adaptation-en:
Spike-triggered update of adaptation current.
"""
self.w = self.jit_neuronal_adaptation(self.w, self.tau_w, self.a, self.v_rest, self.v)
[文档] @staticmethod
@torch.jit.script
def jit_hard_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_reset: float, b: float,
spike: torch.Tensor):
v = (1. - spike_d) * v + spike * v_reset
w = w + b * spike
return v, w
[文档] @staticmethod
@torch.jit.script
def jit_soft_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_threshold: float, b: float,
spike: torch.Tensor):
v = v - spike_d * v_threshold
w = w + b * spike
return v, w
[文档] def neuronal_reset(self, spike):
"""
* :ref:`API in English <AdaptBaseNode.neuronal_reset-en>`
.. _AdaptBaseNode.neuronal_reset-cn:
根据当前神经元释放的脉冲,对膜电位进行重置。
* :ref:`中文API <AdaptBaseNode.neuronal_reset-cn>`
.. _AdaptBaseNode.neuronal_reset-en:
Reset the membrane potential according to neurons' output spikes.
"""
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.v_reset is None:
# soft reset
self.v, self.w = self.jit_soft_reset(self.v, self.w, spike_d, self.v_threshold, self.b, spike)
else:
# hard reset
self.v, self.w = self.jit_hard_reset(self.v, self.w, spike_d, self.v_reset, self.b, spike)
[文档] def single_step_forward(self, x: torch.Tensor):
self.v_float_to_tensor(x)
self.w_float_to_tensor(x)
self.neuronal_charge(x)
self.neuronal_adaptation()
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def w_float_to_tensor(self, x: torch.Tensor):
if isinstance(self.w, float):
w_init = self.w
self.w = torch.full_like(x.data, fill_value=w_init)
[文档]class IFNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <IFNode.__init__-en>`
.. _IFNode.__init__-cn:
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + X[t]
* :ref:`中文API <IFNode.__init__-cn>`
.. _IFNode.__init__-en:
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay
as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + X[t]
"""
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
[文档] @staticmethod
@torch.jit.script
def jit_eval_single_step_forward_hard_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v_reset * spike + (1. - spike) * v
return spike, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_single_step_forward_soft_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v - spike * v_threshold
return spike, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_hard_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float,
v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
return spike_seq, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float,
v_reset: float):
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
v_seq[t] = v
return spike_seq, v, v_seq
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_soft_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v - spike * v_threshold
spike_seq[t] = spike
return spike_seq, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float):
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v - spike * v_threshold
spike_seq[t] = spike
v_seq[t] = v
return spike_seq, v, v_seq
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype):
self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype):
self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset, self.forward_kernel,
self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq,
self.v,
self.v_threshold)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq,
self.v,
self.v_threshold,
self.v_reset)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold,
self.v_reset)
return spike_seq
[文档] def single_step_forward(self, x: torch.Tensor):
if self.training:
return super().single_step_forward(x)
else:
self.v_float_to_tensor(x)
if self.v_reset is None:
spike, self.v = self.jit_eval_single_step_forward_soft_reset(x, self.v, self.v_threshold)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset(x, self.v, self.v_threshold, self.v_reset)
return spike
[文档]class LIFNode(BaseNode):
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <LIFNode.__init__-en>`
.. _LIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
* :ref:`中文API <LIFNode.__init__-cn>`
.. _LIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
"""
assert isinstance(tau, float) and tau > 1.
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.decay_input = decay_input
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
[文档] def single_step_forward(self, x: torch.Tensor):
if self.training:
return super().single_step_forward(x)
else:
self.v_float_to_tensor(x)
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel(decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel(decay_input=self.decay_input, surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.LIFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset, 1. / self.tau,
self.forward_kernel,
self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.decay_input:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_decay_input(x_seq, self.v,
self.v_threshold,
self.tau)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_no_decay_input(x_seq, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.v_reset, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_decay_input(x_seq, self.v,
self.v_threshold,
self.v_reset,
self.tau)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.v_reset, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_no_decay_input(x_seq, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike_seq
[文档]class ParametricLIFNode(BaseNode):
def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <ParametricLIFNode.__init__-en>`
.. _ParametricLIFNode.__init__-cn:
:param init_tau: 膜电位时间常数的初始值
:type init_tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。
这个选项的优先权高于 ``backend``
:type cupy_fp32_inference: bool
`Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
* :ref:`中文API <ParametricLIFNode.__init__-cn>`
.. _ParametricLIFNode.__init__-en:
:param init_tau: the initial value of membrane time constant
:type init_tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
module will use `cupy` to accelerate. This option has priority over ``backend``
:type cupy_fp32_inference: bool
The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
"""
assert isinstance(init_tau, float) and init_tau > 1.
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.decay_input = decay_input
init_w = - math.log(init_tau - 1.)
self.w = nn.Parameter(torch.as_tensor(init_w))
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v + (x - self.v) * self.w.sigmoid()
else:
self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v * (1. - self.w.sigmoid()) + x
else:
self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel(decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.ParametricLIFNodeATGF.apply(
x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.w.sigmoid().to(x_seq),
self.forward_kernel, self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class QIFNode(BaseNode):
def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0.,
v_reset: float = -0.1,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <QIFNode.__init__-en>`
.. _QIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param v_c: 关键电压
:type v_c: float
:param a0:
:type a0: float
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_rest: 静息电位
:type v_rest: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
* :ref:`中文API <QIFNode.__init__-cn>`
.. _QIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param v_c: critical voltage
:type v_c: float
:param a0:
:type a0: float
:param v_threshold: threshold voltage of neurons
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model.
The subthreshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
"""
assert isinstance(tau, float) and tau > 1.
if v_reset is not None:
assert v_threshold > v_reset
assert v_rest >= v_reset
assert a0 > 0
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.v_c = v_c
self.v_rest = v_rest
self.a0 = a0
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = neuron_kernel.MultiStepQIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class EIFNode(BaseNode):
def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
v_rest: float = 0., v_reset: float = -0.1,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <EIFNode.__init__-en>`
.. _EIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param delta_T: 陡峭度参数
:type delta_T: float
:param theta_rh: 基强度电压阈值
:type theta_rh: float
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
* :ref:`中文API <EIFNode.__init__-cn>`
.. _EIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param delta_T: sharpness parameter
:type delta_T: float
:param theta_rh: rheobase threshold
:type theta_rh: float
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`.
The subthreshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
"""
assert isinstance(tau, float) and tau > 1.
if v_reset is not None:
assert v_threshold > v_reset
assert v_rest >= v_reset
assert delta_T > 0
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.delta_T = delta_T
self.v_rest = v_rest
self.theta_rh = theta_rh
[文档] def neuronal_charge(self, x: torch.Tensor):
with torch.no_grad():
if not isinstance(self.v, torch.Tensor):
self.v = torch.as_tensor(self.v, device=x.device)
self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp(
(self.v - self.theta_rh) / self.delta_T)) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class IzhikevichNode(AdaptBaseNode):
def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1.,
v_reset: float = 0., v_rest: float = -0.1, w_rest: float = 0., tau_w: float = 2., a: float = 0.,
b: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
assert isinstance(tau, float) and tau > 1.
assert a0 > 0
super().__init__(v_threshold, v_reset, v_rest, w_rest, tau_w, a, b, surrogate_function, detach_reset, step_mode,
backend, store_v_seq)
self.tau = tau
self.v_c = v_c
self.a0 = a0
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c) - self.w) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
self.w_float_to_tensor(x_seq[0])
spike_seq, v_seq, w_seq = neuron_kernel.MultiStepIzhikevichNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.w.flatten(0), self.tau, self.v_threshold, self.v_reset,
self.v_rest, self.a, self.b, self.tau_w,
self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
w_seq = w_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
self.w = w_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class LIAFNode(LIFNode):
def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs):
"""
* :ref:`API in English <LIAFNode.__init__-en>`
.. _LIAFNode.__init__-cn:
:param act: 激活函数
:type act: Callable
:param threshold_related: 是否使用阈值依赖模式 (TR mode). 若为 ``True`` 则 ``y = act(h - v_th)``,
否则 ``y = act(h)``
:type threshold_related: bool
`LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_ 提出的LIAF神经元。LIAFNode和LIFNode的行为相同,但输出是 ``self.act(...)`` 而非脉冲。
.. Warning::
The outputs of this neurons layer are not binary spikes.
* :ref:`中文API <LIAFNode.__init__-cn>`
.. _LIAFNode.__init__-en:
:param act: the activation function
:type act: Callable
:param threshold_related: whether the neuron uses threshold related (TR mode). If ``True``, ``y = act(h - v_th)``,
otherwise ``y = act(h)``
:type threshold_related: bool
Other parameters in `*args, **kwargs` are same with :class:`LIFNode`.
The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_. LIAFNode has the same behavior as LIFNode, but outputs ``self.act(...)``
rather than spikes.
.. admonition:: Warning
:class: warning
The outputs of this neurons layer are not binary spikes.
"""
super().__init__(*args, **kwargs)
self.act = act
self.threshold_related = threshold_related
assert self.backend == 'torch', "LIAFNode only supports for backend='torch'!"
assert self.single_step_cupy_fp32_inference == False, "LIAFNode does not support for single_step_cupy_fp32_inference!"
@property
def supported_backends(self):
return ('torch',)
[文档] def single_step_forward(self, x: torch.Tensor):
self.neuronal_charge(x)
if self.threshold_related:
y = self.act(self.v - self.v_threshold)
else:
y = self.act(self.v)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return y
[文档]class KLIFNode(BaseNode):
def __init__(self, scale_reset: bool=False, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <KLIFNode.__init__-en>`
.. _KLIFNode.__init__-cn:
:param scale_reset: 是否在 ``neuronal_reset`` 时将 ``v`` 进行缩放
:type scale_reset: bool
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
`KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_ 提出的K-based Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
注意,KLIF神经元的放电和重置与普通的神经元不同,为:
.. math::
F[t] &= \\mathrm{ReLU}(kH[t])
S[t] &= \\Theta(F[t] - V_{th})
如果 ``scale_reset == False``,则
.. math::
V[t] = \\begin{cases}
F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\
F[t] - S[t]V_{th}, soft~~reset
\\end{cases}
如果 ``scale_reset == True``,则
.. math::
V[t] = \\begin{cases}
\\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\
\\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset
\\end{cases}
* :ref:`中文API <KLIFNode.__init__-cn>`
.. _KLIFNode.__init__-en:
:param scale_reset: whether scale ``v`` in ``neuronal_reset``
:type scale_reset: bool
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The K-based Leaky Integrate-and-Fire neuron proposed by `KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_, which can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
Note that the neuronal fire and reset of the KLIF neuron is different from native neurons:
.. math::
F[t] &= \\mathrm{ReLU}(kH[t])
S[t] &= \\Theta(F[t] - V_{th})
If ``scale_reset == False``, then
.. math::
V[t] = \\begin{cases}
F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\
F[t] - S[t]V_{th}, soft~~reset
\\end{cases}
Elif ``scale_reset == True``, then
.. math::
V[t] = \\begin{cases}
\\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\
\\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset
\\end{cases}
"""
assert isinstance(tau, float) and tau > 1.
if backend == 'cupy':
raise NotImplementedError("The CuPy backend for the KLIF neuron has not been implemented!")
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.scale_reset = scale_reset
self.tau = tau
self.decay_input = decay_input
self.k = nn.Parameter(torch.as_tensor(1.))
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau, self.k)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau, self.k)
[文档] def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.scale_reset:
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) / self.k
else:
# hard reset
self.v = self.jit_hard_reset(self.v / self.k, spike_d, self.v_reset)
else:
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)
else:
# hard reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)