spikingjelly.activation_based.functional.net_config 源代码

from __future__ import annotations

import logging
from typing import Optional, Union

import torch.nn as nn

from .. import base

__all__ = [
    "collect_reset_modules",
    "detach_net",
    "reset_collected_modules",
    "reset_net",
    "set_backend",
    "set_step_mode",
]


[文档] def collect_reset_modules(net: nn.Module) -> tuple[nn.Module, ...]: return tuple(m for m in net.modules() if callable(getattr(m, "reset", None)))
[文档] def reset_collected_modules(modules: tuple[nn.Module, ...]) -> None: for m in modules: if not isinstance(m, base.MemoryModule): logging.warning( f"Trying to call `reset()` of {m}, which is not spikingjelly.activation_based.base" f".MemoryModule" ) m.reset()
[文档] def reset_net(net: nn.Module): r""" **API Language:** :ref:`中文 <reset_net-cn>` | :ref:`English <reset_net-en>` ---- .. _reset_net-cn: * **中文** 重置 ``net`` 中所有可重置模块的状态。 该函数会遍历 ``net.modules()`` 中的所有子模块;若某个子模块实现了 ``reset()`` 方法,则调用该方法。对于不是 :class:`~spikingjelly.activation_based.base.MemoryModule` 但实现了 ``reset()`` 的模块,此函数仍会调用 ``reset()``,同时记录告警。 :param net: 任何属于 ``nn.Module`` 子类的网络 :type net: torch.nn.Module :return: ``None`` :rtype: None :raises Exception: 任何子模块 ``reset()`` 在执行过程中抛出的异常都会原样向上传播 ---- .. _reset_net-en: * **English** Reset the states of all resettable modules in ``net``. This function iterates over ``net.modules()`` and calls ``reset()`` on each submodule that implements it. If a submodule is not an instance of :class:`~spikingjelly.activation_based.base.MemoryModule` but still defines ``reset()``, the function will still call it and emit a warning. :param net: Any network inherits from ``nn.Module`` :type net: torch.nn.Module :return: ``None`` :rtype: None :raises Exception: Any exception raised by a submodule ``reset()`` call is propagated unchanged """ reset_collected_modules(collect_reset_modules(net))
[文档] def set_step_mode(net: nn.Module, step_mode: str): r""" **API Language:** :ref:`中文 <set_step_mode-cn>` | :ref:`English <set_step_mode-en>` ---- .. _set_step_mode-cn: * **中文** 将 ``net`` 中所有具有 ``step_mode`` 属性的模块的步进模式设置为 ``step_mode`` 。 .. note:: :class:`StepModeContainer <spikingjelly.activation_based.layer.container.StepModeContainer>`, :class:`ElementWiseRecurrentContainer <spikingjelly.activation_based.layer.container.ElementWiseRecurrentContainer>`, :class:`LinearRecurrentContainer <spikingjelly.activation_based.layer.container.LinearRecurrentContainer>` 的子模块(不包含包装器本身)的 ``step_mode`` 不会被改变。 若某个模块具有 ``step_mode`` 属性但不是 :class:`~spikingjelly.activation_based.base.StepModule`,则该函数仍会尝试赋值, 同时记录告警。 :param net: 一个神经网络 :type net: torch.nn.Module :param step_mode: 's' (单步模式) 或 'm' (多步模式) :type step_mode: str :return: ``None`` :rtype: None :raises ValueError: 若某个模块的 ``step_mode`` setter 不接受给定的 ``step_mode``,则该异常会原样向上传播 ---- .. _set_step_mode-en: * **English** Set ``step_mode`` to ``step_mode`` for all modules in ``net`` that expose a ``step_mode`` attribute. .. admonition:: Note :class: note The submodule (not including the container itself) of :class:`StepModeContainer <spikingjelly.activation_based.layer.container.StepModeContainer>`, :class:`ElementWiseRecurrentContainer <spikingjelly.activation_based.layer.container.ElementWiseRecurrentContainer>`, :class:`LinearRecurrentContainer <spikingjelly.activation_based.layer.container.LinearRecurrentContainer>` will not be changed. If a module has a ``step_mode`` attribute but is not an instance of :class:`~spikingjelly.activation_based.base.StepModule`, the function still attempts to assign the new value and emits a warning. :param net: a network :type net: nn.Module :param step_mode: 's' (single-step) or 'm' (multi-step) :type step_mode: str :return: ``None`` :rtype: None :raises ValueError: Propagated if a module rejects the provided ``step_mode`` in its setter """ from ..layer import ( ElementWiseRecurrentContainer, LinearRecurrentContainer, StepModeContainer, ) keep_step_mode_instance = ( StepModeContainer, ElementWiseRecurrentContainer, LinearRecurrentContainer, ) # step_mode of sub-modules in keep_step_mode_instance will not be changed keep_step_mode_containers = [] for m in net.modules(): if isinstance(m, keep_step_mode_instance): keep_step_mode_containers.append(m) for m in net.modules(): if hasattr(m, "step_mode"): is_contained = False for container in keep_step_mode_containers: if ( not isinstance(m, keep_step_mode_instance) and m in container.modules() ): is_contained = True break if is_contained: # this function should not change step_mode of submodules in keep_step_mode_containers pass else: if not isinstance(m, (base.StepModule)): logging.warning( f"Trying to set the step mode for {m}, which is not spikingjelly.activation_based" f".base.StepModule" ) m.step_mode = step_mode
[文档] def set_backend( net: nn.Module, backend: str, instance: Optional[Union[nn.Module, tuple[nn.Module]]] = None, ): r""" **API Language:** :ref:`中文 <set_backend-cn>` | :ref:`English <set_backend-en>` ---- .. _set_backend-cn: * **中文** 将 ``net`` 中所有满足 ``isinstance(m, instance)`` 且具有 ``backend`` 属性的模块后端设置为 ``backend``。 仅当目标模块的 ``supported_backends`` 包含给定 ``backend`` 时才会实际更新; 否则会记录告警并保留原有后端。若 ``instance`` 为 ``None``,则会检查所有具有 ``backend`` 属性的模块。 :param net: 一个神经网络 :type net: torch.nn.Module :param backend: 使用哪个后端 :type backend: str :param instance: 传给 ``isinstance`` 的筛选类型。满足该筛选且具有 ``backend`` 属性的模块后端会被检查。 若为 ``None`` ,则所有具有 ``backend`` 属性的模块都会被检查 :type instance: Optional[Union[nn.Module, tuple[nn.Module]]] :return: ``None`` :rtype: None :raises Exception: 若目标模块在访问 ``supported_backends`` 或设置 ``backend`` 时抛出异常,则该异常会原样向上传播 ---- .. _set_backend-en: * **English** Set ``backend`` for all modules in ``net`` whose type matches ``instance`` and that expose a ``backend`` attribute. The backend is updated only when ``backend`` is listed in the module's ``supported_backends``. Otherwise, a warning is logged and the existing backend is kept unchanged. If ``instance`` is ``None``, all modules with a ``backend`` attribute are checked. :param net: a network :type net: torch.nn.Module :param backend: the backend to be set :type backend: str :param instance: the type filter passed to ``isinstance``. Modules that match this filter and have a ``backend`` attribute will be checked. If ``None``, all modules with a ``backend`` attribute will be checked :type instance: Optional[Union[nn.Module, tuple[nn.Module]]] :return: ``None`` :rtype: None :raises Exception: Propagated if a target module raises while exposing ``supported_backends`` or assigning ``backend`` """ instance = (nn.Module,) if instance is None else instance for m in net.modules(): if isinstance(m, instance): if hasattr(m, "backend"): if not isinstance(m, base.MemoryModule): logging.warning( f"Trying to set the backend for {m}, which is not spikingjelly.activation_based.base.MemoryModule" ) if backend in m.supported_backends: m.backend = backend else: logging.warning( f"{m} does not supports for backend={backend}. It will still use backend={m.backend}." )
[文档] def detach_net(net: nn.Module): r""" **API Language:** :ref:`中文 <detach_net-cn>` | :ref:`English <detach_net-en>` ---- .. _detach_net-cn: * **中文** 将 ``net`` 中各有状态模块与之前时间步的计算图断开。 该函数会遍历 ``net.modules()`` 中的所有子模块;若某个子模块实现了 ``detach()`` 方法,则调用该方法。对于不是 :class:`~spikingjelly.activation_based.base.MemoryModule` 但实现了 ``detach()`` 的模块,此函数仍会调用 ``detach()``,同时记录告警。 :param net: 任何属于 ``nn.Module`` 子类的网络 :type net: torch.nn.Module :return: ``None`` :rtype: None :raises Exception: 任何子模块 ``detach()`` 在执行过程中抛出的异常都会原样向上传播 ---- .. _detach_net-en: * **English** Detach stateful modules in ``net`` from the computation graphs of previous time steps. This function iterates over ``net.modules()`` and calls ``detach()`` on each submodule that implements it. If a submodule is not an instance of :class:`~spikingjelly.activation_based.base.MemoryModule` but still defines ``detach()``, the function will still call it and emit a warning. :param net: Any network inherits from ``nn.Module`` :type net: torch.nn.Module :return: ``None`` :rtype: None :raises Exception: Any exception raised by a submodule ``detach()`` call is propagated unchanged """ for m in net.modules(): if hasattr(m, "detach"): if not isinstance(m, base.MemoryModule): logging.warning( f"Trying to call `detach()` of {m}, which is not spikingjelly.activation_based.base" f".MemoryModule" ) m.detach()