spikingjelly.activation_based.ann2snn.factories 源代码
from typing import Optional, Type, Union
import torch.nn as nn
from spikingjelly.activation_based import neuron
from spikingjelly.activation_based.ann2snn.modules import VoltageHook
[文档]
class NeuronFactory:
def __init__(
self,
neuron_type: Type[nn.Module] = neuron.IFNode,
v_threshold: float = 1.0,
v_reset: Optional[float] = None,
**kwargs,
):
"""
**API Language** - :ref:`中文 <NeuronFactory.__init__-cn>` | :ref:`English <NeuronFactory.__init__-en>`
----
.. _NeuronFactory.__init__-cn:
* **中文**
用于创建替换激活函数的脉冲神经元模块。默认创建
:class:`spikingjelly.activation_based.neuron.IFNode`,并使用
``v_threshold=1.0`` 与 ``v_reset=None`` 保持原有 ANN2SNN 行为。默认转换会通过
:class:`VoltageScaler` 处理激活尺度,因此默认工厂不会把 ``scale`` 直接写入
神经元阈值;自定义工厂可读取 ``scale`` 派生阈值或其他参数。
:param neuron_type: 神经元类,必须接受 ``v_threshold`` 与 ``v_reset`` 关键字参数。
默认为 :class:`spikingjelly.activation_based.neuron.IFNode`。
:type neuron_type: Type[nn.Module]
:param v_threshold: 神经元发放阈值,传递给神经元构造函数。
:type v_threshold: float
:param v_reset: 膜电位复位值。``None`` 表示软复位(减法复位),默认为 ``None``。
:type v_reset: Optional[float]
:param kwargs: 透传给神经元构造函数的其他关键字参数。
----
.. _NeuronFactory.__init__-en:
* **English**
Factory that creates spiking-neuron modules used to replace ANN activation
functions. By default it instantiates
:class:`spikingjelly.activation_based.neuron.IFNode` with
``v_threshold=1.0`` and ``v_reset=None`` to preserve the original
ANN2SNN behaviour. The default conversion handles the activation scale
with :class:`VoltageScaler`, so the default factory does not copy
``scale`` into the neuron threshold. Custom factories may derive
thresholds or other neuron parameters from ``scale``.
:param neuron_type: Neuron class to instantiate. Must accept
``v_threshold`` and ``v_reset`` keyword arguments. Defaults to
:class:`spikingjelly.activation_based.neuron.IFNode`.
:type neuron_type: Type[nn.Module]
:param v_threshold: Firing threshold passed to the neuron constructor.
:type v_threshold: float
:param v_reset: Membrane reset value. ``None`` means soft reset
(subtractive reset). Defaults to ``None``.
:type v_reset: Optional[float]
:param kwargs: Additional keyword arguments forwarded to the neuron
constructor.
"""
self.neuron_type = neuron_type
self.v_threshold = v_threshold
self.v_reset = v_reset
self.neuron_kwargs = kwargs
[文档]
def create(self, scale: float) -> nn.Module:
r"""
**API Language** - :ref:`中文 <NeuronFactory.create-cn>` | :ref:`English <NeuronFactory.create-en>`
----
.. _NeuronFactory.create-cn:
* **中文**
根据工厂配置创建一个脉冲神经元模块实例。``scale`` 为当前层校准得到的激活
尺度,默认实现不直接使用该值,但子类可据此派生阈值或其他参数。
:param scale: 当前层的校准尺度。
:type scale: float
:return: 配置完成的脉冲神经元模块。
:rtype: nn.Module
----
.. _NeuronFactory.create-en:
* **English**
Instantiate a spiking-neuron module with the configured parameters.
``scale`` is the calibrated activation scale of the current layer; the
default implementation does not use it directly, but subclasses can
derive thresholds or other neuron parameters from it.
:param scale: Calibration scale for the layer.
:type scale: float
:return: A spiking-neuron module.
:rtype: nn.Module
"""
return self.neuron_type(
v_threshold=self.v_threshold,
v_reset=self.v_reset,
**self.neuron_kwargs,
)
[文档]
class HookFactory:
def __init__(self, mode: Union[str, float] = "Max", momentum: float = 0.1):
"""
**API Language** - :ref:`中文 <HookFactory.__init__-cn>` | :ref:`English <HookFactory.__init__-en>`
----
.. _HookFactory.__init__-cn:
* **中文**
用于创建校准阶段使用的 :class:`VoltageHook` 实例。每个匹配到的激活节点会获得
独立的 hook 实例。
:param mode: 校准模式,传递给 :class:`VoltageHook`。``"Max"`` 记录激活最大值;
``"99.9%"`` 记录 99.9 分位点;``(0, 1]`` 区间的 float 表示
``max * mode``。
:type mode: str, float
:param momentum: :class:`VoltageHook` 的 EMA 动量。
:type momentum: float
----
.. _HookFactory.__init__-en:
* **English**
Factory that creates :class:`VoltageHook` instances used during
calibration. Each matched activation node receives an independent hook
instance.
:param mode: Calibration mode forwarded to :class:`VoltageHook`.
``"Max"`` records the maximum activation; ``"99.9%"`` records the
99.9-th percentile; a float in ``(0, 1]`` records ``max * mode``.
:type mode: str, float
:param momentum: EMA momentum for :class:`VoltageHook`.
:type momentum: float
"""
self.mode = mode
self.momentum = momentum
[文档]
def create(self) -> VoltageHook:
r"""
**API Language** - :ref:`中文 <HookFactory.create-cn>` | :ref:`English <HookFactory.create-en>`
----
.. _HookFactory.create-cn:
* **中文**
创建一个新的 :class:`VoltageHook` 实例。
:return: 配置完成的 :class:`VoltageHook`。
:rtype: VoltageHook
----
.. _HookFactory.create-en:
* **English**
Create a new :class:`VoltageHook` instance.
:return: A configured :class:`VoltageHook`.
:rtype: VoltageHook
"""
return VoltageHook(momentum=self.momentum, mode=self.mode)