spikingjelly.activation_based.nir_exchange.to_nir 源代码

from typing import Optional, Union
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch import fx
from torch.fx.passes.shape_prop import ShapeProp
import nir
import nirtorch

from .. import layer, neuron


__all__ = ["export_to_nir"]


def _to_numpy(x: torch.Tensor) -> np.ndarray:
    return x.detach().cpu().numpy()


class _ModuleMapper:
    def __init__(
        self,
        net: nn.Module,
        example_input: torch.Tensor,
        dt: float = 1e-4,
    ):
        self.dt = dt
        self.net = net
        self.module_io_shape = {}
        self.set_module_io_shape(example_input)

    def set_module_io_shape(self, example_input: torch.Tensor):
        tracer = nirtorch.torch_tracer.NIRTorchTracer(self.map_dict.keys())
        graph = tracer.trace(self.net)
        gm = fx.GraphModule(tracer.root, graph)
        ShapeProp(gm).propagate(example_input)

        for node in gm.graph.nodes:
            if node.op != "call_module":
                continue
            if "tensor_meta" not in node.meta:
                continue

            module = gm.get_submodule(node.target)
            output_shape = node.meta["tensor_meta"].shape

            input_shapes = []
            for in_node in node.all_input_nodes:
                if "tensor_meta" in in_node.meta:
                    input_shapes.append(in_node.meta["tensor_meta"].shape)
            input_shape = input_shapes[0]  # most modules has only one input

            self.module_io_shape[module] = {
                "input_shape": input_shape,
                "output_shape": output_shape,
            }

    @property
    def map_dict(self) -> dict:
        return {
            nn.Linear: self.map_linear,
            layer.Linear: self.map_linear,
            nn.Conv2d: self.map_conv2d,
            layer.Conv2d: self.map_conv2d,
            nn.AvgPool2d: self.map_avgpool2d,
            layer.AvgPool2d: self.map_avgpool2d,
            nn.Flatten: self.map_flatten,
            layer.Flatten: self.map_flatten,
            neuron.IFNode: self.map_if,
            neuron.LIFNode: self.map_lif,
            neuron.ParametricLIFNode: self.map_plif,
        }

    def map(self, module: nn.Module) -> nir.NIRNode:
        return self.map_dict[module.__class__](module)

    def map_linear(self, module: nn.Linear) -> nir.NIRNode:
        if module.bias is None:
            return nir.Linear(_to_numpy(module.weight))
        else:
            return nir.Affine(_to_numpy(module.weight), _to_numpy(module.bias))

    def map_conv2d(self, module: nn.Conv2d) -> nir.Conv2d:
        if module.bias is None:
            bias = np.zeros((module.weight.shape[0]))
        else:
            bias = _to_numpy(module.bias)

        H, W = self.module_io_shape[module]["input_shape"][-2:]

        return nir.Conv2d(
            input_shape=(H, W),
            weight=_to_numpy(module.weight),
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            groups=module.groups,
            bias=bias,
        )

    def map_avgpool2d(self, module: nn.AvgPool2d) -> nir.NIRNode:
        return nir.AvgPool2d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
        )

    def map_flatten(self, module: nn.Flatten) -> nir.Flatten:
        start_dim, end_dim = module.start_dim, module.end_dim
        start_dim = start_dim - 1 if start_dim > 0 else start_dim
        end_dim = end_dim - 1 if end_dim > 0 else end_dim

        input_shape = self.module_io_shape[module]["input_shape"]
        input_type_start = 1
        if hasattr(module, "step_mode") and module.step_mode == "m":
            input_type_start = 2

        return nir.Flatten(
            input_type=input_shape[input_type_start:],  # remove the T and B dims
            start_dim=start_dim,
            end_dim=end_dim,
        )

    def map_if(self, module: neuron.IFNode) -> nir.IF:
        """
        .. warning::

            `nir.IF` does not distinguish soft reset from hard reset. If
            `module.v_reset=None` (i.e. soft reset), it will be converted to
            `module.v_reset=0` (i.e. hard reset with 0 reset potential).
        """
        v_reset = module.v_reset
        v_threshold = module.v_threshold

        r = 1 / self.dt
        v_reset_ = 0.0 if v_reset is None else v_reset

        input_shape = self.module_io_shape[module]["input_shape"]
        output_shape = self.module_io_shape[module]["output_shape"]
        type_start = 1 if module.step_mode == "s" else 2
        input_type = input_shape[type_start:]
        output_type = output_shape[type_start:]  # remove the T and B dims

        return nir.IF(
            r=np.full(input_type, r),
            v_threshold=np.full(input_type, v_threshold),
            v_reset=np.full(input_type, v_reset_),
            input_type=input_type,
            output_type=output_type,
        )

    def map_lif(self, module: neuron.LIFNode) -> nir.LIF:
        """
        .. warning::

            `nir.LIF` does not distinguish soft reset from hard reset. If
            `module.v_reset=None` (i.e. soft reset), it will be converted to
            `module.v_reset=0` (i.e. hard reset with 0 reset potential).
        """
        tau = module.tau
        v_reset = module.v_reset
        v_threshold = module.v_threshold
        decay_input = module.decay_input

        tau_ = tau * self.dt
        r = 1.0 if decay_input else tau
        v_leak = 0.0 if v_reset is None else v_reset
        v_reset_ = 0.0 if v_reset is None else v_reset

        input_shape = self.module_io_shape[module]["input_shape"]
        output_shape = self.module_io_shape[module]["output_shape"]
        type_start = 1 if module.step_mode == "s" else 2
        input_type = input_shape[type_start:]
        output_type = output_shape[type_start:]  # remove the T and B dims

        return nir.LIF(
            tau=np.full(input_type, tau_),
            r=np.full(input_type, r),
            v_leak=np.full(input_type, v_leak),
            v_threshold=np.full(input_type, v_threshold),
            v_reset=np.full(input_type, v_reset_),
            input_type=input_type,
            output_type=output_type,
        )

    def map_plif(self, module: neuron.ParametricLIFNode) -> nir.LIF:
        """
        .. warning::

            `nir.LIF` does not distinguish soft reset from hard reset. If
            `module.v_reset=None` (i.e. soft reset), it will be converted to
            `module.v_reset=0` (i.e. hard reset with 0 reset potential).
        """
        with torch.no_grad():
            tau = 1.0 / module.w.sigmoid()
        v_reset = module.v_reset
        v_threshold = module.v_threshold
        decay_input = module.decay_input

        tau_ = tau * self.dt
        r = 1.0 if decay_input else tau
        v_leak = 0.0 if v_reset is None else v_reset
        v_reset_ = 0.0 if v_reset is None else v_reset

        input_shape = self.module_io_shape[module]["input_shape"]
        output_shape = self.module_io_shape[module]["output_shape"]
        type_start = 1 if module.step_mode == "s" else 2
        input_type = input_shape[type_start:]
        output_type = output_shape[type_start:]  # remove the T and B dims

        return nir.LIF(
            tau=np.full(input_type, tau_),
            r=np.full(input_type, r),
            v_leak=np.full(input_type, v_leak),
            v_threshold=np.full(input_type, v_threshold),
            v_reset=np.full(input_type, v_reset_),
            input_type=input_type,
            output_type=output_type,
        )


