spikingjelly.activation_based.neuron 源代码

from abc import abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from . import surrogate, base
from .. import configure
import math
import numpy as np
import logging
from .auto_cuda import neuron_kernel as ac_neuron_kernel

try:
    import cupy
    from . import neuron_kernel, cuda_utils

except BaseException as e:
    logging.info(f'spikingjelly.activation_based.neuron: {e}')
    cupy = None
    neuron_kernel = None
    cuda_utils = None


[文档]class BaseNode(base.MemoryModule): def __init__(self, v_threshold: float = 1., v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <BaseNode.__init__-en>` .. _BaseNode.__init__-cn: :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool 可微分SNN神经元的基类神经元。 * :ref:`中文API <BaseNode.__init__-cn>` .. _BaseNode.__init__-en: :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool This class is the base class of differentiable spiking neurons. """ assert isinstance(v_reset, float) or v_reset is None assert isinstance(v_threshold, float) assert isinstance(detach_reset, bool) super().__init__() if v_reset is None: self.register_memory('v', 0.) else: self.register_memory('v', v_reset) self.v_threshold = v_threshold self.v_reset = v_reset self.detach_reset = detach_reset self.surrogate_function = surrogate_function self.step_mode = step_mode self.backend = backend self.store_v_seq = store_v_seq # used in lava_exchange self.lava_s_cale = 1 << 6 # used for cupy backend self.forward_kernel = None self.backward_kernel = None @property def store_v_seq(self): return self._store_v_seq @store_v_seq.setter def store_v_seq(self, value: bool): self._store_v_seq = value if value: if not hasattr(self, 'v_seq'): self.register_memory('v_seq', None)
[文档] @staticmethod @torch.jit.script def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float): v = (1. - spike) * v + spike * v_reset return v
[文档] @staticmethod @torch.jit.script def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float): v = v - spike * v_threshold return v
[文档] @abstractmethod def neuronal_charge(self, x: torch.Tensor): """ * :ref:`API in English <BaseNode.neuronal_charge-en>` .. _BaseNode.neuronal_charge-cn: 定义神经元的充电差分方程。子类必须实现这个函数。 * :ref:`中文API <BaseNode.neuronal_charge-cn>` .. _BaseNode.neuronal_charge-en: Define the charge difference equation. The sub-class must implement this function. """ raise NotImplementedError
[文档] def neuronal_fire(self): """ * :ref:`API in English <BaseNode.neuronal_fire-en>` .. _BaseNode.neuronal_fire-cn: 根据当前神经元的电压、阈值,计算输出脉冲。 * :ref:`中文API <BaseNode.neuronal_fire-cn>` .. _BaseNode.neuronal_fire-en: Calculate out spikes of neurons by their current membrane potential and threshold voltage. """ return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike): """ * :ref:`API in English <BaseNode.neuronal_reset-en>` .. _BaseNode.neuronal_reset-cn: 根据当前神经元释放的脉冲,对膜电位进行重置。 * :ref:`中文API <BaseNode.neuronal_reset-cn>` .. _BaseNode.neuronal_reset-en: Reset the membrane potential according to neurons' output spikes. """ if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.v_reset is None: # soft reset self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) else: # hard reset self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
[文档] def extra_repr(self): return f'v_threshold={self.v_threshold}, v_reset={self.v_reset}, detach_reset={self.detach_reset}, step_mode={self.step_mode}, backend={self.backend}'
[文档] def single_step_forward(self, x: torch.Tensor): """ * :ref:`API in English <BaseNode.single_step_forward-en>` .. _BaseNode.single_step_forward-cn: :param x: 输入到神经元的电压增量 :type x: torch.Tensor :return: 神经元的输出脉冲 :rtype: torch.Tensor 按照充电、放电、重置的顺序进行前向传播。 * :ref:`中文API <BaseNode.single_step_forward-cn>` .. _BaseNode.single_step_forward-en: :param x: increment of voltage inputted to neurons :type x: torch.Tensor :return: out spikes of neurons :rtype: torch.Tensor Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`. """ self.v_float_to_tensor(x) self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor): T = x_seq.shape[0] y_seq = [] if self.store_v_seq: v_seq = [] for t in range(T): y = self.single_step_forward(x_seq[t]) y_seq.append(y) if self.store_v_seq: v_seq.append(self.v) if self.store_v_seq: self.v_seq = torch.stack(v_seq) return torch.stack(y_seq)
[文档] def v_float_to_tensor(self, x: torch.Tensor): if isinstance(self.v, float): v_init = self.v self.v = torch.full_like(x.data, v_init)
[文档]class AdaptBaseNode(BaseNode): def __init__(self, v_threshold: float = 1., v_reset: float = 0., v_rest: float = 0., w_rest: float = 0., tau_w: float = 2., a: float = 0., b: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): # b: jump amplitudes # a: subthreshold coupling assert isinstance(w_rest, float) assert isinstance(v_rest, float) assert isinstance(tau_w, float) assert isinstance(a, float) assert isinstance(b, float) super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.register_memory('w', w_rest) self.w_rest = w_rest self.v_rest = v_rest self.tau_w = tau_w self.a = a self.b = b
[文档] @staticmethod @torch.jit.script def jit_neuronal_adaptation(w: torch.Tensor, tau_w: float, a: float, v_rest: float, v: torch.Tensor): return w + 1. / tau_w * (a * (v - v_rest) - w)
[文档] def neuronal_adaptation(self): """ * :ref:`API in English <AdaptBaseNode.neuronal_adaptation-en>` .. _AdaptBaseNode.neuronal_adaptation-cn: 脉冲触发的适应性电流的更新 * :ref:`中文API <AdaptBaseNode.neuronal_adaptation-cn>` .. _AdaptBaseNode.neuronal_adaptation-en: Spike-triggered update of adaptation current. """ self.w = self.jit_neuronal_adaptation(self.w, self.tau_w, self.a, self.v_rest, self.v)
[文档] @staticmethod @torch.jit.script def jit_hard_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_reset: float, b: float, spike: torch.Tensor): v = (1. - spike_d) * v + spike * v_reset w = w + b * spike return v, w
[文档] @staticmethod @torch.jit.script def jit_soft_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_threshold: float, b: float, spike: torch.Tensor): v = v - spike_d * v_threshold w = w + b * spike return v, w
[文档] def neuronal_reset(self, spike): """ * :ref:`API in English <AdaptBaseNode.neuronal_reset-en>` .. _AdaptBaseNode.neuronal_reset-cn: 根据当前神经元释放的脉冲,对膜电位进行重置。 * :ref:`中文API <AdaptBaseNode.neuronal_reset-cn>` .. _AdaptBaseNode.neuronal_reset-en: Reset the membrane potential according to neurons' output spikes. """ if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.v_reset is None: # soft reset self.v, self.w = self.jit_soft_reset(self.v, self.w, spike_d, self.v_threshold, self.b, spike) else: # hard reset self.v, self.w = self.jit_hard_reset(self.v, self.w, spike_d, self.v_reset, self.b, spike)
[文档] def extra_repr(self): return super().extra_repr() + f', v_rest={self.v_rest}, w_rest={self.w_rest}, tau_w={self.tau_w}, a={self.a}, b={self.b}'
[文档] def single_step_forward(self, x: torch.Tensor): self.v_float_to_tensor(x) self.w_float_to_tensor(x) self.neuronal_charge(x) self.neuronal_adaptation() spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
[文档] def w_float_to_tensor(self, x: torch.Tensor): if isinstance(self.w, float): w_init = self.w self.w = torch.full_like(x.data, fill_value=w_init)
[文档]class IFNode(BaseNode): def __init__(self, v_threshold: float = 1., v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <IFNode.__init__-en>` .. _IFNode.__init__-cn: :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为: .. math:: H[t] = V[t-1] + X[t] * :ref:`中文API <IFNode.__init__-cn>` .. _IFNode.__init__-en: :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed: .. math:: H[t] = V[t-1] + X[t] """ super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) @property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor): self.v = self.v + x
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_hard_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float): v = v + x spike = (v >= v_threshold).to(x) v = v_reset * spike + (1. - spike) * v return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_soft_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float): v = v + x spike = (v >= v_threshold).to(x) v = v - spike * v_threshold return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.training: if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': hard_reset = self.v_reset is not None if x_seq.dtype == torch.float: dtype = 'float' elif x_seq.dtype == torch.half: dtype = 'half2' else: raise NotImplementedError(x_seq.dtype) if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype): self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype) if self.backward_kernel is None or not self.backward_kernel.check_attributes(surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype): self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel( surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype) self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = ac_neuron_kernel.IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.forward_kernel, self.backward_kernel) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend) else: self.v_float_to_tensor(x_seq[0]) if self.v_reset is None: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq, self.v, self.v_threshold) else: spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold) else: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq, self.v, self.v_threshold, self.v_reset) else: spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold, self.v_reset) return spike_seq
[文档] def single_step_forward(self, x: torch.Tensor): if self.training: return super().single_step_forward(x) else: self.v_float_to_tensor(x) if self.v_reset is None: spike, self.v = self.jit_eval_single_step_forward_soft_reset(x, self.v, self.v_threshold) else: spike, self.v = self.jit_eval_single_step_forward_hard_reset(x, self.v, self.v_threshold, self.v_reset) return spike
[文档]class LIFNode(BaseNode): def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1., v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <LIFNode.__init__-en>` .. _LIFNode.__init__-cn: :param tau: 膜电位时间常数 :type tau: float :param decay_input: 输入是否也会参与衰减 :type decay_input: bool :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: 若 ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) 若 ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] * :ref:`中文API <LIFNode.__init__-cn>` .. _LIFNode.__init__-en: :param tau: membrane time constant :type tau: float :param decay_input: whether the input will decay :type decay_input: bool :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed: IF ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) IF ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] """ assert isinstance(tau, float) and tau > 1. super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.tau = tau self.decay_input = decay_input @property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def extra_repr(self): return super().extra_repr() + f', tau={self.tau}'
[文档] def neuronal_charge(self, x: torch.Tensor): if self.decay_input: if self.v_reset is None or self.v_reset == 0.: self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau) else: self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau) else: if self.v_reset is None or self.v_reset == 0.: self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau) else: self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
[文档] @staticmethod @torch.jit.script def neuronal_charge_decay_input_reset0(x: torch.Tensor, v: torch.Tensor, tau: float): v = v + (x - v) / tau return v
[文档] @staticmethod @torch.jit.script def neuronal_charge_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float): v = v + (x - (v - v_reset)) / tau return v
[文档] @staticmethod @torch.jit.script def neuronal_charge_no_decay_input_reset0(x: torch.Tensor, v: torch.Tensor, tau: float): v = v * (1. - 1. / tau) + x return v
[文档] @staticmethod @torch.jit.script def neuronal_charge_no_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float): v = v - (v - v_reset) / tau + x return v
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_hard_reset_decay_input(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): v = v + (x - (v - v_reset)) / tau spike = (v >= v_threshold).to(x) v = v_reset * spike + (1. - spike) * v return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_hard_reset_no_decay_input(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): v = v - (v - v_reset) / tau + x spike = (v >= v_threshold).to(x) v = v_reset * spike + (1. - spike) * v return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_soft_reset_decay_input(x: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): v = v + (x - v) / tau spike = (v >= v_threshold).to(x) v = v - spike * v_threshold return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_single_step_forward_soft_reset_no_decay_input(x: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): v = v * (1. - 1. / tau) + x spike = (v >= v_threshold).to(x) v = v - spike * v_threshold return spike, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset_decay_input(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + (x_seq[t] - (v - v_reset)) / tau spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + (x_seq[t] - (v - v_reset)) / tau spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset_no_decay_input(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v - (v - v_reset) / tau + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float, tau: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v - (v - v_reset) / tau + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v_reset * spike + (1. - spike) * v spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset_decay_input(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + (x_seq[t] - v) / tau spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v + (x_seq[t] - v) / tau spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset_no_decay_input(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): spike_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v * (1. - 1. / tau) + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike return spike_seq, v
[文档] @staticmethod @torch.jit.script def jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float, tau: float): spike_seq = torch.zeros_like(x_seq) v_seq = torch.zeros_like(x_seq) for t in range(x_seq.shape[0]): v = v * (1. - 1. / tau) + x_seq[t] spike = (v >= v_threshold).to(x_seq) v = v - spike * v_threshold spike_seq[t] = spike v_seq[t] = v return spike_seq, v, v_seq
[文档] def single_step_forward(self, x: torch.Tensor): if self.training: return super().single_step_forward(x) else: self.v_float_to_tensor(x) if self.v_reset is None: if self.decay_input: spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v, self.v_threshold, self.tau) else: spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v, self.v_threshold, self.tau) else: if self.decay_input: spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v, self.v_threshold, self.v_reset, self.tau) else: spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v, self.v_threshold, self.v_reset, self.tau) return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.training: if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': hard_reset = self.v_reset is not None if x_seq.dtype == torch.float: dtype = 'float' elif x_seq.dtype == torch.half: dtype = 'half2' else: raise NotImplementedError(x_seq.dtype) if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input): self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel(decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype) if self.backward_kernel is None or not self.backward_kernel.check_attributes( surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input): self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel(decay_input=self.decay_input, surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype) self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = ac_neuron_kernel.LIFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, 1. / self.tau, self.forward_kernel, self.backward_kernel) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend) else: self.v_float_to_tensor(x_seq[0]) if self.v_reset is None: if self.decay_input: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq( x_seq, self.v, self.v_threshold, self.tau) else: spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_decay_input(x_seq, self.v, self.v_threshold, self.tau) else: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq( x_seq, self.v, self.v_threshold, self.tau) else: spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_no_decay_input(x_seq, self.v, self.v_threshold, self.tau) else: if self.decay_input: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq( x_seq, self.v, self.v_threshold, self.v_reset, self.tau) else: spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_decay_input(x_seq, self.v, self.v_threshold, self.v_reset, self.tau) else: if self.store_v_seq: spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq( x_seq, self.v, self.v_threshold, self.v_reset, self.tau) else: spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_no_decay_input(x_seq, self.v, self.v_threshold, self.v_reset, self.tau) return spike_seq
[文档]class ParametricLIFNode(BaseNode): def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1., v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <ParametricLIFNode.__init__-en>` .. _ParametricLIFNode.__init__-cn: :param init_tau: 膜电位时间常数的初始值 :type init_tau: float :param decay_input: 输入是否也会参与衰减 :type decay_input: bool :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool :param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。 这个选项的优先权高于 ``backend`` :type cupy_fp32_inference: bool `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ 提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: 若 ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) 若 ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] 其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。 * :ref:`中文API <ParametricLIFNode.__init__-cn>` .. _ParametricLIFNode.__init__-en: :param init_tau: the initial value of membrane time constant :type init_tau: float :param decay_input: whether the input will decay :type decay_input: bool :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool :param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this module will use `cupy` to accelerate. This option has priority over ``backend`` :type cupy_fp32_inference: bool The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed: IF ``decay_input == True``: .. math:: H = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) IF ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter. """ assert isinstance(init_tau, float) and init_tau > 1. super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.decay_input = decay_input init_w = - math.log(init_tau - 1.) self.w = nn.Parameter(torch.as_tensor(init_w)) @property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def extra_repr(self): with torch.no_grad(): tau = 1. / self.w.sigmoid() return super().extra_repr() + f', tau={tau}'
[文档] def neuronal_charge(self, x: torch.Tensor): if self.decay_input: if self.v_reset is None or self.v_reset == 0.: self.v = self.v + (x - self.v) * self.w.sigmoid() else: self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid() else: if self.v_reset is None or self.v_reset == 0.: self.v = self.v * (1. - self.w.sigmoid()) + x else: self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': hard_reset = self.v_reset is not None if x_seq.dtype == torch.float: dtype = 'float' elif x_seq.dtype == torch.half: dtype = 'half2' else: raise NotImplementedError(x_seq.dtype) if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset, dtype=dtype, decay_input=self.decay_input): self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel(decay_input=self.decay_input, hard_reset=hard_reset, dtype=dtype) if self.backward_kernel is None or not self.backward_kernel.check_attributes( surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input): self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel(decay_input=self.decay_input, surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset, detach_reset=self.detach_reset, dtype=dtype) self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = ac_neuron_kernel.ParametricLIFNodeATGF.apply( x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.w.sigmoid().to(x_seq), self.forward_kernel, self.backward_kernel) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)
[文档]class QIFNode(BaseNode): def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0., v_reset: float = -0.1, surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <QIFNode.__init__-en>` .. _QIFNode.__init__-cn: :param tau: 膜电位时间常数 :type tau: float :param v_c: 关键电压 :type v_c: float :param a0: :type a0: float :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_rest: 静息电位 :type v_rest: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c)) * :ref:`中文API <QIFNode.__init__-cn>` .. _QIFNode.__init__-en: :param tau: membrane time constant :type tau: float :param v_c: critical voltage :type v_c: float :param a0: :type a0: float :param v_threshold: threshold voltage of neurons :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model. The subthreshold neural dynamics of it is as followed: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c)) """ assert isinstance(tau, float) and tau > 1. if v_reset is not None: assert v_threshold > v_reset assert v_rest >= v_reset assert a0 > 0 super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.tau = tau self.v_c = v_c self.v_rest = v_rest self.a0 = a0
[文档] def extra_repr(self): return super().extra_repr() + f', tau={self.tau}, v_c={self.v_c}, a0={self.a0}, v_rest={self.v_rest}'
[文档] def neuronal_charge(self, x: torch.Tensor): self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau
@property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = neuron_kernel.MultiStepQIFNodePTT.apply( x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)
[文档]class EIFNode(BaseNode): def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1., v_rest: float = 0., v_reset: float = -0.1, surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <EIFNode.__init__-en>` .. _EIFNode.__init__-cn: :param tau: 膜电位时间常数 :type tau: float :param delta_T: 陡峭度参数 :type delta_T: float :param theta_rh: 基强度电压阈值 :type theta_rh: float :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right) * :ref:`中文API <EIFNode.__init__-cn>` .. _EIFNode.__init__-en: :param tau: membrane time constant :type tau: float :param delta_T: sharpness parameter :type delta_T: float :param theta_rh: rheobase threshold :type theta_rh: float :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`. The subthreshold neural dynamics of it is as followed: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right) """ assert isinstance(tau, float) and tau > 1. if v_reset is not None: assert v_threshold > v_reset assert v_rest >= v_reset assert delta_T > 0 super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.tau = tau self.delta_T = delta_T self.v_rest = v_rest self.theta_rh = theta_rh
[文档] def extra_repr(self): return super().extra_repr() + f', tau={self.tau}, delta_T={self.delta_T}, theta_rh={self.theta_rh}'
[文档] def neuronal_charge(self, x: torch.Tensor): with torch.no_grad(): if not isinstance(self.v, torch.Tensor): self.v = torch.as_tensor(self.v, device=x.device) self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp( (self.v - self.theta_rh) / self.delta_T)) / self.tau
@property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': self.v_float_to_tensor(x_seq[0]) spike_seq, v_seq = neuron_kernel.MultiStepEIFNodePTT.apply( x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)
[文档]class IzhikevichNode(AdaptBaseNode): def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_reset: float = 0., v_rest: float = -0.1, w_rest: float = 0., tau_w: float = 2., a: float = 0., b: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): assert isinstance(tau, float) and tau > 1. assert a0 > 0 super().__init__(v_threshold, v_reset, v_rest, w_rest, tau_w, a, b, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.tau = tau self.v_c = v_c self.a0 = a0
[文档] def extra_repr(self): return super().extra_repr() + f', tau={self.tau}, v_c={self.v_c}, a0={self.a0}'
[文档] def neuronal_charge(self, x: torch.Tensor): self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c) - self.w) / self.tau
@property def supported_backends(self): if self.step_mode == 's': return ('torch',) elif self.step_mode == 'm': return ('torch', 'cupy') else: raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor): if self.backend == 'torch': return super().multi_step_forward(x_seq) elif self.backend == 'cupy': self.v_float_to_tensor(x_seq[0]) self.w_float_to_tensor(x_seq[0]) spike_seq, v_seq, w_seq = neuron_kernel.MultiStepIzhikevichNodePTT.apply( x_seq.flatten(1), self.v.flatten(0), self.w.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest, self.a, self.b, self.tau_w, self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code) spike_seq = spike_seq.reshape(x_seq.shape) v_seq = v_seq.reshape(x_seq.shape) w_seq = w_seq.reshape(x_seq.shape) if self.store_v_seq: self.v_seq = v_seq self.v = v_seq[-1].clone() self.w = w_seq[-1].clone() return spike_seq else: raise ValueError(self.backend)
[文档]class LIAFNode(LIFNode): def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs): """ * :ref:`API in English <LIAFNode.__init__-en>` .. _LIAFNode.__init__-cn: :param act: 激活函数 :type act: Callable :param threshold_related: 是否使用阈值依赖模式 (TR mode). 若为 ``True`` 则 ``y = act(h - v_th)``, 否则 ``y = act(h)`` :type threshold_related: bool `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_ 提出的LIAF神经元。LIAFNode和LIFNode的行为相同,但输出是 ``self.act(...)`` 而非脉冲。 .. Warning:: The outputs of this neurons layer are not binary spikes. * :ref:`中文API <LIAFNode.__init__-cn>` .. _LIAFNode.__init__-en: :param act: the activation function :type act: Callable :param threshold_related: whether the neuron uses threshold related (TR mode). If ``True``, ``y = act(h - v_th)``, otherwise ``y = act(h)`` :type threshold_related: bool Other parameters in `*args, **kwargs` are same with :class:`LIFNode`. The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_. LIAFNode has the same behavior as LIFNode, but outputs ``self.act(...)`` rather than spikes. .. admonition:: Warning :class: warning The outputs of this neurons layer are not binary spikes. """ super().__init__(*args, **kwargs) self.act = act self.threshold_related = threshold_related assert self.backend == 'torch', "LIAFNode only supports for backend='torch'!" assert self.single_step_cupy_fp32_inference == False, "LIAFNode does not support for single_step_cupy_fp32_inference!" @property def supported_backends(self): return ('torch',)
[文档] def single_step_forward(self, x: torch.Tensor): self.neuronal_charge(x) if self.threshold_related: y = self.act(self.v - self.v_threshold) else: y = self.act(self.v) spike = self.neuronal_fire() self.neuronal_reset(spike) return y
[文档]class KLIFNode(BaseNode): def __init__(self, scale_reset: bool=False, tau: float = 2., decay_input: bool = True, v_threshold: float = 1., v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False): """ * :ref:`API in English <KLIFNode.__init__-en>` .. _KLIFNode.__init__-cn: :param scale_reset: 是否在 ``neuronal_reset`` 时将 ``v`` 进行缩放 :type scale_reset: bool :param tau: 膜电位时间常数 :type tau: float :param decay_input: 输入是否也会参与衰减 :type decay_input: bool :param v_threshold: 神经元的阈值电压 :type v_threshold: float :param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; 如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` :type v_reset: float :param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 :type surrogate_function: Callable :param detach_reset: 是否将reset过程的计算图分离 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步) :type step_mode: str :param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 通常设置成 ``False`` ,可以节省内存 :type store_v_seq: bool `KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_ 提出的K-based Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为: 若 ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) 若 ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] 注意,KLIF神经元的放电和重置与普通的神经元不同,为: .. math:: F[t] &= \\mathrm{ReLU}(kH[t]) S[t] &= \\Theta(F[t] - V_{th}) 如果 ``scale_reset == False``,则 .. math:: V[t] = \\begin{cases} F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\ F[t] - S[t]V_{th}, soft~~reset \\end{cases} 如果 ``scale_reset == True``,则 .. math:: V[t] = \\begin{cases} \\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\ \\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset \\end{cases} * :ref:`中文API <KLIFNode.__init__-cn>` .. _KLIFNode.__init__-en: :param scale_reset: whether scale ``v`` in ``neuronal_reset`` :type scale_reset: bool :param tau: membrane time constant :type tau: float :param decay_input: whether the input will decay :type decay_input: bool :param v_threshold: threshold of this neurons layer :type v_threshold: float :param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike :type v_reset: float :param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward :type surrogate_function: Callable :param detach_reset: whether detach the computation graph of reset in backward :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step) :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, using ``'cupy'`` backend will have the fastest training speed :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the memory consumption :type store_v_seq: bool The K-based Leaky Integrate-and-Fire neuron proposed by `KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_, which can be seen as a leaky integrator. The subthreshold neural dynamics of it is as followed: IF ``decay_input == True``: .. math:: H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset})) IF ``decay_input == False``: .. math:: H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t] Note that the neuronal fire and reset of the KLIF neuron is different from native neurons: .. math:: F[t] &= \\mathrm{ReLU}(kH[t]) S[t] &= \\Theta(F[t] - V_{th}) If ``scale_reset == False``, then .. math:: V[t] = \\begin{cases} F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\ F[t] - S[t]V_{th}, soft~~reset \\end{cases} Elif ``scale_reset == True``, then .. math:: V[t] = \\begin{cases} \\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\ \\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset \\end{cases} """ assert isinstance(tau, float) and tau > 1. if backend == 'cupy': raise NotImplementedError("The CuPy backend for the KLIF neuron has not been implemented!") super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) self.scale_reset = scale_reset self.tau = tau self.decay_input = decay_input self.k = nn.Parameter(torch.as_tensor(1.))
[文档] @staticmethod @torch.jit.script def neuronal_charge_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float, k: torch.Tensor): v = v + (x - (v - v_reset)) / tau v = torch.relu_(k * v) return v
[文档] @staticmethod @torch.jit.script def neuronal_charge_no_decay_input(x: torch.Tensor, v: torch.Tensor, v_reset: float, tau: float, k: torch.Tensor): v = v - (v - v_reset) / tau + x v = torch.relu_(k * v) return v
[文档] def neuronal_charge(self, x: torch.Tensor): if self.decay_input: self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau, self.k) else: self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau, self.k)
[文档] def neuronal_reset(self, spike): if self.detach_reset: spike_d = spike.detach() else: spike_d = spike if self.scale_reset: if self.v_reset is None: # soft reset self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) / self.k else: # hard reset self.v = self.jit_hard_reset(self.v / self.k, spike_d, self.v_reset) else: if self.v_reset is None: # soft reset self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) else: # hard reset self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)