spikingjelly.activation_based.lava_exchange 源代码

import logging
from typing import Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from . import base, neuron, surrogate

_hw_bits = 12


[文档] def step_quantize_forward(x: torch.Tensor, step: float): r""" **API Language:** :ref:`中文 <step_quantize_forward-cn>` | :ref:`English <step_quantize_forward-en>` ---- .. _step_quantize_forward-cn: * **中文** * **中文** ``step_quantize`` 的前向量化函数。将 ``x`` 除以 ``step``,四舍五入后再乘回 ``step``。 :param x: 输入张量 :type x: torch.Tensor :param step: 量化步长 :type step: float :return: 量化后的张量 :rtype: torch.Tensor ---- .. _step_quantize_forward-en: * **English** * **English** The forward quantization function of ``step_quantize``. Divide ``x`` by ``step``, round, and multiply back by ``step``. :param x: Input tensor :type x: torch.Tensor :param step: Quantization step :type step: float :return: Quantized tensor :rtype: torch.Tensor """ x = x / step torch.round_(x) return x * step
[文档] class step_quantize_atgf(torch.autograd.Function): r""" **API Language:** :ref:`中文 <step_quantize_atgf-cn>` | :ref:`English <step_quantize_atgf-en>` ---- .. _step_quantize_atgf-cn: * **中文** * **中文** ``step_quantize`` 的自定义自动求导函数。前向使用 ``step_quantize_forward`` 进行量化,反向使用直通估计器(Straight-Through Estimator)。 ---- .. _step_quantize_atgf-en: * **English** * **English** Custom autograd Function for ``step_quantize``. Uses ``step_quantize_forward`` for forward quantization and a straight-through estimator (STE) for backward. """
[文档] @staticmethod def forward(ctx, x: torch.Tensor, step: float = 1.0): return step_quantize_forward(x, step)
[文档] @staticmethod def backward(ctx, grad_output): return grad_output, None
def step_quantize(x: torch.Tensor, step: float = 1.0): r""" **API Language:** :ref:`中文 <step_quantize-cn>` | :ref:`English <step_quantize-en>` ---- .. _step_quantize-cn: * **中文** * **中文** 步进量化器,定义在 `Lava` 中。 记 ``k`` 为 ``int``,``x[i]`` 将被量化到最近的 ``k * step``。 :param x: 浮点张量,取值范围为 ``0 <= x <= 1``。 :type x: torch.Tensor :param step: 量化步长 :type step: float :return: ``y = round(x / step) * step`` :rtype: torch.Tensor ---- .. _step_quantize-en: * **English** * **English** The step quantizer defined in `Lava`. Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``k * step``. :param x: a float tensor whose range is ``0 <= x <= 1``. :type x: torch.Tensor :param step: the quantization step :type step: float :return: ``y = round(x / step) * step`` :rtype: torch.Tensor """ return step_quantize_atgf.apply(x, step)
[文档] def quantize_8b(x, scale, descale=False): r""" **API Language:** :ref:`中文 <quantize_8b-cn>` | :ref:`English <quantize_8b-en>` ---- .. _quantize_8b-cn: * **中文** * **中文** 记 ``k`` 为 ``int``,``x[i]`` 将被量化到最近的 ``2 * k / scale``,其中 ``k = {-128, -127, ..., 126, 127}``。 :param x: 输入张量 :type x: torch.Tensor :param scale: 缩放因子 :type scale: float :param descale: 是否进行反缩放 :type descale: bool :return: 量化后的张量 :rtype: torch.Tensor ---- .. _quantize_8b-en: * **English** * **English** Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``2 * k / scale``, and ``k = {-128, -127, ..., 126, 127}``. :param x: input tensor :type x: torch.Tensor :param scale: scale factor :type scale: float :param descale: whether to descale :type descale: bool :return: quantized tensor :rtype: torch.Tensor """ if not descale: return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale) else: return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale) * scale
[文档] def right_shift_to_zero(x: torch.Tensor, bits: int): r""" **API Language:** :ref:`中文 <right_shift_to_zero-cn>` | :ref:`English <right_shift_to_zero-en>` ---- .. _right_shift_to_zero-cn: * **中文** * **中文** 带符号的右移运算,向零舍入。计算 ``sign(x) * (|x| >> bits)``,确保负数向零舍入。 :param x: 输入整数张量,须为 ``torch.int32`` 或 ``torch.int64`` :type x: torch.Tensor :param bits: 右移位数 :type bits: int :return: 右移后的整数张量 :rtype: torch.Tensor ---- .. _right_shift_to_zero-en: * **English** * **English** Signed right shift with rounding toward zero. Computes ``sign(x) * (|x| >> bits)`` so that negative values shift toward zero. :param x: Input integer tensor, must be ``torch.int32`` or ``torch.int64`` :type x: torch.Tensor :param bits: Number of bits to shift :type bits: int :return: Right-shifted integer tensor :rtype: torch.Tensor """ dtype = x.dtype assert dtype in (torch.int32, torch.int64) return (torch.sign(x) * (torch.abs(x) >> bits)).to(dtype)
def _listep_forward( x: torch.Tensor, decay: torch.Tensor, state: torch.Tensor, w_scale: int, dtype: torch.dtype = torch.int32, hw_bits: int = 12, ): # y = (state * w_scale * ((1 << hw_bits) - decay) / (1 << hw_bits) + w_scale * x) / w_scale # y = state * (1 - decay / (1 << hw_bits)) + x scaled_state = (state * w_scale).to(dtype=dtype) decay_int = (1 << hw_bits) - decay.to(dtype=dtype) output = right_shift_to_zero(scaled_state * decay_int, hw_bits) + (w_scale * x).to( dtype=dtype ) return output / w_scale def _listep_backward( grad_output: torch.Tensor, decay: torch.Tensor, state: torch.Tensor, hw_bits: int = 12, ): grad_state = (1 - decay / (1 << hw_bits)) * grad_output grad_decay = -state / (1 << hw_bits) * grad_output grad_decay = grad_decay.sum() return grad_output, grad_decay, grad_state # x, decay, state
[文档] class BatchNorm2d(nn.Module): def __init__( self, num_features: int, eps: float = 1e-05, momentum: float = 0.1, track_running_stats: bool = True, weight_exp_bits: int = 3, pre_hook_fx: Callable = lambda x: x, ): r""" **API Language:** :ref:`中文 <BatchNorm2d.__init__-cn>` | :ref:`English <BatchNorm2d.__init__-en>` ---- .. _BatchNorm2d.__init__-cn: * **中文** 用于 Lava 交换的带权重量化的批归一化层,参考 ``lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm``。 与标准 ``nn.BatchNorm2d`` 不同,该层使用基于 2 的幂的量化标准差进行归一化,且不含可学习的仿射参数。 :param num_features: 特征通道数 :type num_features: int :param eps: 防止除零的小常数 :type eps: float :param momentum: running 统计量的动量 :type momentum: float :param track_running_stats: 是否追踪运行统计量 :type track_running_stats: bool :param weight_exp_bits: 权重指数位数 :type weight_exp_bits: int :param pre_hook_fx: 归一化前对均值的预处理函数 :type pre_hook_fx: Callable :return: ``None`` :rtype: None ---- .. _BatchNorm2d.__init__-en: * **English** Weight-quantized batch normalization for Lava exchange, adapted from ``lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm``. Unlike standard ``nn.BatchNorm2d``, this layer uses power-of-2 quantized standard deviation for normalization and has no learnable affine parameters. :param num_features: Number of feature channels :type num_features: int :param eps: Small constant for numerical stability :type eps: float :param momentum: Momentum for running statistics :type momentum: float :param track_running_stats: Whether to track running statistics :type track_running_stats: bool :param weight_exp_bits: Number of bits for weight exponent :type weight_exp_bits: int :param pre_hook_fx: Pre-processing function applied to mean before normalization :type pre_hook_fx: Callable :return: ``None`` :rtype: None """ super().__init__() # lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm self.num_features = num_features self.eps = eps self.momentum = momentum self.track_running_stats = track_running_stats self.weight_exp_bits = weight_exp_bits self.pre_hook_fx = pre_hook_fx self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.zeros(1))
[文档] def to_lava(self): bn = slayer.neuron.norm.WgtScaleBatchNorm( num_features=self.num_features, momentum=self.momentum, weight_exp_bits=self.weight_exp_bits, eps=self.eps, pre_hook_fx=self.pre_hook_fx, ) bn.load_state_dict(self.state_dict()) print(self.state_dict()) return bn
[文档] def forward(self, x: torch.Tensor): if self.track_running_stats and self.training: x_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True) x_var = torch.var(x, unbiased=False) numel = x.numel() / self.num_features with torch.no_grad(): self.running_mean = ( 1.0 - self.momentum ) * self.running_mean + self.momentum * x_mean.squeeze() self.running_var = ( 1.0 - self.momentum ) * self.running_var + self.momentum * x_var * numel / (numel + 1) else: x_mean = self.running_mean.view(1, -1, 1, 1) x_var = self.running_var.view(1, -1, 1, 1) x_std = torch.sqrt(x_var + self.eps) x_std = torch.pow( 2.0, torch.ceil(torch.log2(x_std)).clamp( -self.weight_exp_bits, self.weight_exp_bits ), ) return (x - self.pre_hook_fx(x_mean)) / x_std
[文档] class LeakyIntegratorStep(torch.autograd.Function): r""" **API Language:** :ref:`中文 <LeakyIntegratorStep-cn>` | :ref:`English <LeakyIntegratorStep-en>` ---- .. _LeakyIntegratorStep-cn: * **中文** * **中文** 泄露积分器(Leaky Integrator)的自定义自动求导函数,用于 Lava 交换中的电流/电压衰减计算。 前向通过 ``_listep_forward`` 实现整数算术的泄露积分,反向通过 ``_listep_backward`` 实现梯度传播。 ---- .. _LeakyIntegratorStep-en: * **English** * **English** Custom autograd Function for the Leaky Integrator used in Lava exchange current/voltage decay. Forward implements leaky integration via integer arithmetic through ``_listep_forward``, and backward propagates gradients through ``_listep_backward``. """
[文档] @staticmethod def forward(ctx, x, decay, state, w_scale): output = _listep_forward( x, decay, state, w_scale, dtype=torch.int64, hw_bits=_hw_bits ) if x.requires_grad or state.requires_grad: ctx.save_for_backward(decay, state) return output
[文档] @staticmethod def backward(ctx, grad_output): decay, state = ctx.saved_tensors grad_input, grad_decay, grad_state = _listep_backward( grad_output, decay, state, hw_bits=_hw_bits ) return grad_input, grad_decay, grad_state, None
[文档] class CubaLIFNode(neuron.BaseNode): def __init__( self, current_decay: Union[float, torch.Tensor], voltage_decay: Union[float, torch.Tensor], v_threshold: float = 1.0, v_reset: float = 0.0, scale=1 << 6, requires_grad=False, surrogate_function: Callable = surrogate.Sigmoid(), norm: BatchNorm2d = None, detach_reset=False, step_mode="s", backend="torch", store_v_seq: bool = False, store_i_seq: bool = False, ): # author: https://github.com/AllenYolk r""" **API Language:** :ref:`中文 <CubaLIFNode.__init__-cn>` | :ref:`English <CubaLIFNode.__init__-en>` ---- .. _CubaLIFNode.__init__-cn: * **中文** :param current_decay: 电流衰减常数 :type current_decay: Union[float, torch.Tensor] :param voltage_decay: 电压衰减常数 :type voltage_decay: Union[float, torch.Tensor] :param v_threshold: 神经元阈值电压。默认为1。 :type v_threshold: float :param v_reset: 重置电压,默认为0 :type v_reset: float, None :param scale: 量化参数,控制神经元的量化精度(参考了lava-dl的cuba.Neuron)。默认为 ``1<<6`` 。 等效于``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``。 :type scale: float :param requires_grad: 指明 ``current_decay`` 和 ``voltage_decay`` 两个神经元参数是否可学习(是否需要梯度),默认为 ``False`` 。 :type requires_grad: bool :param detach_reset: 是否将reset的计算图分离,默认为 ``False`` 。 :type detach_reset: bool :param step_mode: 步进模式,可以为 `'s'` (单步)或 `'m'` (多步),默认为 `'s'` 。 :type step_mode: str :param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 使用的步进模式支持的后端。目前只支持torch :type backend: str :param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.voltage_state`` 。 通常设置成 ``False`` ,可以节省内存。 :type store_v_seq: bool :param store_i_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` 的各个时间步的电流值 ``self.i_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电流,即 ``shape = [N, *]`` 的 ``self.current_state`` 。 通常设置成 ``False`` ,可以节省内存。 :type store_i_seq: bool :param surrogate_function: 替代梯度函数。默认为 ``surrogate.Sigmoid()`` :type surrogate_function: Callable :param norm: 量化归一化层,可选。若提供,则在每个时间步前对输入进行量化 :type norm: BatchNorm2d, optional .. math:: I[t] = (1 - \\alpha_{I})I[t-1] + X[t] V[t] = (1 - \\alpha_{V})V[t-1] + I[t] ---- .. _CubaLIFNode.__init__-en: * **English** :param current_decay: current decay constant :type current_decay: Union[float, torch.Tensor] :param v_threshold: threshold of the the neurons in this layer. Default to 1. :type v_threshold: float :param v_reset: reset potential of the neurons in this layer, 0 by default :type v_reset: float :param scale: quantization precision (ref: lava-dl cuba.Neuron). Default to ``1<<6`` . Equivalent to ``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``. :type scale: float :param requires_grad: whether ``current_decay`` and ``voltage_decay`` are learnable. Default to ``False`` . :type requires_grad: bool :param detach_reset: whether to detach the computational graph of reset in backward pass. Default to ``False`` . :type detach_reset: bool :param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step). Default to `'s'` . :type step_mode: str :param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. Only `torch` is supported. :type backend: str :param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the voltage at last time-step will be stored to ``self.voltage_state`` with ``shape = [N, *]``, which can reduce the memory consumption. Default to ``False`` . :type store_v_seq: bool :param store_i_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls whether storing the current at each time-step to ``self.i_seq`` with ``shape = [T, N, *]``. If set to ``False``, only the current at last time-step will be stored to ``self.current_state`` with ``shape = [N, *]``, which can reduce the memory consumption. Default to ``False`` . :type store_i_seq: bool .. math:: I[t] = (1 - \alpha_{I})I[t-1] + X[t] V[t] = (1 - \alpha_{V})V[t-1] + I[t] """ self.lava_cuba_neuron_params = { "threshold": v_threshold, "current_decay": current_decay, "voltage_decay": voltage_decay, "scale": scale, } super().__init__( v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate_function, detach_reset=detach_reset, step_mode=step_mode, backend=backend, store_v_seq=store_v_seq, ) self.store_i_seq = store_i_seq assert v_reset == 0.0, ( "CubaLIFNode only supports for hard reset with v_reset = 0. !" ) self.requires_grad = requires_grad # the default quantization parameter setting in lava self._scale = int(scale) self._s_scale = int(scale * (1 << 6)) self._p_scale = 1 << _hw_bits # Which is equivalent to: # self.p_scale = 1<<12 # self.w_scale = int(scale) # self.s_scale = int(scale * (1<<6)) self._v_threshold = int(v_threshold * self.scale) / self.scale # ``_v_threshold`` is the nearest and no more than ``k / scale`` to ``v_threshold`` where ``k`` is an ``int`` self.v_threshold_eps = 0.01 / self.s_scale # loihi use s[t] = v[t] > v_th, but we use s[t] = v[t] >= v_th. Thus, we use v[t] + eps >= v_th to approximate current_decay = torch.tensor(self.p_scale * current_decay, dtype=torch.float32) voltage_decay = torch.tensor(self.p_scale * voltage_decay, dtype=torch.float32) if requires_grad: self.current_decay = nn.Parameter(current_decay) self.voltage_decay = nn.Parameter(voltage_decay) else: self.register_buffer("current_decay", current_decay) self.register_buffer("voltage_decay", voltage_decay) self.register_memory("current_state", 0.0) self.register_memory("voltage_state", 0.0) self.clamp_decay_parameters() self.norm = norm if self.norm is not None: if isinstance(self.norm, BatchNorm2d): self.norm.pre_hook_fx = self.quantize_8bit else: raise NotImplementedError(self.norm)
[文档] def quantize_8bit(self, x, descale=False): return quantize_8b(x, scale=self.scale, descale=descale)
[文档] def clamp_decay_parameters(self): with torch.no_grad(): self.current_decay.data.clamp_(min=0.0, max=self.p_scale) self.voltage_decay.data.clamp_(min=0.0, max=self.p_scale)
@property def scale(self): r""" **API Language:** :ref:`中文 <CubaLIFNode.scale-cn>` | :ref:`English <CubaLIFNode.scale-en>` ---- .. _CubaLIFNode.scale-cn: * **中文** * **中文** :return: 突触权重缩放因子 :rtype: float ---- .. _CubaLIFNode.scale-en: * **English** * **English** :return: Synaptic weight scaling factor :rtype: float """ return self._scale @property def s_scale(self): r""" **API Language:** :ref:`中文 <CubaLIFNode.s_scale-cn>` | :ref:`English <CubaLIFNode.s_scale-en>` ---- .. _CubaLIFNode.s_scale-cn: * **中文** * **中文** :return: 突触缩放因子 :rtype: float ---- .. _CubaLIFNode.s_scale-en: * **English** * **English** :return: Synaptic scaling factor :rtype: float """ return self._s_scale @property def p_scale(self): r""" **API Language:** :ref:`中文 <CubaLIFNode.p_scale-cn>` | :ref:`English <CubaLIFNode.p_scale-en>` ---- .. _CubaLIFNode.p_scale-cn: * **中文** * **中文** :return: 电压缩放因子 :rtype: float ---- .. _CubaLIFNode.p_scale-en: * **English** * **English** :return: Voltage scaling factor :rtype: float """ return self._p_scale @property def store_i_seq(self): return self._store_i_seq @store_i_seq.setter def store_i_seq(self, value: bool): self._store_i_seq = value if value: if not hasattr(self, "i_seq"): self.register_memory("i_seq", None) @property def supported_backends(self): if self.step_mode == "m" or self.step_mode == "s": return ("torch",) else: raise ValueError( f"self.step_mode should be 's' or 'm', " f"but get {self.step_mode} instead." ) # computation process
[文档] def state_initialization(self, x: torch.Tensor): if isinstance(self.current_state, float): self.current_state = torch.zeros_like(x.data) if isinstance(self.voltage_state, float): self.voltage_state = torch.zeros_like(x.data)
[文档] def neuronal_charge(self, x: torch.Tensor): if self.requires_grad: self.clamp_decay_parameters() current = LeakyIntegratorStep.apply( x, step_quantize(self.current_decay), self.current_state.contiguous(), self.s_scale, ) if self.norm is not None: current = self.norm(current) voltage = LeakyIntegratorStep.apply( current, step_quantize(self.voltage_decay), self.voltage_state.contiguous(), self.s_scale, ) self.current_state = current self.voltage_state = voltage
[文档] def neuronal_fire(self): return self.surrogate_function( self.voltage_state - (self.v_threshold + self.v_threshold_eps) )
[文档] def neuronal_reset(self, spike): if self.detach_reset: spike_d = spike.detach() else: spike_d = spike self.voltage_state = self.apply_hard_reset( self.voltage_state, spike_d, self.v_reset )
[文档] def single_step_forward(self, x): self.state_initialization(x) self.neuronal_charge(x) spike = self.neuronal_fire() self.neuronal_reset(spike) return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor): T = x_seq.shape[0] y_seq = [] if self.store_v_seq: v_seq = [] if self.store_i_seq: i_seq = [] for t in range(T): y = self.single_step_forward(x_seq[t]) y_seq.append(y) if self.store_v_seq: v_seq.append(self.voltage_state) if self.store_i_seq: i_seq.append(self.current_state) if self.store_v_seq: self.v_seq = torch.stack(v_seq) if self.store_i_seq: self.i_seq = torch.stack(i_seq) return torch.stack(y_seq)
try: import lava.lib.dl.slayer as slayer # ---------------------------------------- # data reshape function
[文档] def TNX_to_NXT(x_seq: torch.Tensor): # x_seq.shape = [T, N, *] permute_args = list(range(1, x_seq.dim())) permute_args.append(0) return x_seq.permute(permute_args)
[文档] def NXT_to_TNX(x_seq: torch.Tensor): # x_seq.shape = [N, *, T] permute_args = list(range(x_seq.dim() - 1)) permute_args.insert(0, x_seq.dim() - 1) return x_seq.permute(permute_args)
[文档] def lava_neuron_forward( lava_neuron: nn.Module, x_seq: torch.Tensor, v: Union[torch.Tensor, float] ): # x_seq.shape = [T, N, *] # lave uses shape = [*, T], while SJ uses shape = [T, *] unsqueeze_flag = False if x_seq.dim() == 2: x_seq = x_seq.unsqueeze(1) # lave needs input with shape [N, ... ,T] unsqueeze_flag = True if isinstance(v, float): v_init = v v = torch.zeros_like(x_seq[0]) if v_init != 0.0: torch.fill_(v, v_init) x_seq_shape = x_seq.shape x_seq = x_seq.flatten(2).permute(1, 2, 0) # [T, N, *] -> [N, *, T] lava_neuron.voltage_state = v spike = lava_neuron(x_seq).permute(2, 0, 1) v = lava_neuron.voltage_state.reshape(x_seq_shape[1:]) spike = spike.reshape(x_seq_shape) if unsqueeze_flag: v = v.squeeze(1) spike = spike.squeeze(1) return spike, v
# ---------------------------------------- # quantize function class _step_quantize(torch.autograd.Function): @staticmethod def forward(ctx, x, step): return torch.round(x / step) * step @staticmethod def backward(ctx, grad_output): return grad_output, None
[文档] def step_quantize(x: torch.Tensor, step: float = 1.0): """ :param x: the input tensor :type x: torch.Tensor :param step: the quantize step :type step: float :return: quantized tensor :rtype: torch.Tensor The step quantize function. Here is an example: .. code-block:: python # plt.style.use(['science', 'muted', 'grid']) fig = plt.figure(dpi=200, figsize=(6, 4)) x = torch.arange(-4, 4, 0.001) plt.plot( x, lava_exchange.step_quantize(x, 2.0), label="quantize(x, step=2)" ) plt.plot(x, x, label="y=x", ls="-.") plt.legend() plt.grid(ls="--") plt.title("step quantize") plt.xlabel("Input") plt.ylabel("Output") plt.savefig( "./docs/source/_static/API/activation_based/lava_exchange/step_quantize.svg" ) plt.savefig( "./docs/source/_static/API/activation_based/lava_exchange/step_quantize.pdf" ) .. image:: ../_static/API/activation_based/lava_exchange/step_quantize.* :width: 100% """ return _step_quantize.apply(x, step)
[文档] def quantize_8bit(x: torch.Tensor, scale, descale=False): if descale: return ( step_quantize(x, 2.0 / scale).clamp(-256.0 / scale, 255.0 / scale) * scale ) else: return step_quantize(x, 2.0 / scale).clamp(-256.0 / scale, 255.0 / scale)
# ---------------------------------------- # convert function
[文档] def check_instance(m, instance): if not isinstance(m, instance): raise ValueError( f"expected {m} with type {instance}, but got {m} with type {type(m)}!" )
[文档] def check_no_bias(m): if m.bias is not None: raise ValueError(f"lava does not support for {type(m)} with bias!")
[文档] def to_lava_neuron_param_dict(sj_ms_neuron: nn.Module): if isinstance(sj_ms_neuron, neuron.IFNode): if sj_ms_neuron.v_reset != 0.0: raise ValueError("lava only supports for v_reset == 0!") return { "threshold": sj_ms_neuron.v_threshold, "current_decay": 1.0, "voltage_decay": 0.0, "tau_grad": 1, "scale_grad": 1, "scale": sj_ms_neuron.lava_s_cale, "norm": None, "dropout": None, "shared_param": True, "persistent_state": True, "requires_grad": False, "graded_spike": False, } elif isinstance(sj_ms_neuron, neuron.LIFNode): if sj_ms_neuron.v_reset != 0.0: raise ValueError("lava only supports for v_reset == 0!") if sj_ms_neuron.decay_input: raise ValueError("lava only supports for decay_input == False!") return { "threshold": sj_ms_neuron.v_threshold, "current_decay": 1.0, "voltage_decay": 1.0 / sj_ms_neuron.tau, "tau_grad": 1, "scale_grad": 1, "scale": sj_ms_neuron.lava_s_cale, "norm": None, "dropout": None, "shared_param": True, "persistent_state": True, "requires_grad": False, "graded_spike": False, } else: raise NotImplementedError(sj_ms_neuron)
[文档] def to_lava_neuron(sj_ms_neuron: nn.Module): if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)): return slayer.neuron.cuba.Neuron(**to_lava_neuron_param_dict(sj_ms_neuron)) else: raise NotImplementedError(sj_ms_neuron)
[文档] def linear_to_lava_synapse_dense(fc: nn.Linear): """ :param fc: a pytorch linear layer without bias :type fc: nn.Linear :return: a lava slayer dense synapse :rtype: slayer.synapse.Dense Codes example: .. code-block:: python T = 4 N = 2 layer_nn = nn.Linear(8, 4, bias=False) layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn) x_seq = torch.rand([T, N, 8]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX( layer_sl(lava_exchange.TNX_to_NXT(x_seq)) ) print("max error:", (y_nn - y_sl).abs().max()) """ check_instance(fc, nn.Linear) check_no_bias(fc) slayer_dense = slayer.synapse.Dense(fc.in_features, fc.out_features) # `slayer_dense` is a `torch.torch.nn.Conv3d`. Its weight has shape [out_features, in_features, 1, 1, 1] slayer_dense.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone() return slayer_dense
[文档] def conv2d_to_lava_synapse_conv(conv2d_nn: nn.Conv2d): """ :param conv2d_nn: a pytorch conv2d layer without bias :type conv2d_nn: nn.Conv2d :return: a lava slayer conv synapse :rtype: slayer.synapse.Conv Codes example: .. code-block:: python T = 4 N = 2 layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX( layer_sl(lava_exchange.TNX_to_NXT(x_seq)) ) print("max error:", (y_nn - y_sl).abs().max()) """ check_instance(conv2d_nn, nn.Conv2d) check_no_bias(conv2d_nn) slayer_conv = slayer.synapse.Conv( in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups, ) # `slayer_conv` is a `torch.torch.nn.Conv3d`. slayer_conv.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone() return slayer_conv
[文档] def avgpool2d_to_lava_synapse_pool(pool2d_nn: nn.AvgPool2d): """ :param pool2d_nn: a pytorch AvgPool2d layer :type pool2d_nn: nn.AvgPool2d :return: a lava slayer pool layer :rtype: slayer.synapse.Pool .. admonition:: Warning :class: warning The lava slayer pool layer applies sum pooling, rather than average pooling. .. code-block:: python T = 4 N = 2 layer_nn = nn.AvgPool2d(kernel_size=2, stride=2) layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = ( lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) / 4.0 ) print("max error:", (y_nn - y_sl).abs().max()) """ check_instance(pool2d_nn, nn.AvgPool2d) logging.warning( "The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer." ) return slayer.synapse.Pool( pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding )
[文档] def to_lava_block_dense( fc: nn.Linear, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True ): check_instance(fc, nn.Linear) check_no_bias(fc) neuron_params = to_lava_neuron_param_dict(sj_ms_neuron) if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)): block_init = slayer.block.cuba.Dense else: raise NotImplementedError(sj_ms_neuron) if quantize_to_8bit: # if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default lava_block = block_init( neuron_params, fc.in_features, fc.out_features, delay_shift=False ) else: lava_block = block_init( neuron_params, fc.in_features, fc.out_features, delay_shift=False, pre_hook_fx=None, ) lava_block.synapse.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone() return lava_block
[文档] def to_lava_block_conv( conv2d_nn: nn.Conv2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True ): check_instance(conv2d_nn, nn.Conv2d) check_no_bias(conv2d_nn) neuron_params = to_lava_neuron_param_dict(sj_ms_neuron) if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)): block_init = slayer.block.cuba.Conv else: raise NotImplementedError(sj_ms_neuron) if quantize_to_8bit: # if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default lava_block = block_init( neuron_params, in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups, delay_shift=False, ) else: lava_block = block_init( neuron_params, in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation, groups=conv2d_nn.groups, delay_shift=False, pre_hook_fx=None, ) lava_block.synapse.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone() return lava_block
[文档] def to_lava_block_pool( pool2d_nn: nn.AvgPool2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True ): check_instance(pool2d_nn, nn.AvgPool2d) neuron_params = to_lava_neuron_param_dict(sj_ms_neuron) if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)): block_init = slayer.block.cuba.Pool else: raise NotImplementedError(sj_ms_neuron) if quantize_to_8bit: # if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default lava_block = block_init( neuron_params, pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding, delay_shift=False, ) else: lava_block = block_init( neuron_params, pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding, delay_shift=False, pre_hook_fx=None, ) logging.warning( "The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer." ) return lava_block
[文档] def to_lava_block_flatten(flatten_nn: nn.Flatten): check_instance(flatten_nn, nn.Flatten) if flatten_nn.start_dim != 1: raise ValueError("lava only supports for flatten_nn.start_dim == 1!") return slayer.block.cuba.Flatten()
[文档] def to_lava_blocks(net: Union[list, tuple, nn.Sequential]): # https://lava-nc.org/lava-lib-dl/netx/netx.html """ Supported layer types input : {shape, type} flatten: {shape, type} average: {shape, type} concat : {shape, type, layers} dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)} pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight} conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride, | padding, dilation, groups, weight, delay(if available)} | |-> this is the description of the compartment parameters |-> {iDecay, vDecay, vThMant, refDelay, ... (other additional params)} """ blocks = [] length = net.__len__() i = 0 k = None while True: if isinstance(net[i], nn.Linear): if k is not None: if isinstance(net[i], (nn.Conv2d, nn.Linear)): net[i].weight.data /= k else: raise NotImplementedError(type(net[i])) k = None if i + 1 < length and isinstance( net[i + 1], (neuron.IFNode, neuron.LIFNode) ): blocks.append(to_lava_block_dense(net[i], net[i + 1])) i += 2 else: raise ValueError(type(net[i])) elif isinstance(net[i], nn.Conv2d): if i + 1 < length and isinstance( net[i + 1], (neuron.IFNode, neuron.LIFNode) ): blocks.append(to_lava_block_conv(net[i], net[i + 1])) i += 2 else: raise ValueError(type(net[i])) elif isinstance(net[i], nn.AvgPool2d): if i + 1 < length and isinstance( net[i + 1], (neuron.IFNode, neuron.LIFNode) ): blocks.append(to_lava_block_pool(net[i], net[i + 1])) i += 2 if isinstance(net[i].kernel_size, int): k = float(net[i].kernel_size * net[i].kernel_size) else: k = float(net[i].kernel_size[0] * net[i].kernel_size[1]) else: raise ValueError(type(net[i])) elif isinstance(net[i], nn.Flatten): blocks.append(to_lava_block_flatten(net[i])) i += 1 else: raise ValueError(type(net[i])) if i == length: break return blocks
[文档] class SumPool2d(nn.Module): def __init__(self, kernel_size, stride=None, padding=0, dilation=1): """ .. code-block:: python x = torch.rand([4, 2, 4, 16, 16]) with torch.no_grad(): sp_sj = SumPool2d(kernel_size=2, stride=2) y_sj = functional.seq_to_ann_forward(x, sp_sj) sp_la = slayer.synapse.Pool(kernel_size=2, stride=2) y_la = lava_exchange.NXT_to_TNX(sp_la(lava_exchange.TNX_to_NXT(x))) print((y_sj - y_la).abs().sum()) """ super().__init__() temp_conv = nn.Conv2d( in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False, ) self.weight = torch.ones_like(temp_conv.weight.data) self.kernel_size = temp_conv.kernel_size self.stride = temp_conv.stride self.padding = temp_conv.padding self.dilation = temp_conv.dilation del temp_conv
[文档] def forward(self, x: torch.Tensor): # x.shape = [N, C, H, W] if self.dilation == (1, 1): return ( F.avg_pool2d(x, self.kernel_size, self.stride, self.padding) * self.weight.numel() ) else: N, C, H, W = x.shape x = x.view(N * C, 1, H, W) x = F.conv2d( x, weight=self.weight, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, ) x = x.view(N, C, x.shape[2], x.shape[3]) return x
[文档] class BlockContainer(nn.Module, base.StepModule): @property def step_mode(self): return self._step_mode @step_mode.setter def step_mode(self, value: str): if value not in self.supported_step_mode(): raise ValueError( f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!' ) self._step_mode = value if isinstance(self.neuron, base.StepModule): self.neuron.step_mode = value if isinstance(self.synapse, base.StepModule): self.synapse.step_mode = value def __init__( self, synapse: Union[nn.Conv2d, nn.Linear, nn.AvgPool2d, nn.Flatten], neu: Optional[CubaLIFNode], step_mode: str = "s", ): super().__init__() if isinstance(synapse, nn.Flatten): assert neu is None self.synapse = synapse self.neuron = None if synapse.start_dim != 1: raise ValueError( "lava only supports for torch.nn.Flatten with start_dim == 1!" ) else: if isinstance(neu, neuron.IFNode): if neu.v_reset != 0.0: raise ValueError("lava only supports for v_reset == 0!") neu = CubaLIFNode( current_decay=1.0, voltage_decay=0.0, v_threshold=neu.v_threshold, scale=neu.lava_s_cale, ) elif isinstance(neu, neuron.LIFNode): if neu.v_reset != 0.0: raise ValueError("lava only supports for v_reset == 0!") if neu.decay_input: raise ValueError("lava only supports for decay_input == False!") neu = CubaLIFNode( current_decay=1.0, voltage_decay=1.0 / neu.tau, v_threshold=neu.v_threshold, scale=neu.lava_s_cale, ) else: assert isinstance(neu, CubaLIFNode) self.synapse = synapse self.neuron = neu if isinstance(self.synapse, (nn.Conv2d, nn.Linear)): assert self.synapse.bias is None self.step_mode = step_mode
[文档] def forward(self, x: torch.Tensor): if self.step_mode == "m": T = x.shape[0] N = x.shape[1] x = x.flatten(0, 1) if isinstance(self.synapse, (nn.Conv2d, nn.Linear)): weight = self.neuron.quantize_8bit(self.synapse.weight) # 量化到 2k / self.neuron.scale, k = -128, -127, ..., 127,共有256个取值 if isinstance(self.synapse, nn.Conv2d): x = F.conv2d( x, weight=weight, bias=self.synapse.bias, stride=self.synapse.stride, padding=self.synapse.padding, dilation=self.synapse.dilation, groups=self.synapse.groups, ) elif isinstance(self.synapse, nn.Linear): x = F.linear(x, weight, self.synapse.bias) elif isinstance(self.synapse, (SumPool2d, nn.Flatten)): x = self.synapse(x) else: raise NotImplementedError(type(self.synapse)) if self.step_mode == "m": x = x.view([T, N, *x.shape[1:]]) if self.neuron is not None: x = self.neuron(x) return x
[文档] def to_lava_block(self): if isinstance(self.synapse, nn.Linear): lava_block = slayer.block.cuba.Dense( self.neuron.lava_cuba_neuron_params, self.synapse.in_features, self.synapse.out_features, delay_shift=False, ) lava_block.synapse.weight.data[:, :, 0, 0, 0] = ( self.synapse.weight.data.clone() ) elif isinstance(self.synapse, nn.Conv2d): lava_block = slayer.block.cuba.Conv( self.neuron.lava_cuba_neuron_params, in_features=self.synapse.in_channels, out_features=self.synapse.out_channels, kernel_size=self.synapse.kernel_size, stride=self.synapse.stride, padding=self.synapse.padding, dilation=self.synapse.dilation, groups=self.synapse.groups, delay_shift=False, ) lava_block.synapse.weight.data[:, :, :, :, 0] = ( self.synapse.weight.data.clone() ) elif isinstance(self.synapse, SumPool2d): lava_block = slayer.block.cuba.Pool( self.neuron.lava_cuba_neuron_params, self.synapse.kernel_size, self.synapse.stride, self.synapse.padding, self.synapse.dilation, delay_shift=False, ) elif isinstance(self.synapse, nn.Flatten): return slayer.block.cuba.Flatten() else: raise NotImplementedError # 补上norm if self.neuron.norm is not None: lava_block.neuron.norm = self.neuron.norm.to_lava() return lava_block
except BaseException as e: logging.info(f"spikingjelly.activation_based.lava_exchange: {e}") slayer = None