[文档] def export_to_nir( net: nn.Module, example_input: torch.Tensor, save_path: Optional[Union[str, Path]] = None, dt: float = 1e-4, ): """ **API Language:** :ref:`中文 <export_to_nir-cn>` | :ref:`English <export_to_nir-en>` ---- .. _export_to_nir-cn: * **中文** 将 SpikingJelly 的模型转换为 `NIR(Neuromorphic Intermediate Representation) <https://neuroir.org/docs/index.html>`_ 图, 以供后续转换到其它框架或部署到神经形态芯片上。本函数会自动通过示例输入 ``example_input`` 推导每个模块的输入输出形状,将 SpikingJelly 或 PyTorch 模块转换为对应的 NIR 节点。 :param net: 需要转换的 SpikingJelly / PyTorch 模型 :type net: torch.nn.Module :param example_input: 用于推导 ``net`` 中各个子模块输入输出形状的示例输入张量 :type example_input: torch.Tensor :param save_path: 转换后的 NIR 图保存路径。如果不为 ``None``,函数会将 NIR 图写入指定的 HDF5 文件。默认为 ``None`` ,即不保存 NIR 图 :type save_path: Optional[Union[str, Path]] :param dt: 网络时间步长,单位为秒,用于计算 NIR 神经元节点的时间常量等超参数。默认值为 ``1e-4``, 与大多数兼容 NIR 的框架一致 :type dt: float :return: 转换得到的 NIRGraph 对象 :rtype: nir.NIRGraph ---- .. _export_to_nir-en: * **English** Convert a SpikingJelly model to a `NIR (Neuromorphic Intermediate Representation) <https://neuroir.org/docs/index.html>`_ graph for conversion to other frameworks or deployment on neuromorphic hardware. This function automatically infers the input and output shapes of each submodule using ``example_input``, and converts SpikingJelly or PyTorch modules to the corresponding NIR nodes. :param net: the SpikingJelly / PyTorch model to convert :type net: torch.nn.Module :param example_input: an example input tensor used to infer the input and output shapes of each submodule in ``net`` :type example_input: torch.Tensor :param save_path: the path to save the converted NIR graph. If not ``None``, the NIR graph will be written to the specified HDF5 file. Defaults to `None`, which means the NIR graph will not be saved :type save_path: Optional[Union[str, Path]] :param dt: simulation time step in seconds, used to compute time constants and other hyperparameters for NIR neuron nodes. The default value is ``1e-4``, consistent with other frameworks that support NIR :type dt: float :return: the converted NIRGraph object :rtype: nir.NIRGraph """ mapper = _ModuleMapper(net, example_input, dt=dt) graph = nirtorch.torch_to_nir(net, mapper.map_dict, type_check=True) if save_path is not None: nir.write(save_path, graph) return graph