import logging
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import base, neuron, surrogate
_hw_bits = 12
[文档]
def step_quantize_forward(x: torch.Tensor, step: float):
r"""
**API Language:**
:ref:`中文 <step_quantize_forward-cn>` | :ref:`English <step_quantize_forward-en>`
----
.. _step_quantize_forward-cn:
* **中文**
* **中文**
``step_quantize`` 的前向量化函数。将 ``x`` 除以 ``step``,四舍五入后再乘回 ``step``。
:param x: 输入张量
:type x: torch.Tensor
:param step: 量化步长
:type step: float
:return: 量化后的张量
:rtype: torch.Tensor
----
.. _step_quantize_forward-en:
* **English**
* **English**
The forward quantization function of ``step_quantize``. Divide ``x`` by ``step``, round, and multiply back by ``step``.
:param x: Input tensor
:type x: torch.Tensor
:param step: Quantization step
:type step: float
:return: Quantized tensor
:rtype: torch.Tensor
"""
x = x / step
torch.round_(x)
return x * step
[文档]
class step_quantize_atgf(torch.autograd.Function):
r"""
**API Language:**
:ref:`中文 <step_quantize_atgf-cn>` | :ref:`English <step_quantize_atgf-en>`
----
.. _step_quantize_atgf-cn:
* **中文**
* **中文**
``step_quantize`` 的自定义自动求导函数。前向使用 ``step_quantize_forward`` 进行量化,反向使用直通估计器(Straight-Through Estimator)。
----
.. _step_quantize_atgf-en:
* **English**
* **English**
Custom autograd Function for ``step_quantize``. Uses ``step_quantize_forward`` for forward quantization and a straight-through estimator (STE) for backward.
"""
[文档]
@staticmethod
def forward(ctx, x: torch.Tensor, step: float = 1.0):
return step_quantize_forward(x, step)
[文档]
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
def step_quantize(x: torch.Tensor, step: float = 1.0):
r"""
**API Language:**
:ref:`中文 <step_quantize-cn>` | :ref:`English <step_quantize-en>`
----
.. _step_quantize-cn:
* **中文**
* **中文**
步进量化器,定义在 `Lava` 中。
记 ``k`` 为 ``int``,``x[i]`` 将被量化到最近的 ``k * step``。
:param x: 浮点张量,取值范围为 ``0 <= x <= 1``。
:type x: torch.Tensor
:param step: 量化步长
:type step: float
:return: ``y = round(x / step) * step``
:rtype: torch.Tensor
----
.. _step_quantize-en:
* **English**
* **English**
The step quantizer defined in `Lava`.
Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``k * step``.
:param x: a float tensor whose range is ``0 <= x <= 1``.
:type x: torch.Tensor
:param step: the quantization step
:type step: float
:return: ``y = round(x / step) * step``
:rtype: torch.Tensor
"""
return step_quantize_atgf.apply(x, step)
[文档]
def quantize_8b(x, scale, descale=False):
r"""
**API Language:**
:ref:`中文 <quantize_8b-cn>` | :ref:`English <quantize_8b-en>`
----
.. _quantize_8b-cn:
* **中文**
* **中文**
记 ``k`` 为 ``int``,``x[i]`` 将被量化到最近的 ``2 * k / scale``,其中 ``k = {-128, -127, ..., 126, 127}``。
:param x: 输入张量
:type x: torch.Tensor
:param scale: 缩放因子
:type scale: float
:param descale: 是否进行反缩放
:type descale: bool
:return: 量化后的张量
:rtype: torch.Tensor
----
.. _quantize_8b-en:
* **English**
* **English**
Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``2 * k / scale``, and ``k = {-128, -127, ..., 126, 127}``.
:param x: input tensor
:type x: torch.Tensor
:param scale: scale factor
:type scale: float
:param descale: whether to descale
:type descale: bool
:return: quantized tensor
:rtype: torch.Tensor
"""
if not descale:
return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale)
else:
return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale) * scale
[文档]
def right_shift_to_zero(x: torch.Tensor, bits: int):
r"""
**API Language:**
:ref:`中文 <right_shift_to_zero-cn>` | :ref:`English <right_shift_to_zero-en>`
----
.. _right_shift_to_zero-cn:
* **中文**
* **中文**
带符号的右移运算,向零舍入。计算 ``sign(x) * (|x| >> bits)``,确保负数向零舍入。
:param x: 输入整数张量,须为 ``torch.int32`` 或 ``torch.int64``
:type x: torch.Tensor
:param bits: 右移位数
:type bits: int
:return: 右移后的整数张量
:rtype: torch.Tensor
----
.. _right_shift_to_zero-en:
* **English**
* **English**
Signed right shift with rounding toward zero. Computes ``sign(x) * (|x| >> bits)`` so that negative values shift toward zero.
:param x: Input integer tensor, must be ``torch.int32`` or ``torch.int64``
:type x: torch.Tensor
:param bits: Number of bits to shift
:type bits: int
:return: Right-shifted integer tensor
:rtype: torch.Tensor
"""
dtype = x.dtype
assert dtype in (torch.int32, torch.int64)
return (torch.sign(x) * (torch.abs(x) >> bits)).to(dtype)
def _listep_forward(
x: torch.Tensor,
decay: torch.Tensor,
state: torch.Tensor,
w_scale: int,
dtype: torch.dtype = torch.int32,
hw_bits: int = 12,
):
# y = (state * w_scale * ((1 << hw_bits) - decay) / (1 << hw_bits) + w_scale * x) / w_scale
# y = state * (1 - decay / (1 << hw_bits)) + x
scaled_state = (state * w_scale).to(dtype=dtype)
decay_int = (1 << hw_bits) - decay.to(dtype=dtype)
output = right_shift_to_zero(scaled_state * decay_int, hw_bits) + (w_scale * x).to(
dtype=dtype
)
return output / w_scale
def _listep_backward(
grad_output: torch.Tensor,
decay: torch.Tensor,
state: torch.Tensor,
hw_bits: int = 12,
):
grad_state = (1 - decay / (1 << hw_bits)) * grad_output
grad_decay = -state / (1 << hw_bits) * grad_output
grad_decay = grad_decay.sum()
return grad_output, grad_decay, grad_state
# x, decay, state
[文档]
class BatchNorm2d(nn.Module):
def __init__(
self,
num_features: int,
eps: float = 1e-05,
momentum: float = 0.1,
track_running_stats: bool = True,
weight_exp_bits: int = 3,
pre_hook_fx: Callable = lambda x: x,
):
r"""
**API Language:**
:ref:`中文 <BatchNorm2d.__init__-cn>` | :ref:`English <BatchNorm2d.__init__-en>`
----
.. _BatchNorm2d.__init__-cn:
* **中文**
用于 Lava 交换的带权重量化的批归一化层,参考 ``lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm``。
与标准 ``nn.BatchNorm2d`` 不同,该层使用基于 2 的幂的量化标准差进行归一化,且不含可学习的仿射参数。
:param num_features: 特征通道数
:type num_features: int
:param eps: 防止除零的小常数
:type eps: float
:param momentum: running 统计量的动量
:type momentum: float
:param track_running_stats: 是否追踪运行统计量
:type track_running_stats: bool
:param weight_exp_bits: 权重指数位数
:type weight_exp_bits: int
:param pre_hook_fx: 归一化前对均值的预处理函数
:type pre_hook_fx: Callable
:return: ``None``
:rtype: None
----
.. _BatchNorm2d.__init__-en:
* **English**
Weight-quantized batch normalization for Lava exchange, adapted from ``lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm``.
Unlike standard ``nn.BatchNorm2d``, this layer uses power-of-2 quantized standard deviation for normalization and has no learnable affine parameters.
:param num_features: Number of feature channels
:type num_features: int
:param eps: Small constant for numerical stability
:type eps: float
:param momentum: Momentum for running statistics
:type momentum: float
:param track_running_stats: Whether to track running statistics
:type track_running_stats: bool
:param weight_exp_bits: Number of bits for weight exponent
:type weight_exp_bits: int
:param pre_hook_fx: Pre-processing function applied to mean before normalization
:type pre_hook_fx: Callable
:return: ``None``
:rtype: None
"""
super().__init__()
# lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.track_running_stats = track_running_stats
self.weight_exp_bits = weight_exp_bits
self.pre_hook_fx = pre_hook_fx
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.zeros(1))
[文档]
def to_lava(self):
bn = slayer.neuron.norm.WgtScaleBatchNorm(
num_features=self.num_features,
momentum=self.momentum,
weight_exp_bits=self.weight_exp_bits,
eps=self.eps,
pre_hook_fx=self.pre_hook_fx,
)
bn.load_state_dict(self.state_dict())
print(self.state_dict())
return bn
[文档]
def forward(self, x: torch.Tensor):
if self.track_running_stats and self.training:
x_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
x_var = torch.var(x, unbiased=False)
numel = x.numel() / self.num_features
with torch.no_grad():
self.running_mean = (
1.0 - self.momentum
) * self.running_mean + self.momentum * x_mean.squeeze()
self.running_var = (
1.0 - self.momentum
) * self.running_var + self.momentum * x_var * numel / (numel + 1)
else:
x_mean = self.running_mean.view(1, -1, 1, 1)
x_var = self.running_var.view(1, -1, 1, 1)
x_std = torch.sqrt(x_var + self.eps)
x_std = torch.pow(
2.0,
torch.ceil(torch.log2(x_std)).clamp(
-self.weight_exp_bits, self.weight_exp_bits
),
)
return (x - self.pre_hook_fx(x_mean)) / x_std
[文档]
class LeakyIntegratorStep(torch.autograd.Function):
r"""
**API Language:**
:ref:`中文 <LeakyIntegratorStep-cn>` | :ref:`English <LeakyIntegratorStep-en>`
----
.. _LeakyIntegratorStep-cn:
* **中文**
* **中文**
泄露积分器(Leaky Integrator)的自定义自动求导函数,用于 Lava 交换中的电流/电压衰减计算。
前向通过 ``_listep_forward`` 实现整数算术的泄露积分,反向通过 ``_listep_backward`` 实现梯度传播。
----
.. _LeakyIntegratorStep-en:
* **English**
* **English**
Custom autograd Function for the Leaky Integrator used in Lava exchange current/voltage decay.
Forward implements leaky integration via integer arithmetic through ``_listep_forward``,
and backward propagates gradients through ``_listep_backward``.
"""
[文档]
@staticmethod
def forward(ctx, x, decay, state, w_scale):
output = _listep_forward(
x, decay, state, w_scale, dtype=torch.int64, hw_bits=_hw_bits
)
if x.requires_grad or state.requires_grad:
ctx.save_for_backward(decay, state)
return output
[文档]
@staticmethod
def backward(ctx, grad_output):
decay, state = ctx.saved_tensors
grad_input, grad_decay, grad_state = _listep_backward(
grad_output, decay, state, hw_bits=_hw_bits
)
return grad_input, grad_decay, grad_state, None
[文档]
class CubaLIFNode(neuron.BaseNode):
def __init__(
self,
current_decay: Union[float, torch.Tensor],
voltage_decay: Union[float, torch.Tensor],
v_threshold: float = 1.0,
v_reset: float = 0.0,
scale=1 << 6,
requires_grad=False,
surrogate_function: Callable = surrogate.Sigmoid(),
norm: BatchNorm2d = None,
detach_reset=False,
step_mode="s",
backend="torch",
store_v_seq: bool = False,
store_i_seq: bool = False,
):
# author: https://github.com/AllenYolk
r"""
**API Language:**
:ref:`中文 <CubaLIFNode.__init__-cn>` | :ref:`English <CubaLIFNode.__init__-en>`
----
.. _CubaLIFNode.__init__-cn:
* **中文**
:param current_decay: 电流衰减常数
:type current_decay: Union[float, torch.Tensor]
:param voltage_decay: 电压衰减常数
:type voltage_decay: Union[float, torch.Tensor]
:param v_threshold: 神经元阈值电压。默认为1。
:type v_threshold: float
:param v_reset: 重置电压,默认为0
:type v_reset: float, None
:param scale: 量化参数,控制神经元的量化精度(参考了lava-dl的cuba.Neuron)。默认为 ``1<<6`` 。
等效于``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``。
:type scale: float
:param requires_grad: 指明 ``current_decay`` 和 ``voltage_decay`` 两个神经元参数是否可学习(是否需要梯度),默认为 ``False`` 。
:type requires_grad: bool
:param detach_reset: 是否将reset的计算图分离,默认为 ``False`` 。
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步)或 `'m'` (多步),默认为 `'s'` 。
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。目前只支持torch
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.voltage_state`` 。
通常设置成 ``False`` ,可以节省内存。
:type store_v_seq: bool
:param store_i_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电流值 ``self.i_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电流,即 ``shape = [N, *]`` 的 ``self.current_state`` 。
通常设置成 ``False`` ,可以节省内存。
:type store_i_seq: bool
:param surrogate_function: 替代梯度函数。默认为 ``surrogate.Sigmoid()``
:type surrogate_function: Callable
:param norm: 量化归一化层,可选。若提供,则在每个时间步前对输入进行量化
:type norm: BatchNorm2d, optional
.. math::
I[t] = (1 - \\alpha_{I})I[t-1] + X[t]
V[t] = (1 - \\alpha_{V})V[t-1] + I[t]
----
.. _CubaLIFNode.__init__-en:
* **English**
:param current_decay: current decay constant
:type current_decay: Union[float, torch.Tensor]
:param v_threshold: threshold of the the neurons in this layer. Default to 1.
:type v_threshold: float
:param v_reset: reset potential of the neurons in this layer, 0 by default
:type v_reset: float
:param scale: quantization precision (ref: lava-dl cuba.Neuron). Default to ``1<<6`` .
Equivalent to ``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``.
:type scale: float
:param requires_grad: whether ``current_decay`` and ``voltage_decay`` are learnable. Default to ``False`` .
:type requires_grad: bool
:param detach_reset: whether to detach the computational graph of reset in backward pass. Default to ``False`` .
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step). Default to `'s'` .
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. Only `torch` is supported.
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.voltage_state`` with ``shape = [N, *]``, which can reduce the
memory consumption. Default to ``False`` .
:type store_v_seq: bool
:param store_i_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the current at each time-step to ``self.i_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the current at last time-step will be stored to ``self.current_state`` with ``shape = [N, *]``, which can reduce the
memory consumption. Default to ``False`` .
:type store_i_seq: bool
.. math::
I[t] = (1 - \alpha_{I})I[t-1] + X[t]
V[t] = (1 - \alpha_{V})V[t-1] + I[t]
"""
self.lava_cuba_neuron_params = {
"threshold": v_threshold,
"current_decay": current_decay,
"voltage_decay": voltage_decay,
"scale": scale,
}
super().__init__(
v_threshold=v_threshold,
v_reset=v_reset,
surrogate_function=surrogate_function,
detach_reset=detach_reset,
step_mode=step_mode,
backend=backend,
store_v_seq=store_v_seq,
)
self.store_i_seq = store_i_seq
assert v_reset == 0.0, (
"CubaLIFNode only supports for hard reset with v_reset = 0. !"
)
self.requires_grad = requires_grad
# the default quantization parameter setting in lava
self._scale = int(scale)
self._s_scale = int(scale * (1 << 6))
self._p_scale = 1 << _hw_bits
# Which is equivalent to:
# self.p_scale = 1<<12
# self.w_scale = int(scale)
# self.s_scale = int(scale * (1<<6))
self._v_threshold = int(v_threshold * self.scale) / self.scale
# ``_v_threshold`` is the nearest and no more than ``k / scale`` to ``v_threshold`` where ``k`` is an ``int``
self.v_threshold_eps = 0.01 / self.s_scale
# loihi use s[t] = v[t] > v_th, but we use s[t] = v[t] >= v_th. Thus, we use v[t] + eps >= v_th to approximate
current_decay = torch.tensor(self.p_scale * current_decay, dtype=torch.float32)
voltage_decay = torch.tensor(self.p_scale * voltage_decay, dtype=torch.float32)
if requires_grad:
self.current_decay = nn.Parameter(current_decay)
self.voltage_decay = nn.Parameter(voltage_decay)
else:
self.register_buffer("current_decay", current_decay)
self.register_buffer("voltage_decay", voltage_decay)
self.register_memory("current_state", 0.0)
self.register_memory("voltage_state", 0.0)
self.clamp_decay_parameters()
self.norm = norm
if self.norm is not None:
if isinstance(self.norm, BatchNorm2d):
self.norm.pre_hook_fx = self.quantize_8bit
else:
raise NotImplementedError(self.norm)
[文档]
def quantize_8bit(self, x, descale=False):
return quantize_8b(x, scale=self.scale, descale=descale)
[文档]
def clamp_decay_parameters(self):
with torch.no_grad():
self.current_decay.data.clamp_(min=0.0, max=self.p_scale)
self.voltage_decay.data.clamp_(min=0.0, max=self.p_scale)
@property
def scale(self):
r"""
**API Language:**
:ref:`中文 <CubaLIFNode.scale-cn>` | :ref:`English <CubaLIFNode.scale-en>`
----
.. _CubaLIFNode.scale-cn:
* **中文**
* **中文**
:return: 突触权重缩放因子
:rtype: float
----
.. _CubaLIFNode.scale-en:
* **English**
* **English**
:return: Synaptic weight scaling factor
:rtype: float
"""
return self._scale
@property
def s_scale(self):
r"""
**API Language:**
:ref:`中文 <CubaLIFNode.s_scale-cn>` | :ref:`English <CubaLIFNode.s_scale-en>`
----
.. _CubaLIFNode.s_scale-cn:
* **中文**
* **中文**
:return: 突触缩放因子
:rtype: float
----
.. _CubaLIFNode.s_scale-en:
* **English**
* **English**
:return: Synaptic scaling factor
:rtype: float
"""
return self._s_scale
@property
def p_scale(self):
r"""
**API Language:**
:ref:`中文 <CubaLIFNode.p_scale-cn>` | :ref:`English <CubaLIFNode.p_scale-en>`
----
.. _CubaLIFNode.p_scale-cn:
* **中文**
* **中文**
:return: 电压缩放因子
:rtype: float
----
.. _CubaLIFNode.p_scale-en:
* **English**
* **English**
:return: Voltage scaling factor
:rtype: float
"""
return self._p_scale
@property
def store_i_seq(self):
return self._store_i_seq
@store_i_seq.setter
def store_i_seq(self, value: bool):
self._store_i_seq = value
if value:
if not hasattr(self, "i_seq"):
self.register_memory("i_seq", None)
@property
def supported_backends(self):
if self.step_mode == "m" or self.step_mode == "s":
return ("torch",)
else:
raise ValueError(
f"self.step_mode should be 's' or 'm', "
f"but get {self.step_mode} instead."
)
# computation process
[文档]
def state_initialization(self, x: torch.Tensor):
if isinstance(self.current_state, float):
self.current_state = torch.zeros_like(x.data)
if isinstance(self.voltage_state, float):
self.voltage_state = torch.zeros_like(x.data)
[文档]
def neuronal_charge(self, x: torch.Tensor):
if self.requires_grad:
self.clamp_decay_parameters()
current = LeakyIntegratorStep.apply(
x,
step_quantize(self.current_decay),
self.current_state.contiguous(),
self.s_scale,
)
if self.norm is not None:
current = self.norm(current)
voltage = LeakyIntegratorStep.apply(
current,
step_quantize(self.voltage_decay),
self.voltage_state.contiguous(),
self.s_scale,
)
self.current_state = current
self.voltage_state = voltage
[文档]
def neuronal_fire(self):
return self.surrogate_function(
self.voltage_state - (self.v_threshold + self.v_threshold_eps)
)
[文档]
def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
self.voltage_state = self.apply_hard_reset(
self.voltage_state, spike_d, self.v_reset
)
[文档]
def single_step_forward(self, x):
self.state_initialization(x)
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档]
def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
y_seq = []
if self.store_v_seq:
v_seq = []
if self.store_i_seq:
i_seq = []
for t in range(T):
y = self.single_step_forward(x_seq[t])
y_seq.append(y)
if self.store_v_seq:
v_seq.append(self.voltage_state)
if self.store_i_seq:
i_seq.append(self.current_state)
if self.store_v_seq:
self.v_seq = torch.stack(v_seq)
if self.store_i_seq:
self.i_seq = torch.stack(i_seq)
return torch.stack(y_seq)
try:
import lava.lib.dl.slayer as slayer
# ----------------------------------------
# data reshape function
[文档]
def TNX_to_NXT(x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
permute_args = list(range(1, x_seq.dim()))
permute_args.append(0)
return x_seq.permute(permute_args)
[文档]
def NXT_to_TNX(x_seq: torch.Tensor):
# x_seq.shape = [N, *, T]
permute_args = list(range(x_seq.dim() - 1))
permute_args.insert(0, x_seq.dim() - 1)
return x_seq.permute(permute_args)
[文档]
def lava_neuron_forward(
lava_neuron: nn.Module, x_seq: torch.Tensor, v: Union[torch.Tensor, float]
):
# x_seq.shape = [T, N, *]
# lave uses shape = [*, T], while SJ uses shape = [T, *]
unsqueeze_flag = False
if x_seq.dim() == 2:
x_seq = x_seq.unsqueeze(1)
# lave needs input with shape [N, ... ,T]
unsqueeze_flag = True
if isinstance(v, float):
v_init = v
v = torch.zeros_like(x_seq[0])
if v_init != 0.0:
torch.fill_(v, v_init)
x_seq_shape = x_seq.shape
x_seq = x_seq.flatten(2).permute(1, 2, 0)
# [T, N, *] -> [N, *, T]
lava_neuron.voltage_state = v
spike = lava_neuron(x_seq).permute(2, 0, 1)
v = lava_neuron.voltage_state.reshape(x_seq_shape[1:])
spike = spike.reshape(x_seq_shape)
if unsqueeze_flag:
v = v.squeeze(1)
spike = spike.squeeze(1)
return spike, v
# ----------------------------------------
# quantize function
class _step_quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, step):
return torch.round(x / step) * step
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
[文档]
def step_quantize(x: torch.Tensor, step: float = 1.0):
"""
:param x: the input tensor
:type x: torch.Tensor
:param step: the quantize step
:type step: float
:return: quantized tensor
:rtype: torch.Tensor
The step quantize function. Here is an example:
.. code-block:: python
# plt.style.use(['science', 'muted', 'grid'])
fig = plt.figure(dpi=200, figsize=(6, 4))
x = torch.arange(-4, 4, 0.001)
plt.plot(
x, lava_exchange.step_quantize(x, 2.0), label="quantize(x, step=2)"
)
plt.plot(x, x, label="y=x", ls="-.")
plt.legend()
plt.grid(ls="--")
plt.title("step quantize")
plt.xlabel("Input")
plt.ylabel("Output")
plt.savefig(
"./docs/source/_static/API/activation_based/lava_exchange/step_quantize.svg"
)
plt.savefig(
"./docs/source/_static/API/activation_based/lava_exchange/step_quantize.pdf"
)
.. image:: ../_static/API/activation_based/lava_exchange/step_quantize.*
:width: 100%
"""
return _step_quantize.apply(x, step)
[文档]
def quantize_8bit(x: torch.Tensor, scale, descale=False):
if descale:
return (
step_quantize(x, 2.0 / scale).clamp(-256.0 / scale, 255.0 / scale)
* scale
)
else:
return step_quantize(x, 2.0 / scale).clamp(-256.0 / scale, 255.0 / scale)
# ----------------------------------------
# convert function
[文档]
def check_instance(m, instance):
if not isinstance(m, instance):
raise ValueError(
f"expected {m} with type {instance}, but got {m} with type {type(m)}!"
)
[文档]
def check_no_bias(m):
if m.bias is not None:
raise ValueError(f"lava does not support for {type(m)} with bias!")
[文档]
def to_lava_neuron_param_dict(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, neuron.IFNode):
if sj_ms_neuron.v_reset != 0.0:
raise ValueError("lava only supports for v_reset == 0!")
return {
"threshold": sj_ms_neuron.v_threshold,
"current_decay": 1.0,
"voltage_decay": 0.0,
"tau_grad": 1,
"scale_grad": 1,
"scale": sj_ms_neuron.lava_s_cale,
"norm": None,
"dropout": None,
"shared_param": True,
"persistent_state": True,
"requires_grad": False,
"graded_spike": False,
}
elif isinstance(sj_ms_neuron, neuron.LIFNode):
if sj_ms_neuron.v_reset != 0.0:
raise ValueError("lava only supports for v_reset == 0!")
if sj_ms_neuron.decay_input:
raise ValueError("lava only supports for decay_input == False!")
return {
"threshold": sj_ms_neuron.v_threshold,
"current_decay": 1.0,
"voltage_decay": 1.0 / sj_ms_neuron.tau,
"tau_grad": 1,
"scale_grad": 1,
"scale": sj_ms_neuron.lava_s_cale,
"norm": None,
"dropout": None,
"shared_param": True,
"persistent_state": True,
"requires_grad": False,
"graded_spike": False,
}
else:
raise NotImplementedError(sj_ms_neuron)
[文档]
def to_lava_neuron(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
return slayer.neuron.cuba.Neuron(**to_lava_neuron_param_dict(sj_ms_neuron))
else:
raise NotImplementedError(sj_ms_neuron)
[文档]
def linear_to_lava_synapse_dense(fc: nn.Linear):
"""
:param fc: a pytorch linear layer without bias
:type fc: nn.Linear
:return: a lava slayer dense synapse
:rtype: slayer.synapse.Dense
Codes example:
.. code-block:: python
T = 4
N = 2
layer_nn = nn.Linear(8, 4, bias=False)
layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn)
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(
layer_sl(lava_exchange.TNX_to_NXT(x_seq))
)
print("max error:", (y_nn - y_sl).abs().max())
"""
check_instance(fc, nn.Linear)
check_no_bias(fc)
slayer_dense = slayer.synapse.Dense(fc.in_features, fc.out_features)
# `slayer_dense` is a `torch.torch.nn.Conv3d`. Its weight has shape [out_features, in_features, 1, 1, 1]
slayer_dense.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()
return slayer_dense
[文档]
def conv2d_to_lava_synapse_conv(conv2d_nn: nn.Conv2d):
"""
:param conv2d_nn: a pytorch conv2d layer without bias
:type conv2d_nn: nn.Conv2d
:return: a lava slayer conv synapse
:rtype: slayer.synapse.Conv
Codes example:
.. code-block:: python
T = 4
N = 2
layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False)
layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(
layer_sl(lava_exchange.TNX_to_NXT(x_seq))
)
print("max error:", (y_nn - y_sl).abs().max())
"""
check_instance(conv2d_nn, nn.Conv2d)
check_no_bias(conv2d_nn)
slayer_conv = slayer.synapse.Conv(
in_features=conv2d_nn.in_channels,
out_features=conv2d_nn.out_channels,
kernel_size=conv2d_nn.kernel_size,
stride=conv2d_nn.stride,
padding=conv2d_nn.padding,
dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups,
)
# `slayer_conv` is a `torch.torch.nn.Conv3d`.
slayer_conv.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()
return slayer_conv
[文档]
def avgpool2d_to_lava_synapse_pool(pool2d_nn: nn.AvgPool2d):
"""
:param pool2d_nn: a pytorch AvgPool2d layer
:type pool2d_nn: nn.AvgPool2d
:return: a lava slayer pool layer
:rtype: slayer.synapse.Pool
.. admonition:: Warning
:class: warning
The lava slayer pool layer applies sum pooling, rather than average pooling.
.. code-block:: python
T = 4
N = 2
layer_nn = nn.AvgPool2d(kernel_size=2, stride=2)
layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = (
lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq)))
/ 4.0
)
print("max error:", (y_nn - y_sl).abs().max())
"""
check_instance(pool2d_nn, nn.AvgPool2d)
logging.warning(
"The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer."
)
return slayer.synapse.Pool(
pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding
)
[文档]
def to_lava_block_dense(
fc: nn.Linear, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True
):
check_instance(fc, nn.Linear)
check_no_bias(fc)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Dense
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(
neuron_params, fc.in_features, fc.out_features, delay_shift=False
)
else:
lava_block = block_init(
neuron_params,
fc.in_features,
fc.out_features,
delay_shift=False,
pre_hook_fx=None,
)
lava_block.synapse.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()
return lava_block
[文档]
def to_lava_block_conv(
conv2d_nn: nn.Conv2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True
):
check_instance(conv2d_nn, nn.Conv2d)
check_no_bias(conv2d_nn)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Conv
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(
neuron_params,
in_features=conv2d_nn.in_channels,
out_features=conv2d_nn.out_channels,
kernel_size=conv2d_nn.kernel_size,
stride=conv2d_nn.stride,
padding=conv2d_nn.padding,
dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups,
delay_shift=False,
)
else:
lava_block = block_init(
neuron_params,
in_features=conv2d_nn.in_channels,
out_features=conv2d_nn.out_channels,
kernel_size=conv2d_nn.kernel_size,
stride=conv2d_nn.stride,
padding=conv2d_nn.padding,
dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups,
delay_shift=False,
pre_hook_fx=None,
)
lava_block.synapse.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()
return lava_block
[文档]
def to_lava_block_pool(
pool2d_nn: nn.AvgPool2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True
):
check_instance(pool2d_nn, nn.AvgPool2d)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Pool
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(
neuron_params,
pool2d_nn.kernel_size,
pool2d_nn.stride,
pool2d_nn.padding,
delay_shift=False,
)
else:
lava_block = block_init(
neuron_params,
pool2d_nn.kernel_size,
pool2d_nn.stride,
pool2d_nn.padding,
delay_shift=False,
pre_hook_fx=None,
)
logging.warning(
"The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer."
)
return lava_block
[文档]
def to_lava_block_flatten(flatten_nn: nn.Flatten):
check_instance(flatten_nn, nn.Flatten)
if flatten_nn.start_dim != 1:
raise ValueError("lava only supports for flatten_nn.start_dim == 1!")
return slayer.block.cuba.Flatten()
[文档]
def to_lava_blocks(net: Union[list, tuple, nn.Sequential]):
# https://lava-nc.org/lava-lib-dl/netx/netx.html
"""
Supported layer types
input : {shape, type}
flatten: {shape, type}
average: {shape, type}
concat : {shape, type, layers}
dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)}
pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight}
conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride,
| padding, dilation, groups, weight, delay(if available)}
|
|-> this is the description of the compartment parameters
|-> {iDecay, vDecay, vThMant, refDelay, ... (other additional params)}
"""
blocks = []
length = net.__len__()
i = 0
k = None
while True:
if isinstance(net[i], nn.Linear):
if k is not None:
if isinstance(net[i], (nn.Conv2d, nn.Linear)):
net[i].weight.data /= k
else:
raise NotImplementedError(type(net[i]))
k = None
if i + 1 < length and isinstance(
net[i + 1], (neuron.IFNode, neuron.LIFNode)
):
blocks.append(to_lava_block_dense(net[i], net[i + 1]))
i += 2
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.Conv2d):
if i + 1 < length and isinstance(
net[i + 1], (neuron.IFNode, neuron.LIFNode)
):
blocks.append(to_lava_block_conv(net[i], net[i + 1]))
i += 2
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.AvgPool2d):
if i + 1 < length and isinstance(
net[i + 1], (neuron.IFNode, neuron.LIFNode)
):
blocks.append(to_lava_block_pool(net[i], net[i + 1]))
i += 2
if isinstance(net[i].kernel_size, int):
k = float(net[i].kernel_size * net[i].kernel_size)
else:
k = float(net[i].kernel_size[0] * net[i].kernel_size[1])
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.Flatten):
blocks.append(to_lava_block_flatten(net[i]))
i += 1
else:
raise ValueError(type(net[i]))
if i == length:
break
return blocks
[文档]
class SumPool2d(nn.Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1):
"""
.. code-block:: python
x = torch.rand([4, 2, 4, 16, 16])
with torch.no_grad():
sp_sj = SumPool2d(kernel_size=2, stride=2)
y_sj = functional.seq_to_ann_forward(x, sp_sj)
sp_la = slayer.synapse.Pool(kernel_size=2, stride=2)
y_la = lava_exchange.NXT_to_TNX(sp_la(lava_exchange.TNX_to_NXT(x)))
print((y_sj - y_la).abs().sum())
"""
super().__init__()
temp_conv = nn.Conv2d(
in_channels=1,
out_channels=1,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False,
)
self.weight = torch.ones_like(temp_conv.weight.data)
self.kernel_size = temp_conv.kernel_size
self.stride = temp_conv.stride
self.padding = temp_conv.padding
self.dilation = temp_conv.dilation
del temp_conv
[文档]
def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
if self.dilation == (1, 1):
return (
F.avg_pool2d(x, self.kernel_size, self.stride, self.padding)
* self.weight.numel()
)
else:
N, C, H, W = x.shape
x = x.view(N * C, 1, H, W)
x = F.conv2d(
x,
weight=self.weight,
bias=None,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
)
x = x.view(N, C, x.shape[2], x.shape[3])
return x
[文档]
class BlockContainer(nn.Module, base.StepModule):
@property
def step_mode(self):
return self._step_mode
@step_mode.setter
def step_mode(self, value: str):
if value not in self.supported_step_mode():
raise ValueError(
f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!'
)
self._step_mode = value
if isinstance(self.neuron, base.StepModule):
self.neuron.step_mode = value
if isinstance(self.synapse, base.StepModule):
self.synapse.step_mode = value
def __init__(
self,
synapse: Union[nn.Conv2d, nn.Linear, nn.AvgPool2d, nn.Flatten],
neu: Optional[CubaLIFNode],
step_mode: str = "s",
):
super().__init__()
if isinstance(synapse, nn.Flatten):
assert neu is None
self.synapse = synapse
self.neuron = None
if synapse.start_dim != 1:
raise ValueError(
"lava only supports for torch.nn.Flatten with start_dim == 1!"
)
else:
if isinstance(neu, neuron.IFNode):
if neu.v_reset != 0.0:
raise ValueError("lava only supports for v_reset == 0!")
neu = CubaLIFNode(
current_decay=1.0,
voltage_decay=0.0,
v_threshold=neu.v_threshold,
scale=neu.lava_s_cale,
)
elif isinstance(neu, neuron.LIFNode):
if neu.v_reset != 0.0:
raise ValueError("lava only supports for v_reset == 0!")
if neu.decay_input:
raise ValueError("lava only supports for decay_input == False!")
neu = CubaLIFNode(
current_decay=1.0,
voltage_decay=1.0 / neu.tau,
v_threshold=neu.v_threshold,
scale=neu.lava_s_cale,
)
else:
assert isinstance(neu, CubaLIFNode)
self.synapse = synapse
self.neuron = neu
if isinstance(self.synapse, (nn.Conv2d, nn.Linear)):
assert self.synapse.bias is None
self.step_mode = step_mode
[文档]
def forward(self, x: torch.Tensor):
if self.step_mode == "m":
T = x.shape[0]
N = x.shape[1]
x = x.flatten(0, 1)
if isinstance(self.synapse, (nn.Conv2d, nn.Linear)):
weight = self.neuron.quantize_8bit(self.synapse.weight)
# 量化到 2k / self.neuron.scale, k = -128, -127, ..., 127,共有256个取值
if isinstance(self.synapse, nn.Conv2d):
x = F.conv2d(
x,
weight=weight,
bias=self.synapse.bias,
stride=self.synapse.stride,
padding=self.synapse.padding,
dilation=self.synapse.dilation,
groups=self.synapse.groups,
)
elif isinstance(self.synapse, nn.Linear):
x = F.linear(x, weight, self.synapse.bias)
elif isinstance(self.synapse, (SumPool2d, nn.Flatten)):
x = self.synapse(x)
else:
raise NotImplementedError(type(self.synapse))
if self.step_mode == "m":
x = x.view([T, N, *x.shape[1:]])
if self.neuron is not None:
x = self.neuron(x)
return x
[文档]
def to_lava_block(self):
if isinstance(self.synapse, nn.Linear):
lava_block = slayer.block.cuba.Dense(
self.neuron.lava_cuba_neuron_params,
self.synapse.in_features,
self.synapse.out_features,
delay_shift=False,
)
lava_block.synapse.weight.data[:, :, 0, 0, 0] = (
self.synapse.weight.data.clone()
)
elif isinstance(self.synapse, nn.Conv2d):
lava_block = slayer.block.cuba.Conv(
self.neuron.lava_cuba_neuron_params,
in_features=self.synapse.in_channels,
out_features=self.synapse.out_channels,
kernel_size=self.synapse.kernel_size,
stride=self.synapse.stride,
padding=self.synapse.padding,
dilation=self.synapse.dilation,
groups=self.synapse.groups,
delay_shift=False,
)
lava_block.synapse.weight.data[:, :, :, :, 0] = (
self.synapse.weight.data.clone()
)
elif isinstance(self.synapse, SumPool2d):
lava_block = slayer.block.cuba.Pool(
self.neuron.lava_cuba_neuron_params,
self.synapse.kernel_size,
self.synapse.stride,
self.synapse.padding,
self.synapse.dilation,
delay_shift=False,
)
elif isinstance(self.synapse, nn.Flatten):
return slayer.block.cuba.Flatten()
else:
raise NotImplementedError
# 补上norm
if self.neuron.norm is not None:
lava_block.neuron.norm = self.neuron.norm.to_lava()
return lava_block
except BaseException as e:
logging.info(f"spikingjelly.activation_based.lava_exchange: {e}")
slayer = None