from abc import abstractmethod
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import logging
from . import surrogate, base
from .auto_cuda import neuron_kernel as ac_neuron_kernel
from .auto_cuda import ss_neuron_kernel as ss_ac_neuron_kernel
try:
import cupy
from . import neuron_kernel, cuda_utils
except BaseException as e:
logging.info(f'spikingjelly.activation_based.neuron: {e}')
cupy = None
neuron_kernel = None
cuda_utils = None
[文档]class SimpleBaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
step_mode='s'):
"""
A simple version of ``BaseNode``. The user can modify this neuron easily.
"""
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
self.detach_reset = detach_reset
self.step_mode = step_mode
self.register_memory(name='v', value=0.)
[文档] def single_step_forward(self, x: torch.Tensor):
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.v_reset is None:
# soft reset
self.v = self.v - self.v_threshold * spike_d
else:
# hard reset
self.v = spike_d * self.v_reset + (1. - spike_d) * self.v
[文档]class SimpleIFNode(SimpleBaseNode):
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
[文档]class SimpleLIFNode(SimpleBaseNode):
def __init__(self, tau:float, decay_input: bool, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
step_mode='s'):
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode)
self.tau = tau
self.decay_input = decay_input
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
self.v = self.v + (self.v_reset - self.v + x) / self.tau
else:
self.v = self.v + (self.v_reset - self.v) / self.tau + x
[文档]class BaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <BaseNode.__init__-en>`
.. _BaseNode.__init__-cn:
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
可微分SNN神经元的基类神经元。
* :ref:`中文API <BaseNode.__init__-cn>`
.. _BaseNode.__init__-en:
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
This class is the base class of differentiable spiking neurons.
"""
assert isinstance(v_reset, float) or v_reset is None
assert isinstance(v_threshold, float)
assert isinstance(detach_reset, bool)
super().__init__()
if v_reset is None:
self.register_memory('v', 0.)
else:
self.register_memory('v', v_reset)
self.v_threshold = v_threshold
self.v_reset = v_reset
self.detach_reset = detach_reset
self.surrogate_function = surrogate_function
self.step_mode = step_mode
self.backend = backend
self.store_v_seq = store_v_seq
# used in lava_exchange
self.lava_s_cale = 1 << 6
# used for cupy backend
self.forward_kernel = None
self.backward_kernel = None
@property
def store_v_seq(self):
return self._store_v_seq
@store_v_seq.setter
def store_v_seq(self, value: bool):
self._store_v_seq = value
if value:
if not hasattr(self, 'v_seq'):
self.register_memory('v_seq', None)
[文档] @staticmethod
@torch.jit.script
def jit_hard_reset(v: torch.Tensor, spike: torch.Tensor, v_reset: float):
v = (1. - spike) * v + spike * v_reset
return v
[文档] @staticmethod
@torch.jit.script
def jit_soft_reset(v: torch.Tensor, spike: torch.Tensor, v_threshold: float):
v = v - spike * v_threshold
return v
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
"""
* :ref:`API in English <BaseNode.neuronal_charge-en>`
.. _BaseNode.neuronal_charge-cn:
定义神经元的充电差分方程。子类必须实现这个函数。
* :ref:`中文API <BaseNode.neuronal_charge-cn>`
.. _BaseNode.neuronal_charge-en:
Define the charge difference equation. The sub-class must implement this function.
"""
raise NotImplementedError
[文档] def neuronal_fire(self):
"""
* :ref:`API in English <BaseNode.neuronal_fire-en>`
.. _BaseNode.neuronal_fire-cn:
根据当前神经元的电压、阈值,计算输出脉冲。
* :ref:`中文API <BaseNode.neuronal_fire-cn>`
.. _BaseNode.neuronal_fire-en:
Calculate out spikes of neurons by their current membrane potential and threshold voltage.
"""
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
"""
* :ref:`API in English <BaseNode.neuronal_reset-en>`
.. _BaseNode.neuronal_reset-cn:
根据当前神经元释放的脉冲,对膜电位进行重置。
* :ref:`中文API <BaseNode.neuronal_reset-cn>`
.. _BaseNode.neuronal_reset-en:
Reset the membrane potential according to neurons' output spikes.
"""
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)
else:
# hard reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
[文档] def single_step_forward(self, x: torch.Tensor):
"""
* :ref:`API in English <BaseNode.single_step_forward-en>`
.. _BaseNode.single_step_forward-cn:
:param x: 输入到神经元的电压增量
:type x: torch.Tensor
:return: 神经元的输出脉冲
:rtype: torch.Tensor
按照充电、放电、重置的顺序进行前向传播。
* :ref:`中文API <BaseNode.single_step_forward-cn>`
.. _BaseNode.single_step_forward-en:
:param x: increment of voltage inputted to neurons
:type x: torch.Tensor
:return: out spikes of neurons
:rtype: torch.Tensor
Forward by the order of `neuronal_charge`, `neuronal_fire`, and `neuronal_reset`.
"""
self.v_float_to_tensor(x)
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
y_seq = []
if self.store_v_seq:
v_seq = []
for t in range(T):
y = self.single_step_forward(x_seq[t])
y_seq.append(y)
if self.store_v_seq:
v_seq.append(self.v)
if self.store_v_seq:
self.v_seq = torch.stack(v_seq)
return torch.stack(y_seq)
[文档] def v_float_to_tensor(self, x: torch.Tensor):
if isinstance(self.v, float):
v_init = self.v
self.v = torch.full_like(x.data, v_init)
[文档]class AdaptBaseNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
v_rest: float = 0., w_rest: float = 0., tau_w: float = 2., a: float = 0., b: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
# b: jump amplitudes
# a: subthreshold coupling
assert isinstance(w_rest, float)
assert isinstance(v_rest, float)
assert isinstance(tau_w, float)
assert isinstance(a, float)
assert isinstance(b, float)
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.register_memory('w', w_rest)
self.w_rest = w_rest
self.v_rest = v_rest
self.tau_w = tau_w
self.a = a
self.b = b
[文档] @staticmethod
@torch.jit.script
def jit_neuronal_adaptation(w: torch.Tensor, tau_w: float, a: float, v_rest: float, v: torch.Tensor):
return w + 1. / tau_w * (a * (v - v_rest) - w)
[文档] def neuronal_adaptation(self):
"""
* :ref:`API in English <AdaptBaseNode.neuronal_adaptation-en>`
.. _AdaptBaseNode.neuronal_adaptation-cn:
脉冲触发的适应性电流的更新
* :ref:`中文API <AdaptBaseNode.neuronal_adaptation-cn>`
.. _AdaptBaseNode.neuronal_adaptation-en:
Spike-triggered update of adaptation current.
"""
self.w = self.jit_neuronal_adaptation(self.w, self.tau_w, self.a, self.v_rest, self.v)
[文档] @staticmethod
@torch.jit.script
def jit_hard_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_reset: float, b: float,
spike: torch.Tensor):
v = (1. - spike_d) * v + spike * v_reset
w = w + b * spike
return v, w
[文档] @staticmethod
@torch.jit.script
def jit_soft_reset(v: torch.Tensor, w: torch.Tensor, spike_d: torch.Tensor, v_threshold: float, b: float,
spike: torch.Tensor):
v = v - spike_d * v_threshold
w = w + b * spike
return v, w
[文档] def neuronal_reset(self, spike):
"""
* :ref:`API in English <AdaptBaseNode.neuronal_reset-en>`
.. _AdaptBaseNode.neuronal_reset-cn:
根据当前神经元释放的脉冲,对膜电位进行重置。
* :ref:`中文API <AdaptBaseNode.neuronal_reset-cn>`
.. _AdaptBaseNode.neuronal_reset-en:
Reset the membrane potential according to neurons' output spikes.
"""
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.v_reset is None:
# soft reset
self.v, self.w = self.jit_soft_reset(self.v, self.w, spike_d, self.v_threshold, self.b, spike)
else:
# hard reset
self.v, self.w = self.jit_hard_reset(self.v, self.w, spike_d, self.v_reset, self.b, spike)
[文档] def single_step_forward(self, x: torch.Tensor):
self.v_float_to_tensor(x)
self.w_float_to_tensor(x)
self.neuronal_charge(x)
self.neuronal_adaptation()
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def w_float_to_tensor(self, x: torch.Tensor):
if isinstance(self.w, float):
w_init = self.w
self.w = torch.full_like(x.data, fill_value=w_init)
[文档]class IFNode(BaseNode):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <IFNode.__init__-en>`
.. _IFNode.__init__-cn:
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Integrate-and-Fire 神经元模型,可以看作理想积分器,无输入时电压保持恒定,不会像LIF神经元那样衰减。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + X[t]
* :ref:`中文API <IFNode.__init__-cn>`
.. _IFNode.__init__-en:
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Integrate-and-Fire neuron, which can be seen as a ideal integrator. The voltage of the IF neuron will not decay
as that of the LIF neuron. The sub-threshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + X[t]
"""
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch', 'cupy')
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
[文档] @staticmethod
@torch.jit.script
def jit_eval_single_step_forward_hard_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float, v_reset: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v_reset * spike + (1. - spike) * v
return spike, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_single_step_forward_soft_reset(x: torch.Tensor, v: torch.Tensor, v_threshold: float):
v = v + x
spike = (v >= v_threshold).to(x)
v = v - spike * v_threshold
return spike, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_hard_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float,
v_reset: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
return spike_seq, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float,
v_reset: float):
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v_reset * spike + (1. - spike) * v
spike_seq[t] = spike
v_seq[t] = v
return spike_seq, v, v_seq
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_soft_reset(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float):
spike_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v - spike * v_threshold
spike_seq[t] = spike
return spike_seq, v
[文档] @staticmethod
@torch.jit.script
def jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq: torch.Tensor, v: torch.Tensor, v_threshold: float):
spike_seq = torch.zeros_like(x_seq)
v_seq = torch.zeros_like(x_seq)
for t in range(x_seq.shape[0]):
v = v + x_seq[t]
spike = (v >= v_threshold).to(x_seq)
v = v - spike * v_threshold
spike_seq[t] = spike
v_seq[t] = v
return spike_seq, v, v_seq
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype):
self.forward_kernel = ac_neuron_kernel.IFNodeFPTTKernel(hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype):
self.backward_kernel = ac_neuron_kernel.IFNodeBPTTKernel(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.IFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset,
self.forward_kernel,
self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_with_v_seq(x_seq,
self.v,
self.v_threshold)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset(x_seq, self.v, self.v_threshold)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_with_v_seq(x_seq,
self.v,
self.v_threshold,
self.v_reset)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset(x_seq, self.v, self.v_threshold,
self.v_reset)
return spike_seq
[文档] def single_step_forward(self, x: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().single_step_forward(x)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x.dtype == torch.float:
dtype = 'float'
elif x.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype):
self.forward_kernel = ss_ac_neuron_kernel.IFNodeFPKernel(hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype):
self.backward_kernel = ss_ac_neuron_kernel.IFNodeBPKernel(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x)
spike, v = ss_ac_neuron_kernel.IFNodeATGF.apply(x.flatten(0), self.v.flatten(0),
self.v_threshold, self.v_reset,
self.forward_kernel,
self.backward_kernel)
spike = spike.reshape(x.shape)
v = v.reshape(x.shape)
self.v = v
return spike
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x)
if self.v_reset is None:
spike, self.v = self.jit_eval_single_step_forward_soft_reset(x, self.v, self.v_threshold)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset(x, self.v, self.v_threshold, self.v_reset)
return spike
[文档]class LIFNode(BaseNode):
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <LIFNode.__init__-en>`
.. _LIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
* :ref:`中文API <LIFNode.__init__-cn>`
.. _LIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Leaky Integrate-and-Fire neuron, which can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
"""
assert isinstance(tau, float) and tau > 1.
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.decay_input = decay_input
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch', 'cupy')
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
[文档] def single_step_forward(self, x: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().single_step_forward(x)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x.dtype == torch.float:
dtype = 'float'
elif x.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ss_ac_neuron_kernel.LIFNodeFPKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ss_ac_neuron_kernel.LIFNodeBPKernel(
decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype)
self.v_float_to_tensor(x)
spike, v = ss_ac_neuron_kernel.LIFNodeATGF.apply(x.flatten(0), self.v.flatten(0),
self.v_threshold, self.v_reset, 1. / self.tau,
self.forward_kernel,
self.backward_kernel)
spike = spike.reshape(x.shape)
v = v.reshape(x.shape)
self.v = v
return spike
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x)
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.training:
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.LIFNodeFPTTKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.LIFNodeBPTTKernel(decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.LIFNodeATGF.apply(x_seq.flatten(1), self.v.flatten(0),
self.v_threshold, self.v_reset, 1. / self.tau,
self.forward_kernel,
self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
else:
self.v_float_to_tensor(x_seq[0])
if self.v_reset is None:
if self.decay_input:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_decay_input(x_seq, self.v,
self.v_threshold,
self.tau)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_soft_reset_no_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_soft_reset_no_decay_input(x_seq, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.v_reset, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_decay_input(x_seq, self.v,
self.v_threshold,
self.v_reset,
self.tau)
else:
if self.store_v_seq:
spike_seq, self.v, self.v_seq = self.jit_eval_multi_step_forward_hard_reset_no_decay_input_with_v_seq(
x_seq, self.v, self.v_threshold, self.v_reset, self.tau)
else:
spike_seq, self.v = self.jit_eval_multi_step_forward_hard_reset_no_decay_input(x_seq, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike_seq
[文档]class ParametricLIFNode(BaseNode):
def __init__(self, init_tau: float = 2.0, decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <ParametricLIFNode.__init__-en>`
.. _ParametricLIFNode.__init__-cn:
:param init_tau: 膜电位时间常数的初始值
:type init_tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
:param cupy_fp32_inference: 若为 `True`,在 `eval` 模式下,使用float32,却在GPU上运行,并且 `cupy` 已经安装,则会自动使用 `cupy` 进行加速。
这个选项的优先权高于 ``backend``
:type cupy_fp32_inference: bool
`Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_
提出的 Parametric Leaky Integrate-and-Fire (PLIF)神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
其中 :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`,:math:`w` 是可学习的参数。
* :ref:`中文API <ParametricLIFNode.__init__-cn>`
.. _ParametricLIFNode.__init__-en:
:param init_tau: the initial value of membrane time constant
:type init_tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
:param cupy_fp32_inference: If `True`, if this module is in `eval` mode, using float32, running on GPU, and `cupy` is installed, then this
module will use `cupy` to accelerate. This option has priority over ``backend``
:type cupy_fp32_inference: bool
The Parametric Leaky Integrate-and-Fire (PLIF) neuron, which is proposed by `Incorporating Learnable Membrane Time Constant to Enhance Learning of Spiking Neural Networks <https://arxiv.org/abs/2007.05785>`_ and can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
where :math:`\\frac{1}{\\tau} = {\\rm Sigmoid}(w)`, :math:`w` is a learnable parameter.
"""
assert isinstance(init_tau, float) and init_tau > 1.
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.decay_input = decay_input
init_w = - math.log(init_tau - 1.)
self.w = nn.Parameter(torch.as_tensor(init_w))
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v + (x - self.v) * self.w.sigmoid()
else:
self.v = self.v + (x - (self.v - self.v_reset)) * self.w.sigmoid()
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.v * (1. - self.w.sigmoid()) + x
else:
self.v = self.v - (self.v - self.v_reset) * self.w.sigmoid() + x
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
hard_reset = self.v_reset is not None
if x_seq.dtype == torch.float:
dtype = 'float'
elif x_seq.dtype == torch.half:
dtype = 'half2'
else:
raise NotImplementedError(x_seq.dtype)
if self.forward_kernel is None or not self.forward_kernel.check_attributes(hard_reset=hard_reset,
dtype=dtype,
decay_input=self.decay_input):
self.forward_kernel = ac_neuron_kernel.ParametricLIFNodeFPTTKernel(decay_input=self.decay_input,
hard_reset=hard_reset, dtype=dtype)
if self.backward_kernel is None or not self.backward_kernel.check_attributes(
surrogate_function=self.surrogate_function.cuda_codes, hard_reset=hard_reset,
detach_reset=self.detach_reset, dtype=dtype, decay_input=self.decay_input):
self.backward_kernel = ac_neuron_kernel.ParametricLIFNodeBPTTKernel(decay_input=self.decay_input,
surrogate_function=self.surrogate_function.cuda_codes,
hard_reset=hard_reset,
detach_reset=self.detach_reset,
dtype=dtype)
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = ac_neuron_kernel.ParametricLIFNodeATGF.apply(
x_seq.flatten(1), self.v.flatten(0), self.v_threshold, self.v_reset, self.w.sigmoid().to(x_seq),
self.forward_kernel, self.backward_kernel)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class QIFNode(BaseNode):
def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1., v_rest: float = 0.,
v_reset: float = -0.1,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <QIFNode.__init__-en>`
.. _QIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param v_c: 关键电压
:type v_c: float
:param a0:
:type a0: float
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_rest: 静息电位
:type v_rest: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Quadratic Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,也是指数积分发放神经元(Exponential Integrate-and-Fire)的近似版本。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
* :ref:`中文API <QIFNode.__init__-cn>`
.. _QIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param v_c: critical voltage
:type v_c: float
:param a0:
:type a0: float
:param v_threshold: threshold voltage of neurons
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Quadratic Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an approximation of the Exponential Integrate-and-Fire model.
The subthreshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] + a_0 (V[t-1] - V_{rest})(V[t-1] - V_c))
"""
assert isinstance(tau, float) and tau > 1.
if v_reset is not None:
assert v_threshold > v_reset
assert v_rest >= v_reset
assert a0 > 0
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.v_c = v_c
self.v_rest = v_rest
self.a0 = a0
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c)) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = neuron_kernel.MultiStepQIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class EIFNode(BaseNode):
def __init__(self, tau: float = 2., delta_T: float = 1., theta_rh: float = .8, v_threshold: float = 1.,
v_rest: float = 0., v_reset: float = -0.1,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <EIFNode.__init__-en>`
.. _EIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param delta_T: 陡峭度参数
:type delta_T: float
:param theta_rh: 基强度电压阈值
:type theta_rh: float
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
Exponential Integrate-and-Fire 神经元模型,一种非线性积分发放神经元模型,是由HH神经元模型(Hodgkin-Huxley model)简化后推导出的一维模型。在 :math:`\\Delta_T\\to 0` 时退化为LIF模型。其阈下神经动力学方程为:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
* :ref:`中文API <EIFNode.__init__-cn>`
.. _EIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param delta_T: sharpness parameter
:type delta_T: float
:param theta_rh: rheobase threshold
:type theta_rh: float
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The Exponential Integrate-and-Fire neuron is a kind of nonlinear integrate-and-fire models and also an one-dimensional model derived from the Hodgkin-Huxley model. It degenerates to the LIF model when :math:`\\Delta_T\\to 0`.
The subthreshold neural dynamics of it is as followed:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}\\left(X[t] - (V[t-1] - V_{rest}) + \\Delta_T\\exp\\left(\\frac{V[t-1] - \\theta_{rh}}{\\Delta_T}\\right)\\right)
"""
assert isinstance(tau, float) and tau > 1.
if v_reset is not None:
assert v_threshold > v_reset
assert v_rest >= v_reset
assert delta_T > 0
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.tau = tau
self.delta_T = delta_T
self.v_rest = v_rest
self.theta_rh = theta_rh
[文档] def neuronal_charge(self, x: torch.Tensor):
with torch.no_grad():
if not isinstance(self.v, torch.Tensor):
self.v = torch.as_tensor(self.v, device=x.device)
self.v = self.v + (x + self.v_rest - self.v + self.delta_T * torch.exp(
(self.v - self.theta_rh) / self.delta_T)) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
spike_seq, v_seq = neuron_kernel.MultiStepEIFNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.tau, self.v_threshold, self.v_reset, self.v_rest,
self.theta_rh, self.delta_T, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class IzhikevichNode(AdaptBaseNode):
def __init__(self, tau: float = 2., v_c: float = 0.8, a0: float = 1., v_threshold: float = 1.,
v_reset: float = 0., v_rest: float = -0.1, w_rest: float = 0., tau_w: float = 2., a: float = 0.,
b: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, step_mode='s',
backend='torch', store_v_seq: bool = False):
assert isinstance(tau, float) and tau > 1.
assert a0 > 0
super().__init__(v_threshold, v_reset, v_rest, w_rest, tau_w, a, b, surrogate_function, detach_reset, step_mode,
backend, store_v_seq)
self.tau = tau
self.v_c = v_c
self.a0 = a0
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + (x + self.a0 * (self.v - self.v_rest) * (self.v - self.v_c) - self.w) / self.tau
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch',)
elif self.step_mode == 'm':
return ('torch', 'cupy')
else:
raise ValueError(self.step_mode)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'torch':
return super().multi_step_forward(x_seq)
elif self.backend == 'cupy':
self.v_float_to_tensor(x_seq[0])
self.w_float_to_tensor(x_seq[0])
spike_seq, v_seq, w_seq = neuron_kernel.MultiStepIzhikevichNodePTT.apply(
x_seq.flatten(1), self.v.flatten(0), self.w.flatten(0), self.tau, self.v_threshold, self.v_reset,
self.v_rest, self.a, self.b, self.tau_w,
self.v_c, self.a0, self.detach_reset, self.surrogate_function.cuda_code)
spike_seq = spike_seq.reshape(x_seq.shape)
v_seq = v_seq.reshape(x_seq.shape)
w_seq = w_seq.reshape(x_seq.shape)
if self.store_v_seq:
self.v_seq = v_seq
self.v = v_seq[-1].clone()
self.w = w_seq[-1].clone()
return spike_seq
else:
raise ValueError(self.backend)
[文档]class LIAFNode(LIFNode):
def __init__(self, act: Callable, threshold_related: bool, *args, **kwargs):
"""
* :ref:`API in English <LIAFNode.__init__-en>`
.. _LIAFNode.__init__-cn:
:param act: 激活函数
:type act: Callable
:param threshold_related: 是否使用阈值依赖模式 (TR mode). 若为 ``True`` 则 ``y = act(h - v_th)``,
否则 ``y = act(h)``
:type threshold_related: bool
`LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_ 提出的LIAF神经元。LIAFNode和LIFNode的行为相同,但输出是 ``self.act(...)`` 而非脉冲。
.. Warning::
The outputs of this neurons layer are not binary spikes.
* :ref:`中文API <LIAFNode.__init__-cn>`
.. _LIAFNode.__init__-en:
:param act: the activation function
:type act: Callable
:param threshold_related: whether the neuron uses threshold related (TR mode). If ``True``, ``y = act(h - v_th)``,
otherwise ``y = act(h)``
:type threshold_related: bool
Other parameters in `*args, **kwargs` are same with :class:`LIFNode`.
The LIAF neuron proposed in `LIAF-Net: Leaky Integrate and Analog Fire Network for Lightweight and Efficient Spatiotemporal Information Processing <https://arxiv.org/abs/2011.06176>`_. LIAFNode has the same behavior as LIFNode, but outputs ``self.act(...)``
rather than spikes.
.. admonition:: Warning
:class: warning
The outputs of this neurons layer are not binary spikes.
"""
super().__init__(*args, **kwargs)
self.act = act
self.threshold_related = threshold_related
assert self.backend == 'torch', "LIAFNode only supports for backend='torch'!"
assert self.single_step_cupy_fp32_inference == False, "LIAFNode does not support for single_step_cupy_fp32_inference!"
@property
def supported_backends(self):
return ('torch',)
[文档] def single_step_forward(self, x: torch.Tensor):
self.neuronal_charge(x)
if self.threshold_related:
y = self.act(self.v - self.v_threshold)
else:
y = self.act(self.v)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return y
[文档]class KLIFNode(BaseNode):
def __init__(self, scale_reset: bool = False, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = False, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <KLIFNode.__init__-en>`
.. _KLIFNode.__init__-cn:
:param scale_reset: 是否在 ``neuronal_reset`` 时将 ``v`` 进行缩放
:type scale_reset: bool
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步) 或 `'m'` (多步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
`KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_ 提出的K-based Leaky Integrate-and-Fire 神经元模型,可以看作是带漏电的积分器。其阈下神经动力学方程为:
若 ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
若 ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
注意,KLIF神经元的放电和重置与普通的神经元不同,为:
.. math::
F[t] &= \\mathrm{ReLU}(kH[t])
S[t] &= \\Theta(F[t] - V_{th})
如果 ``scale_reset == False``,则
.. math::
V[t] = \\begin{cases}
F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\
F[t] - S[t]V_{th}, soft~~reset
\\end{cases}
如果 ``scale_reset == True``,则
.. math::
V[t] = \\begin{cases}
\\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\
\\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset
\\end{cases}
* :ref:`中文API <KLIFNode.__init__-cn>`
.. _KLIFNode.__init__-en:
:param scale_reset: whether scale ``v`` in ``neuronal_reset``
:type scale_reset: bool
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
The K-based Leaky Integrate-and-Fire neuron proposed by `KLIF: An optimized spiking neuron unit for tuning surrogate gradient slope and membrane potential <https://arxiv.org/abs/2302.09238>`_, which can be seen as a leaky integrator.
The subthreshold neural dynamics of it is as followed:
IF ``decay_input == True``:
.. math::
H[t] = V[t-1] + \\frac{1}{\\tau}(X[t] - (V[t-1] - V_{reset}))
IF ``decay_input == False``:
.. math::
H[t] = V[t-1] - \\frac{1}{\\tau}(V[t-1] - V_{reset}) + X[t]
Note that the neuronal fire and reset of the KLIF neuron is different from native neurons:
.. math::
F[t] &= \\mathrm{ReLU}(kH[t])
S[t] &= \\Theta(F[t] - V_{th})
If ``scale_reset == False``, then
.. math::
V[t] = \\begin{cases}
F[t](1-S[t]) + V_{reset}S[t], hard~~reset \\\\
F[t] - S[t]V_{th}, soft~~reset
\\end{cases}
Elif ``scale_reset == True``, then
.. math::
V[t] = \\begin{cases}
\\frac{F[t]}{k}(1-S[t]) + V_{reset}S[t], hard~~reset \\\\
\\frac{1}{k}(F[t] - S[t]V_{th}), soft~~reset
\\end{cases}
"""
assert isinstance(tau, float) and tau > 1.
if backend == 'cupy':
raise NotImplementedError("The CuPy backend for the KLIF neuron has not been implemented!")
super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
self.scale_reset = scale_reset
self.tau = tau
self.decay_input = decay_input
self.k = nn.Parameter(torch.as_tensor(1.))
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.v_reset is None:
v_reset = 0.
else:
v_reset = self.v_reset
if self.decay_input:
self.v = self.neuronal_charge_decay_input(x, self.v, v_reset, self.tau, self.k)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, v_reset, self.tau, self.k)
[文档] def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
if self.scale_reset:
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) / self.k
else:
# hard reset
self.v = self.jit_hard_reset(self.v / self.k, spike_d, self.v_reset)
else:
if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)
else:
# hard reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)
[文档]class PSN(nn.Module, base.MultiStepModule):
def __init__(self, T: int, surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan()):
"""
:param T: the number of time-steps
:type T: int
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
The Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as
.. math::
H &= WX, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\label{eq psn neuronal charge}\\\\
S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}
where :math:`W` is the learnable weight matrix, and :math:`B` is the learnable threshold.
.. admonition:: Note
:class: note
The PSN only supports the multi-step mode.
"""
super().__init__()
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)
[文档] def forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
h_seq = torch.addmm(self.bias, self.weight, x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq)
return spike_seq.view(x_seq.shape)
[文档]class MaskedPSN(base.MemoryModule):
[文档] @staticmethod
@torch.jit.script
def gen_masked_weight(lambda_: torch.Tensor, mask0: torch.Tensor, mask1: torch.Tensor, weight: torch.Tensor):
return (lambda_ * mask0 + (1. - lambda_) * mask1) * weight
[文档] def masked_weight(self):
if self.lambda_ >= 1.:
return self.weight * self.mask0
else:
return self.gen_masked_weight(self.lambda_, self.mask0, self.mask1, self.weight)
def __init__(self, k: int, T: int, lambda_init: float = 0.,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's'):
"""
:param k: the order of the Masked PSN
:type k: int
:param T: the number of time-steps
:type T: int
:param lambda_init: the initial value of :math:`\\lambda` to adjust the progressive masking process
:type lambda_init: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
The Masked Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as
.. math::
H &= (W \\cdot {M}_{k})X, ~~~~~~~~~~~~~~~W \\in \\mathbb{R}^{T \\times T}, {M}_{k} \\in \\mathbb{R}^{T \\times T}, X \\in \\mathbb{R}^{T \\times N} \\\\
S &= \\Theta(H - B), ~~~~~B \\in \\mathbb{R}^{T}, S\\in \\{0, 1\\}^{T \\times N}
where :math:`W` is the learnable weight matrix, :math:`B` is the learnable threshold, and :math:`{M}_{k}` is defined as
.. math::
{M}_{k}[i][j] = \\begin{cases}
1, ~~ j \\leq i \\leq j + k - 1 \\\\
0, \\mathrm{otherwise}
\\end{cases}.
:math:`\\lambda` is used to adjust the progressive masking process, which is
.. math::
M_{k}(\\lambda) = \\lambda \\cdot M_{k} + (1 - \\lambda) \\cdot J,
where :math:`J` is an all-one matrix.
The user can set :math:`\\lambda` during training by calling ``self.lambda_ = ...``.
.. admonition:: Note
:class: note
The masked PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.
"""
super().__init__()
self.register_memory('time_step', 0)
self.register_memory('queue', [])
self.step_mode = step_mode
self.k = k
self.T = T
self.surrogate_function = surrogate_function
weight = torch.zeros([T, T])
bias = torch.zeros([T, 1])
self.register_buffer('_lambda_', torch.as_tensor(lambda_init))
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(bias)
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
nn.init.constant_(self.bias, -1.)
mask1 = torch.ones([T, T])
mask0 = torch.tril(mask1) * torch.triu(mask1, -(self.k - 1))
self.register_buffer('mask0', mask0)
self.register_buffer('mask1', mask1)
[文档] def single_step_forward(self, x: torch.Tensor):
if self.lambda_ < 1.:
raise ValueError("The masked PSN can not work in single-step mode when k < 1!")
self.queue.append(x.flatten())
if self.queue.__len__() > self.k:
self.queue.pop(0)
if self.time_step + 1 > self.T:
raise OverflowError(f"The MaskedPSN(T={self.T}) has run {self.time_step + 1} time-steps!")
weight = self.masked_weight()[self.time_step, self.time_step + 1 - self.queue.__len__(): self.time_step + 1]
x_seq = torch.stack(self.queue)
for i in range(x.dim()):
weight = weight.unsqueeze(-1)
h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias[self.time_step])
self.time_step += 1
return spike.view(x.shape)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
assert x_seq.shape[0] == self.T
h_seq = torch.addmm(self.bias, self.masked_weight(), x_seq.flatten(1))
spike_seq = self.surrogate_function(h_seq).view(x_seq.shape)
return spike_seq
@property
def lambda_(self):
return self._lambda_
@lambda_.setter
def lambda_(self, value: float):
torch.fill_(self.lambda_, value)
[文档]class SlidingPSN(base.MemoryModule):
@property
def supported_backends(self):
return 'gemm', 'conv'
[文档] def gen_gemm_weight(self, T: int):
weight = torch.zeros([T, T], device=self.weight.device)
for i in range(T):
end = i + 1
start = max(0, i + 1 - self.k)
length = min(end - start, self.k)
weight[i][start: end] = self.weight[self.k - length: self.k]
return weight
def __init__(self, k: int, exp_init: bool = True,
surrogate_function: surrogate.SurrogateFunctionBase = surrogate.ATan(), step_mode: str = 's',
backend: str = 'gemm'):
"""
:param k: the order of the Sliding PSN
:type k: int
:param exp_init: if ``True``, the weight will be initialized as ``(..., 1/4, 1/2, 1)``. If ``False``, the weight will be initialized by the kaiming uniform
:type exp_init: bool
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step)
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str
The Sliding Parallel Spiking Neuron proposed in `Parallel Spiking Neurons with High Efficiency and Long-term Dependencies Learning Ability <https://arxiv.org/abs/2304.12760>`_. The neuronal dynamics are defined as
.. math::
H[t] &= \\sum_{i=0}^{k-1}W_{i}\\cdot X[t - k + 1 + i], \\\\
S[t] &= \\Theta(H[t] - B),
where :math:`W = [W_{0}, W_{1}, ..., W_{k-1}] \\in \\mathbb{R}^{T}` is the learnable weight, and :math:`B` is the learnable threshold.
.. admonition:: Note
:class: note
The Sliding PSN supports both single-step and multi-step mode. But using the multi-step mode is much faster than the single-step mode.
"""
super().__init__()
self.register_memory('queue', [])
self.step_mode = step_mode
self.k = k
self.surrogate_function = surrogate_function
self.backend = backend
if exp_init:
weight = torch.ones([k])
for i in range(k - 2, -1, -1):
weight[i] = weight[i + 1] / 2.
else:
weight = torch.ones([1, k])
nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
weight = weight[0]
self.weight = nn.Parameter(weight)
self.bias = nn.Parameter(torch.as_tensor(-1.))
[文档] def single_step_forward(self, x: torch.Tensor):
self.queue.append(x.flatten())
if self.queue.__len__() > self.k:
self.queue.pop(0)
weight = self.weight[self.k - self.queue.__len__(): self.k]
x_seq = torch.stack(self.queue)
for i in range(x.dim()):
weight = weight.unsqueeze(-1)
h = torch.sum(weight * x_seq, 0)
spike = self.surrogate_function(h + self.bias)
return spike.view(x.shape)
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
if self.backend == 'gemm':
weight = self.gen_gemm_weight(x_seq.shape[0])
h_seq = torch.addmm(self.bias, weight, x_seq.flatten(1)).view(x_seq.shape)
return self.surrogate_function(h_seq)
elif self.backend == 'conv':
# x_seq.shape = [T, N, *]
x_seq_shape = x_seq.shape
# [T, N, *] -> [T, N] -> [N, T] -> [N, 1, T]
x_seq = x_seq.flatten(1).t().unsqueeze(1)
x_seq = F.pad(x_seq, pad=(self.k - 1, 0))
x_seq = F.conv1d(x_seq, self.weight.view(1, 1, -1), stride=1)
x_seq = x_seq.squeeze(1).t().view(x_seq_shape)
return self.surrogate_function(x_seq + self.bias)
else:
raise NotImplementedError(self.backend)
[文档]class GatedLIFNode(base.MemoryModule):
def __init__(self, T: int, inplane = None,
init_linear_decay = None, init_v_subreset = None, init_tau: float = 0.25, init_v_threshold: float = 0.5, init_conduct: float = 0.5,
surrogate_function: Callable = surrogate.Sigmoid(), step_mode='m', backend='torch'):
"""
* :ref:`中文API <GatedLIFNode.__init__-cn>`
.. _GatedLIFNode.__init__-cn:
:param T: 时间步长
:type T: int
:param inplane: 输入tensor的通道数。不设置inplane,则默认使用layer-wise GLIF
:type inplane: int
:param init_linear_decay: 膜电位线性衰减常数初始值,不设置就默认为init_v_threshold/(T * 2)
:type init_linear_decay: float
:param init_v_subreset: 膜电位复位电压初始值
:type init_v_subreset: float
:param init_tau: 膜电位时间常数的初始值
:type init_tau: float
:param init_v_threshold: 神经元的阈值电压初始值
:type init_v_threshold: float
:param init_conduct: 膜电位电导率初始值
:type init_conduct: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param step_mode: 步进模式,只支持 `'m'` (多步)
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。gated-LIF只支持torch
:type backend: str
模型出处:`GLIF: A Unified Gated Leaky Integrate-and-Fire Neuron for Spiking Neural Networks <https://openreview.net/forum?id=UmFSx2c4ubT>`
GLIF中所有的膜电位参数都是可学的,包括新引入的门控系数。
* :ref:`API in English <GatedLIFNode.__init__-en>`
.. _GatedLIFNode.__init__-en:
:param T: time-step
:type T: int
:param inplane: input tensor channel number, default: None(layer-wise GLIF). If set, otherwise(channel-wise GLIF)
:type inplane: int
:param init_linear_decay: initial linear-decay constant,default: init_v_threshold/(T * 2)
:type init_linear_decay: float
:param init_v_subreset: initial soft-reset constant
:type init_v_subreset: float
:param init_tau: initial exponential-decay constant
:type init_tau: float
:param init_v_threshold: initial menbrane potential threshold
:type init_v_threshold: float
:param init_conduct: initial conduct
:type init_conduct: float
:param surrogate_function: surrogate gradient
:type surrogate_function: Callable
:param step_mode: step mode, only support `'m'` (multi-step)
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str
Gated LIF neuron refers to `GLIF: A Unified Gated Leaky Integrate-and-Fire Neuron for Spiking Neural Networks <https://openreview.net/forum?id=UmFSx2c4ubT>`
All membrane-related parameters are learnable, including the gates.
"""
assert isinstance(init_tau, float) and init_tau < 1.
assert isinstance(T, int) and T is not None
assert isinstance(inplane, int) or inplane is None
assert (isinstance(init_linear_decay, float) and init_linear_decay < 1.) or init_linear_decay is None
assert (isinstance(init_v_subreset, float) and init_v_subreset < 1.) or init_v_subreset is None
assert step_mode == 'm'
super().__init__()
self.surrogate_function = surrogate_function
self.backend = backend
self.step_mode = step_mode
self.T = T
self.register_memory('v', 0.)
self.register_memory('u', 0.)
self.channel_wise = inplane is not None
if self.channel_wise: #channel-wise learnable params
self.alpha, self.beta, self.gamma = [nn.Parameter(torch.tensor(0.2 * (np.random.rand(inplane) - 0.5), dtype=torch.float)) for i in range(3)]
self.tau = nn.Parameter(- math.log(1 / init_tau - 1) * torch.ones(inplane, dtype=torch.float))
self.v_threshold = nn.Parameter(- math.log(1 / init_v_threshold - 1) * torch.ones(inplane, dtype=torch.float))
init_linear_decay = init_v_threshold / (T * 2) if init_linear_decay is None else init_linear_decay
self.linear_decay = nn.Parameter(- math.log(1 / init_linear_decay - 1) * torch.ones(inplane, dtype=torch.float))
init_v_subreset = init_v_threshold if init_v_subreset is None else init_v_subreset
self.v_subreset = nn.Parameter(- math.log(1 / init_v_subreset - 1) * torch.ones(inplane, dtype=torch.float))
self.conduct = nn.Parameter(- math.log(1 / init_conduct - 1) * torch.ones((T, inplane), dtype=torch.float))
else: #layer-wise learnable params
self.alpha, self.beta, self.gamma = [nn.Parameter(torch.tensor(0.2 * (np.random.rand() - 0.5), dtype=torch.float)) for i in range(3)]
self.tau = nn.Parameter(torch.tensor(- math.log(1 / init_tau - 1), dtype=torch.float))
self.v_threshold = nn.Parameter(torch.tensor(- math.log(1 / init_v_threshold - 1), dtype=torch.float))
init_linear_decay = init_v_threshold / (T * 2) if init_linear_decay is None else init_linear_decay
self.linear_decay = nn.Parameter(torch.tensor(- math.log(1 / init_linear_decay - 1), dtype=torch.float))
init_v_subreset = init_v_threshold if init_v_subreset is None else init_v_subreset
self.v_subreset = nn.Parameter(torch.tensor(- math.log(1 / init_v_subreset - 1), dtype=torch.float))
self.conduct = nn.Parameter(- math.log(1 / init_conduct - 1) * torch.ones(T, dtype=torch.float))
@property
def supported_backends(self):
return 'torch'
[文档] def neuronal_charge(self, x: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, t):
input = x * (1 - beta * (1 - self.conduct[t].view(1, -1, 1, 1).sigmoid()))
self.u = ((1 - alpha * (1 - self.tau.view(1, -1, 1, 1).sigmoid())) * self.v \
- (1 - alpha) * self.linear_decay.view(1, -1, 1, 1).sigmoid()) \
+ input
[文档] def neuronal_reset(self, spike, alpha: torch.Tensor, gamma: torch.Tensor):
self.u = self.u - (1 - alpha * (1 - self.tau.view(1, -1, 1, 1).sigmoid())) * self.v * gamma * spike \
- (1 - gamma) * self.v_subreset.view(1, -1, 1, 1).sigmoid() * spike
[文档] def neuronal_fire(self):
return self.surrogate_function(self.u - self.v_threshold.view(1, -1, 1, 1).sigmoid())
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
alpha, beta, gamma = self.alpha.view(1, -1, 1, 1).sigmoid(), self.beta.view(1, -1, 1, 1).sigmoid(), self.gamma.view(1, -1, 1, 1).sigmoid()
y_seq = []
spike = torch.zeros(x_seq.shape[1:], device=x_seq.device)
for t in range(self.T):
self.neuronal_charge(x_seq[t], alpha, beta, t)
self.neuronal_reset(spike, alpha, gamma)
spike = self.neuronal_fire()
self.v = self.u
y_seq.append(spike)
return torch.stack(y_seq)
##########################################################################################################
# DSR modules
##########################################################################################################
import torch.distributed as dist
[文档]class DSRIFNode(base.MemoryModule):
def __init__(self, T: int = 20, v_threshold: float = 6., alpha: float = 0.5, v_threshold_training: bool = True,
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.01, step_mode='m',
backend='torch', **kwargs):
"""
* :ref:`中文API <DSRIFNode.__init__-cn>`
.. _DSRIFNode.__init__-cn:
:param T: 时间步长
:type T: int
:param v_threshold: 神经元的阈值电压初始值
:type v_threshold: float
:param alpha: 放电阈值的缩放因子
:type alpha: float
:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'`
:type v_threshold_training: bool
:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子
:type v_threshold_grad_scaling: float
:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值
:type v_threshold_lower_bound: float
:param step_mode: 步进模式,只支持 `'m'` (多步)
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch
:type backend: str
模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
* :ref:`API in English <DSRIFNode.__init__-en>`
.. _DSRIFNode.__init__-en:
:param T: time-step
:type T: int
:param v_threshold: initial menbrane potential threshold
:type v_threshold: float
:param alpha: the scaling factor for the menbrane potential threshold
:type alpha: float
:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'`
:type v_threshold_training: bool
:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold
:type v_threshold_grad_scaling: float
:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training
:type v_threshold_lower_bound: float
:param step_mode: step mode, only support `'m'` (multi-step)
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str
DSR IF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
"""
assert isinstance(T, int) and T is not None
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0
assert step_mode == 'm'
super().__init__()
self.backend = backend
self.step_mode = step_mode
self.T = T
if v_threshold_training:
self.v_threshold = nn.Parameter(torch.tensor(v_threshold))
else:
self.v_threshold = torch.tensor(v_threshold)
self.alpha = alpha
self.v_threshold_lower_bound = v_threshold_lower_bound
self.v_threshold_grad_scaling = v_threshold_grad_scaling
@property
def supported_backends(self):
return 'torch'
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
with torch.no_grad():
self.v_threshold.copy_(
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound)
iffunc = self.DSRIFFunction.apply
y_seq = iffunc(x_seq, self.T, self.v_threshold, self.alpha, self.v_threshold_grad_scaling)
return y_seq
[文档] class DSRIFFunction(torch.autograd.Function):
[文档] @staticmethod
def forward(ctx, inp, T=10, v_threshold=1.0, alpha=0.5, v_threshold_grad_scaling=1.0):
ctx.save_for_backward(inp)
mem_potential = torch.zeros_like(inp[0]).to(inp.device)
spikes = []
for t in range(inp.size(0)):
mem_potential = mem_potential + inp[t]
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float()
mem_potential = mem_potential - spike
spikes.append(spike)
output = torch.stack(spikes)
ctx.T = T
ctx.v_threshold = v_threshold
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling
return output
[文档] @staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
inp = ctx.saved_tensors[0]
T = ctx.T
v_threshold = ctx.v_threshold
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling
input_rate_coding = torch.mean(inp, 0)
grad_output_coding = torch.mean(grad_output, 0) * T
input_grad = grad_output_coding.clone()
input_grad[(input_rate_coding < 0) | (input_rate_coding > v_threshold)] = 0
input_grad = torch.stack([input_grad for _ in range(T)]) / T
v_threshold_grad = grad_output_coding.clone()
v_threshold_grad[input_rate_coding <= v_threshold] = 0
v_threshold_grad = torch.sum(v_threshold_grad) * v_threshold_grad_scaling
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1:
try:
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM)
except:
raise RuntimeWarning(
'Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.')
return input_grad, None, v_threshold_grad, None, None
[文档]class DSRLIFNode(base.MemoryModule):
def __init__(self, T: int = 20, v_threshold: float = 1., tau: float = 2.0, delta_t: float = 0.05,
alpha: float = 0.3, v_threshold_training: bool = True,
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.1, step_mode='m',
backend='torch', **kwargs):
"""
* :ref:`中文API <DSRLIFNode.__init__-cn>`
.. _DSRLIFNode.__init__-cn:
:param T: 时间步长
:type T: int
:param v_threshold: 神经元的阈值电压初始值
:type v_threshold: float
:param tau: 膜电位时间常数
:type tau: float
:param delta_t: 对微分方程形式的LIF模型进行离散化的步长
:type delta_t: float
:param alpha: 放电阈值的缩放因子
:type alpha: float
:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'`
:type v_threshold_training: bool
:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子
:type v_threshold_grad_scaling: float
:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值
:type v_threshold_lower_bound: float
:param step_mode: 步进模式,只支持 `'m'` (多步)
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch
:type backend: str
模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
* :ref:`API in English <DSRLIFNode.__init__-en>`
.. _DSRLIFNode.__init__-en:
:param T: time-step
:type T: int
:param v_threshold: initial menbrane potential threshold
:type v_threshold: float
:param tau: membrane time constant
:type tau: float
:param delta_t: discretization step for discretizing the ODE version of the LIF model
:type delta_t: float
:param alpha: the scaling factor for the menbrane potential threshold
:type alpha: float
:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'`
:type v_threshold_training: bool
:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold
:type v_threshold_grad_scaling: float
:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training
:type v_threshold_lower_bound: float
:param step_mode: step mode, only support `'m'` (multi-step)
:type step_mode: str
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str
DSR LIF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
"""
assert isinstance(T, int) and T is not None
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0
assert step_mode == 'm'
super().__init__()
self.backend = backend
self.step_mode = step_mode
self.T = T
if v_threshold_training:
self.v_threshold = nn.Parameter(torch.tensor(v_threshold))
else:
self.v_threshold = torch.tensor(v_threshold)
self.tau = tau
self.delta_t = delta_t
self.alpha = alpha
self.v_threshold_lower_bound = v_threshold_lower_bound
self.v_threshold_grad_scaling = v_threshold_grad_scaling
@property
def supported_backends(self):
return 'torch'
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
with torch.no_grad():
self.v_threshold.copy_(
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound)
liffunc = self.DSRLIFFunction.apply
y_seq = liffunc(x_seq, self.T, self.v_threshold, self.tau, self.delta_t, self.alpha,
self.v_threshold_grad_scaling)
return y_seq
[文档] @classmethod
def weight_rate_spikes(cls, data, tau, delta_t):
T = data.shape[0]
chw = data.size()[2:]
data_reshape = data.permute(list(range(1, len(chw) + 2)) + [0])
weight = torch.tensor([math.exp(-1 / tau * (delta_t * T - ii * delta_t)) for ii in range(1, T + 1)]).to(
data_reshape.device)
return (weight * data_reshape).sum(dim=len(chw) + 1) / weight.sum()
[文档] class DSRLIFFunction(torch.autograd.Function):
[文档] @staticmethod
def forward(ctx, inp, T, v_threshold, tau, delta_t=0.05, alpha=0.3, v_threshold_grad_scaling=1.0):
ctx.save_for_backward(inp)
mem_potential = torch.zeros_like(inp[0]).to(inp.device)
beta = math.exp(-delta_t / tau)
spikes = []
for t in range(inp.size(0)):
mem_potential = beta * mem_potential + (1 - beta) * inp[t]
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float()
mem_potential = mem_potential - spike
spikes.append(spike / delta_t)
output = torch.stack(spikes)
ctx.T = T
ctx.v_threshold = v_threshold
ctx.tau = tau
ctx.delta_t = delta_t
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling
return output
[文档] @staticmethod
def backward(ctx, grad_output):
inp = ctx.saved_tensors[0]
T = ctx.T
v_threshold = ctx.v_threshold
delta_t = ctx.delta_t
tau = ctx.tau
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling
input_rate_coding = DSRLIFNode.weight_rate_spikes(inp, tau, delta_t)
grad_output_coding = DSRLIFNode.weight_rate_spikes(grad_output, tau, delta_t) * T
indexes = (input_rate_coding > 0) & (input_rate_coding < v_threshold / delta_t * tau)
input_grad = torch.zeros_like(grad_output_coding)
input_grad[indexes] = grad_output_coding[indexes].clone() / tau
input_grad = torch.stack([input_grad for _ in range(T)]) / T
v_threshold_grad = grad_output_coding.clone()
v_threshold_grad[input_rate_coding <= v_threshold / delta_t * tau] = 0
v_threshold_grad = torch.sum(v_threshold_grad) * delta_t * v_threshold_grad_scaling
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1:
try:
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM)
except:
raise RuntimeWarning('Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.')
return input_grad, None, v_threshold_grad, None, None, None, None
##########################################################################################################
# OTTT modules
##########################################################################################################
[文档]class OTTTLIFNode(LIFNode):
def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
v_reset: float = None, surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <OTTTLIFNode.__init__-en>`
.. _OTTTLIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留
:type detach_reset: bool
:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
神经元模型出处:`Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>`
模型正向传播和Leaky Integrate-and-Fire神经元相同;用于随时间在线训练
* :ref:`中文API <OTTTLIFNode.__init__-cn>`
.. _OTTTLIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in
the module, and is retained solely for code consistency
:type detach_reset: bool
:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
OTTT LIF neuron refers to `Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>`
The forward propagation is the same as the Leaky Integrate-and-Fire neuron; used for online training through time
"""
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training."
"""
膜电位将在前向传播过程中重新登记为缓存,以支持多卡分布式训练的情况下保留信息在各时刻进行多次反向传播
membrane potential will be registered as buffer during forward, to support multiple backpropagation for all time steps with
reserved informtion under distributed training on multiple GPUs
"""
self._memories.pop('v')
[文档] def reset(self):
super().reset()
if hasattr(self, 'v'):
del self.v
if hasattr(self, 'trace'):
del self.trace
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v.detach()
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
[文档] @staticmethod
@torch.jit.script
def track_trace(spike: torch.Tensor, trace: torch.Tensor, tau: float):
with torch.no_grad():
trace = trace * (1. - 1. / tau) + spike
return trace
[文档] def single_step_forward(self, x: torch.Tensor):
"""
训练时,输出脉冲和迹;推理时,输出脉冲
训练时需要将后续参数模块用layer.py中定义的GradwithTrace进行包装,根据迹计算梯度
output spike and trace during training; output spike during inference
during training, successive parametric modules shoule be wrapped by GradwithTrace defined in layer.py, to calculate gradients with traces
"""
if not hasattr(self, 'v'):
if self.v_reset is None:
self.register_buffer('v', torch.zeros_like(x))
else:
self.register_buffer('v', torch.ones_like(x) * self.v_reset)
if self.training:
if not hasattr(self, 'trace'):
self.register_buffer('trace', torch.zeros_like(x))
if self.backend == 'torch':
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
self.trace = self.track_trace(spike, self.trace, self.tau)
return [spike, self.trace]
else:
raise ValueError(self.backend)
else:
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike
##########################################################################################################
# SLTT modules
##########################################################################################################
[文档]class SLTTLIFNode(LIFNode):
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <SLTTLIFNode.__init__-en>`
.. _SLTTLIFNode.__init__-cn:
:param tau: 膜电位时间常数
:type tau: float
:param decay_input: 输入是否也会参与衰减
:type decay_input: bool
:param v_threshold: 神经元的阈值电压
:type v_threshold: float
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable
:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留
:type detach_reset: bool
:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步)
:type step_mode: str
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool
神经元模型出处:`Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks
<https://arxiv.org/pdf/2302.14311.pdf>`.模型正向传播和Leaky Integrate-and-Fire神经元相同.
* :ref:`中文API <SLTTLIFNode.__init__-cn>`
.. _SLTTLIFNode.__init__-en:
:param tau: membrane time constant
:type tau: float
:param decay_input: whether the input will decay
:type decay_input: bool
:param v_threshold: threshold of this neurons layer
:type v_threshold: float
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable
:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in
the module, and is retained solely for code consistency
:type detach_reset: bool
:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool
SLTT LIF neuron refers to `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks
<https://arxiv.org/pdf/2302.14311.pdf>`. The forward propagation is the same as the Leaky Integrate-and-Fire neuron's.
"""
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training."
self._memories.pop('v')
[文档] def reset(self):
super().reset()
if hasattr(self, 'v'):
del self.v
@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch')
else:
raise ValueError(self.step_mode)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v.detach()
if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)
else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)
[文档] def single_step_forward(self, x: torch.Tensor):
if not hasattr(self, 'v'):
if self.v_reset is None:
self.register_buffer('v', torch.zeros_like(x))
else:
self.register_buffer('v', torch.ones_like(x) * self.v_reset)
if self.training:
if self.backend == 'torch':
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
else:
raise ValueError(self.backend)
else:
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike
##########################################################################################################
# Current-based LIF (CLIF) modules
##########################################################################################################
[文档]class CLIFNode(BaseNode):
def __init__(self, c_decay: float = 0.5, v_decay: float = 0.75, v_threshold: float = 0.5,
v_reset: float = 0., surrogate_function: Callable = surrogate.Rect()):
super().__init__(v_threshold, v_reset, surrogate_function)
self.register_memory('c', 0.)
self.c_decay = c_decay
self.v_decay = v_decay
[文档] def neuronal_charge(self, x: torch.Tensor):
self.c = self.c * self.c_decay + x
self.v = self.v * self.v_decay + self.c
[文档] def single_step_forward(self, x: torch.Tensor):
self.v_float_to_tensor(x)
self.c_float_to_tensor(x)
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
spike_seq = []
for t in range(T):
spike = self.single_step_forward(x_seq[t])
spike_seq.append(spike)
return torch.stack(spike_seq)
[文档] def c_float_to_tensor(self, c: torch.Tensor):
if isinstance(self.c, float):
c_init = self.c
self.c = torch.full_like(c.data, fill_value=c_init)
##########################################################################################################
# Noisy modules for exploration of RL
##########################################################################################################
"""Generate colored noise."""
from typing import Union, Iterable, Optional
from numpy import sqrt, newaxis, integer
from numpy.fft import irfft, rfftfreq
from numpy.random import default_rng, Generator, RandomState
from numpy import sum as npsum
[文档]def powerlaw_psd_gaussian(
exponent: float,
size: Union[int, Iterable[int]],
fmin: float = 0.0,
random_state: Optional[Union[int, Generator, RandomState]] = None
):
"""Gaussian (1/f)**beta noise.
Based on the algorithm in:
Timmer, J. and Koenig, M.:
On generating power law noise.
Astron. Astrophys. 300, 707-710 (1995)
Normalised to unit variance
Parameters:
-----------
exponent : float
The power-spectrum of the generated noise is proportional to
S(f) = (1 / f)**beta
flicker / pink noise: exponent beta = 1
brown noise: exponent beta = 2
Furthermore, the autocorrelation decays proportional to lag**-gamma
with gamma = 1 - beta for 0 < beta < 1.
There may be finite-size issues for beta close to one.
shape : int or iterable
The output has the given shape, and the desired power spectrum in
the last coordinate. That is, the last dimension is taken as time,
and all other components are independent.
fmin : float, optional
Low-frequency cutoff.
Default: 0 corresponds to original paper.
The power-spectrum below fmin is flat. fmin is defined relative
to a unit sampling rate (see numpy's rfftfreq). For convenience,
the passed value is mapped to max(fmin, 1/samples) internally
since 1/samples is the lowest possible finite frequency in the
sample. The largest possible value is fmin = 0.5, the Nyquist
frequency. The output for this value is white noise.
random_state : int, numpy.integer, numpy.random.Generator, numpy.random.RandomState,
optional
Optionally sets the state of NumPy's underlying random number generator.
Integer-compatible values or None are passed to np.random.default_rng.
np.random.RandomState or np.random.Generator are used directly.
Default: None.
Returns
-------
out : array
The samples.
Examples:
---------
# generate 1/f noise == pink noise == flicker noise
>>> import colorednoise as cn
>>> y = cn.powerlaw_psd_gaussian(1, 5)
"""
# Make sure size is a list so we can iterate it and assign to it.
if isinstance(size, (integer, int)):
size = [size]
elif isinstance(size, Iterable):
size = list(size)
else:
raise ValueError("Size must be of type int or Iterable[int]")
# The number of samples in each time series
samples = size[-1]
# Calculate Frequencies (we asume a sample rate of one)
# Use fft functions for real output (-> hermitian spectrum)
f = rfftfreq(samples) # type: ignore # mypy 1.5.1 has problems here
# Validate / normalise fmin
if 0 <= fmin <= 0.5:
fmin = max(fmin, 1./samples) # Low frequency cutoff
else:
raise ValueError("fmin must be chosen between 0 and 0.5.")
# Build scaling factors for all frequencies
s_scale = f
ix = npsum(s_scale < fmin) # Index of the cutoff
if ix and ix < len(s_scale):
s_scale[:ix] = s_scale[ix]
s_scale = s_scale**(-exponent/2.)
# Calculate theoretical output standard deviation from scaling
w = s_scale[1:].copy()
w[-1] *= (1 + (samples % 2)) / 2. # correct f = +-0.5
sigma = 2 * sqrt(npsum(w**2)) / samples
# Adjust size to generate one Fourier component per frequency
size[-1] = len(f)
# Add empty dimension(s) to broadcast s_scale along last
# dimension of generated random power + phase (below)
dims_to_add = len(size) - 1
s_scale = s_scale[(newaxis,) * dims_to_add + (Ellipsis,)]
# prepare random number generator
normal_dist = _get_normal_distribution(random_state)
# Generate scaled random power + phase
sr = normal_dist(scale=s_scale, size=size)
si = normal_dist(scale=s_scale, size=size)
# If the signal length is even, frequencies +/- 0.5 are equal
# so the coefficient must be real.
if not (samples % 2):
si[..., -1] = 0
sr[..., -1] *= sqrt(2) # Fix magnitude
# Regardless of signal length, the DC component must be real
si[..., 0] = 0
sr[..., 0] *= sqrt(2) # Fix magnitude
# Combine power + corrected phase to Fourier components
s = sr + 1J * si
# Transform to real time series & scale to unit variance
y = irfft(s, n=samples, axis=-1) / sigma
return y
def _get_normal_distribution(random_state: Optional[Union[int, Generator, RandomState]]):
normal_dist = None
if isinstance(random_state, (integer, int)) or random_state is None:
random_state = default_rng(random_state)
normal_dist = random_state.normal
elif isinstance(random_state, (Generator, RandomState)):
normal_dist = random_state.normal
else:
raise ValueError(
"random_state must be one of integer, numpy.random.Generator, "
"numpy.random.Randomstate"
)
return normal_dist
[文档]class NoisyBaseNode(nn.Module, base.MultiStepModule):
def __init__(self, num_node, is_training: bool = True, T: int = 5, sigma_init: float = 0.5,
beta: float = 0.0, v_threshold: float = 0.5, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
assert isinstance(v_reset, float) or v_reset is None
assert isinstance(v_threshold, float)
super().__init__()
self.num_node = num_node
self.is_training = is_training
self.T = T
self.beta = beta
self.sigma_v = sigma_init / math.sqrt(num_node)
self.cn_v = None
self.sigma_s = sigma_init / math.sqrt(num_node)
self.cn_s = None
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
if self.v_reset is None:
self.v = self.v - spike * self.v_threshold
else:
self.v = (1. - spike) * self.v + spike * self.v_reset
[文档] def init_tensor(self, data: torch.Tensor):
self.v = torch.full_like(data, fill_value=self.v_reset)
[文档] def forward(self, x_seq: torch.Tensor):
self.init_tensor(x_seq[0].data)
y = []
if self.is_training:
if self.cn_v is None or self.cn_s is None:
self.noise_step += 1
for t in range(self.T):
if self.cn_v is None:
self.neuronal_charge(x_seq[t] + self.sigma_v * self.eps_v_seq[self.noise_step][t].to(x_seq.device))
else:
self.neuronal_charge(x_seq[t] + self.sigma_v * self.cn_v[:, t])
spike = self.neuronal_fire()
self.neuronal_reset(spike)
if self.cn_s is None:
spike = spike + self.sigma_s * self.eps_s_seq[self.noise_step][t].to(x_seq.device)
else:
spike = spike + self.sigma_s * self.cn_s[:, t]
y.append(spike)
else:
for t in range(self.T):
self.neuronal_charge(x_seq[t])
spike = self.neuronal_fire()
self.neuronal_reset(spike)
y.append(spike)
return torch.stack(y)
[文档] def reset_noise(self, num_rl_step):
eps_shape = [self.num_node, num_rl_step * self.T]
per_order = [1, 2, 0]
# (nodes, steps * T) -> (nodes, steps, T) -> (steps, T, nodes)
self.eps_v_seq = torch.FloatTensor(powerlaw_psd_gaussian(self.beta, eps_shape).reshape(self.num_node, num_rl_step, self.T)).permute(per_order)
self.eps_s_seq = torch.FloatTensor(powerlaw_psd_gaussian(self.beta, eps_shape).reshape(self.num_node, num_rl_step, self.T)).permute(per_order)
self.noise_step = -1
[文档] def get_colored_noise(self):
cn = [self.eps_v_seq[self.noise_step], self.eps_s_seq[self.noise_step]]
return torch.cat(cn, dim=1)
[文档] def load_colored_noise(self, cn):
self.cn_v = cn[:, :, :self.num_node]
self.cn_s = cn[:, :, self.num_node:]
[文档] def cancel_load(self):
self.cn_v = None
self.cn_s = None
[文档]class NoisyCLIFNode(NoisyBaseNode):
def __init__(self, num_node, c_decay: float = 0.5, v_decay: float = 0.75, is_training: bool = True,
T: int = 5, sigma_init: float = 0.5, beta: float = 0.0, v_threshold: float = 0.5,
v_reset: float = 0., surrogate_function: Callable = surrogate.Rect()):
super().__init__(num_node, is_training, T, sigma_init, beta, v_threshold,
v_reset, surrogate_function)
self.c_decay = c_decay
self.v_decay = v_decay
[文档] def neuronal_charge(self, x: torch.Tensor):
self.c = self.c * self.c_decay + x
self.v = self.v * self.v_decay + self.c
[文档] def init_tensor(self, data: torch.Tensor):
self.c = torch.full_like(data, fill_value=0.0)
self.v = torch.full_like(data, fill_value=self.v_reset)
##########################################################################################################
# Inter-Layer Connections (ILC) modules for population-coded spiking actor network
##########################################################################################################
[文档]class ILCBaseNode(nn.Module, base.MultiStepModule):
def __init__(self, act_dim, dec_pop_dim, v_threshold: float = 1.0, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
assert isinstance(v_reset, float) or v_reset is None
assert isinstance(v_threshold, float)
super().__init__()
self.act_dim = act_dim
self.out_pop_dim = act_dim * dec_pop_dim
self.dec_pop_dim = dec_pop_dim
self.conn = nn.Conv1d(act_dim, self.out_pop_dim, dec_pop_dim, groups=act_dim)
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
if self.v_reset is None:
self.v = self.v - spike * self.v_threshold
else:
self.v = (1. - spike) * self.v + spike * self.v_reset
[文档] def init_tensor(self, data: torch.Tensor):
self.v = torch.full_like(data, fill_value=self.v_reset)
[文档] def forward(self, x_seq: torch.Tensor):
self.init_tensor(x_seq[0].data)
T = x_seq.shape[0]
spike_seq = []
for t in range(T):
self.neuronal_charge(x_seq[t])
spike = self.neuronal_fire()
self.neuronal_reset(spike)
spike_seq.append(spike)
if t < T - 1:
x_seq[t + 1] = x_seq[t + 1] + self.conn(spike.view(-1, self.act_dim, self.dec_pop_dim)).view(-1, self.out_pop_dim)
return torch.stack(spike_seq)
[文档]class ILCCLIFNode(ILCBaseNode):
def __init__(self, act_dim, dec_pop_dim, c_decay: float = 0.5, v_decay: float = 0.75,
v_threshold: float = 0.5, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
super().__init__(act_dim, dec_pop_dim, v_threshold, v_reset, surrogate_function)
self.c_decay = c_decay
self.v_decay = v_decay
[文档] def neuronal_charge(self, x: torch.Tensor):
self.c = self.c * self.c_decay + x
self.v = self.v * self.v_decay + self.c
[文档] def init_tensor(self, data: torch.Tensor):
self.c = torch.full_like(data, fill_value=0.0)
self.v = torch.full_like(data, fill_value=self.v_reset)
[文档]class ILCLIFNode(ILCBaseNode):
def __init__(self, act_dim, dec_pop_dim, v_decay: float = 0.75,
v_threshold: float = 1.0, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
super().__init__(act_dim, dec_pop_dim, v_threshold, v_reset, surrogate_function)
self.v_decay = v_decay
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v * self.v_decay + x
[文档]class ILCIFNode(ILCBaseNode):
def __init__(self, act_dim, dec_pop_dim, v_threshold: float = 1.0, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
super().__init__(act_dim, dec_pop_dim, v_threshold, v_reset, surrogate_function)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
##########################################################################################################
# Noisy modules with Inter-Layer Connections (ILC)
##########################################################################################################
[文档]class NoisyILCBaseNode(nn.Module, base.MultiStepModule):
def __init__(self, act_dim, dec_pop_dim, is_training: bool = True, T: int = 5,
sigma_init: float = 0.5, beta: float = 0.0, v_threshold: float = 1.0,
v_reset: float = 0., surrogate_function: Callable = surrogate.Rect()):
assert isinstance(v_reset, float) or v_reset is None
assert isinstance(v_threshold, float)
super().__init__()
self.act_dim = act_dim
self.num_node = act_dim * dec_pop_dim
self.dec_pop_dim = dec_pop_dim
self.conn = nn.Conv1d(act_dim, self.num_node, dec_pop_dim, groups=act_dim)
self.is_training = is_training
self.T = T
self.beta = beta
self.sigma_v = sigma_init / math.sqrt(self.num_node)
self.cn_v = None
self.sigma_s = sigma_init / math.sqrt(self.num_node)
self.cn_s = None
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)
[文档] def neuronal_reset(self, spike):
if self.v_reset is None:
self.v = self.v - spike * self.v_threshold
else:
self.v = (1. - spike) * self.v + spike * self.v_reset
[文档] def init_tensor(self, data: torch.Tensor):
self.v = torch.full_like(data, fill_value=self.v_reset)
[文档] def forward(self, x_seq: torch.Tensor):
self.init_tensor(x_seq[0].data)
y = []
if self.is_training:
if self.cn_v is None or self.cn_s is None:
self.noise_step += 1
for t in range(self.T):
if self.cn_v is None:
self.neuronal_charge(x_seq[t] + self.sigma_v * self.eps_v_seq[self.noise_step][t].to(x_seq.device))
else:
self.neuronal_charge(x_seq[t] + self.sigma_v * self.cn_v[:, t])
spike = self.neuronal_fire()
self.neuronal_reset(spike)
if self.cn_s is None:
spike = spike + self.sigma_s * self.eps_s_seq[self.noise_step][t].to(x_seq.device)
else:
spike = spike + self.sigma_s * self.cn_s[:, t]
y.append(spike)
if t < self.T - 1:
x_seq[t + 1] = x_seq[t + 1] + self.conn(spike.view(-1, self.act_dim, self.dec_pop_dim)).view(-1, self.num_node)
else:
for t in range(self.T):
self.neuronal_charge(x_seq[t])
spike = self.neuronal_fire()
self.neuronal_reset(spike)
y.append(spike)
if t < self.T - 1:
x_seq[t + 1] = x_seq[t + 1] + self.conn(spike.view(-1, self.act_dim, self.dec_pop_dim)).view(-1, self.num_node)
return torch.stack(y)
[文档] def reset_noise(self, num_rl_step):
eps_shape = [self.num_node, num_rl_step * self.T]
per_order = [1, 2, 0]
# (nodes, steps * T) -> (nodes, steps, T) -> (steps, T, nodes)
self.eps_v_seq = torch.FloatTensor(powerlaw_psd_gaussian(self.beta, eps_shape).reshape(self.num_node, num_rl_step, self.T)).permute(per_order)
self.eps_s_seq = torch.FloatTensor(powerlaw_psd_gaussian(self.beta, eps_shape).reshape(self.num_node, num_rl_step, self.T)).permute(per_order)
self.noise_step = -1
[文档] def get_colored_noise(self):
cn = [self.eps_v_seq[self.noise_step], self.eps_s_seq[self.noise_step]]
return torch.cat(cn, dim=1)
[文档] def load_colored_noise(self, cn):
self.cn_v = cn[:, :, :self.num_node]
self.cn_s = cn[:, :, self.num_node:]
[文档] def cancel_load(self):
self.cn_v = None
self.cn_s = None
[文档]class NoisyILCCLIFNode(NoisyILCBaseNode):
def __init__(self, act_dim, dec_pop_dim, c_decay: float = 0.5, v_decay: float = 0.75,
is_training: bool = True, T: int = 5, sigma_init: float = 0.5,
beta: float = 0.0, v_threshold: float = 1.0, v_reset: float = 0.,
surrogate_function: Callable = surrogate.Rect()):
super().__init__(act_dim, dec_pop_dim, is_training, T, sigma_init, beta, v_threshold,
v_reset, surrogate_function)
self.c_decay = c_decay
self.v_decay = v_decay
[文档] def neuronal_charge(self, x: torch.Tensor):
self.c = self.c * self.c_decay + x
self.v = self.v * self.v_decay + self.c
[文档] def init_tensor(self, data: torch.Tensor):
self.c = torch.full_like(data, fill_value=0.0)
self.v = torch.full_like(data, fill_value=self.v_reset)
##########################################################################################################
# Non-spiking modules for floating-point output
##########################################################################################################
[文档]class NonSpikingBaseNode(nn.Module, base.MultiStepModule):
def __init__(self, decode='last-mem'):
super().__init__()
self.decode = decode
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def forward(self, x_seq: torch.Tensor):
self.v = torch.full_like(x_seq[0].data, fill_value=0.0)
T = x_seq.shape[0]
v_seq = []
for t in range(T):
self.neuronal_charge(x_seq[t])
v_seq.append(self.v)
if self.decode == 'max-mem':
mem = torch.max(torch.stack(v_seq, 0), 0).values
elif self.decode == 'max-abs-mem':
v_stack = torch.stack(v_seq, 0)
max_mem = torch.max(v_stack, 0).values
min_mem = torch.min(v_stack, 0).values
mem = max_mem * (max_mem.abs() > min_mem.abs()) + min_mem * (max_mem.abs() <= min_mem.abs())
elif self.decode == 'mean-mem':
mem = torch.mean(torch.stack(v_seq, 0), 0)
else: # 'last-mem'
mem = v_seq[-1]
return mem
[文档]class NonSpikingIFNode(NonSpikingBaseNode):
def __init__(self, decode='last-mem'):
super().__init__(decode)
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x
[文档]class NonSpikingLIFNode(NonSpikingBaseNode):
def __init__(self, tau: float = 2., decode='last-mem'):
super().__init__(decode)
self.tau = tau
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + (x - self.v) / self.tau
##########################################################################################################
# Noisy Non-spiking modules
##########################################################################################################
[文档]class NoisyNonSpikingBaseNode(nn.Module, base.MultiStepModule):
def __init__(self, num_node, is_training: bool = True, T: int = 5,
sigma_init: float = 0.5, beta: float = 0.0, decode: str = 'last-mem'):
super().__init__()
self.num_node = num_node
self.is_training = is_training
self.T = T
self.beta = beta
self.decode = decode
self.sigma = nn.Parameter(torch.FloatTensor(num_node))
self.sigma.data.fill_(sigma_init / math.sqrt(num_node))
self.cn = None
[文档] @abstractmethod
def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError
[文档] def init_tensor(self, data: torch.Tensor):
self.v = torch.full_like(data, fill_value=0.0)
[文档] def forward(self, x_seq: torch.Tensor):
self.init_tensor(x_seq[0].data)
v_seq = []
if self.is_training:
if self.cn is None:
self.noise_step += 1
for t in range(self.T):
if self.cn is None:
self.neuronal_charge(x_seq[t] + self.sigma.mul(self.eps_seq[self.noise_step][t].to(x_seq.device)))
else:
self.neuronal_charge(x_seq[t] + self.sigma.mul(self.cn[:, t].to(x_seq.device)))
v_seq.append(self.v)
else:
for t in range(self.T):
self.neuronal_charge(x_seq[t])
v_seq.append(self.v)
if self.decode == 'max-mem':
mem = torch.max(torch.stack(v_seq, 0), 0).values
elif self.decode == 'max-abs-mem':
v_stack = torch.stack(v_seq, 0)
max_mem = torch.max(v_stack, 0).values
min_mem = torch.min(v_stack, 0).values
mem = max_mem * (max_mem.abs() > min_mem.abs()) + min_mem * (max_mem.abs() <= min_mem.abs())
elif self.decode == 'mean-mem':
mem = torch.mean(torch.stack(v_seq, 0), 0)
else: # 'last-mem'
mem = v_seq[-1]
return mem
[文档] def reset_noise(self, num_rl_step):
eps_shape = [self.num_node, num_rl_step * self.T]
per_order = [1, 2, 0]
self.eps_seq = torch.FloatTensor(powerlaw_psd_gaussian(self.beta, eps_shape).reshape(self.num_node, num_rl_step, self.T)).permute(per_order)
self.noise_step = -1
[文档] def get_colored_noise(self):
return self.eps_seq[self.noise_step]
[文档] def load_colored_noise(self, cn):
self.cn = cn
[文档] def cancel_load(self):
self.cn = None
[文档]class NoisyNonSpikingIFNode(NoisyNonSpikingBaseNode):
[文档] def neuronal_charge(self, x: torch.Tensor):
self.v = self.v + x