import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
from typing import Union, Callable
from . import neuron, base, surrogate
_hw_bits = 12
[文档]@torch.jit.script
def step_quantize_forward(x: torch.Tensor, step: float):
x = x / step
torch.round_(x)
return x * step
[文档]class step_quantize_atgf(torch.autograd.Function):
[文档] @staticmethod
def forward(ctx, x: torch.Tensor, step: float = 1.):
return step_quantize_forward(x, step)
[文档] @staticmethod
def backward(ctx, grad_output):
return grad_output, None
def step_quantize(x: torch.Tensor, step: float = 1.):
"""
:param x: a float tensor whose range is ``0 <= x <= 1``.
:type x: torch.Tensor
:param step: the quantization step
:type step: float
:return: ``y = round(x / step) * step``
:rtype: torch.Tensor
The step quantizer defined in `Lava`.
Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``k * step``.
"""
return step_quantize_atgf.apply(x, step)
[文档]def quantize_8b(x, scale, descale=False):
"""
Denote ``k`` as an ``int``, ``x[i]`` will be quantized to the nearest ``2 * k / scale``, \
and ``k = {-128, -127, ..., 126, 127}``.
"""
if not descale:
return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale)
else:
return step_quantize(x, step=2 / scale).clamp(-256 / scale, 255 / scale) * scale
[文档]@torch.jit.script
def right_shift_to_zero(x: torch.Tensor, bits: int):
dtype = x.dtype
assert dtype in (torch.int32, torch.int64)
return (torch.sign(x) * (torch.abs(x) >> bits)).to(dtype)
@torch.jit.script
def _listep_forward(x: torch.Tensor, decay: torch.Tensor, state: torch.Tensor, w_scale: int,
dtype: torch.dtype = torch.int32, hw_bits: int = 12):
# y = (state * w_scale * ((1 << hw_bits) - decay) / (1 << hw_bits) + w_scale * x) / w_scale
# y = state * (1 - decay / (1 << hw_bits)) + x
scaled_state = (state * w_scale).to(dtype=dtype)
decay_int = (1 << hw_bits) - decay.to(dtype=dtype)
output = right_shift_to_zero(scaled_state * decay_int, hw_bits) + (w_scale * x).to(dtype=dtype)
return output / w_scale
@torch.jit.script
def _listep_backward(grad_output: torch.Tensor, decay: torch.Tensor, state: torch.Tensor, hw_bits: int = 12):
grad_state = (1 - decay / (1 << hw_bits)) * grad_output
grad_decay = - state / (1 << hw_bits) * grad_output
grad_decay = grad_decay.sum()
return grad_output, grad_decay, grad_state
# x, decay, state
[文档]class BatchNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-05, momentum: float = 0.1,
track_running_stats: bool = True, weight_exp_bits: int = 3, pre_hook_fx: Callable = lambda x: x):
super().__init__()
# lava.lib.dl.slayer.neuron.norm.WgtScaleBatchNorm
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.track_running_stats = track_running_stats
self.weight_exp_bits = weight_exp_bits
self.pre_hook_fx = pre_hook_fx
self.register_buffer(
'running_mean',
torch.zeros(num_features)
)
self.register_buffer(
'running_var',
torch.zeros(1)
)
[文档] def to_lava(self):
bn = slayer.neuron.norm.WgtScaleBatchNorm(num_features=self.num_features, momentum=self.momentum, weight_exp_bits=self.weight_exp_bits, eps=self.eps, pre_hook_fx=self.pre_hook_fx)
bn.load_state_dict(self.state_dict())
print(self.state_dict())
return bn
[文档] def forward(self, x: torch.Tensor):
if self.track_running_stats and self.training:
x_mean = torch.mean(x, dim=(0, 2, 3), keepdim=True)
x_var = torch.var(x, unbiased=False)
numel = x.numel() / self.num_features
with torch.no_grad():
self.running_mean = (1. - self.momentum) * self.running_mean + self.momentum * x_mean.squeeze()
self.running_var = (1. - self.momentum) * self.running_var + self.momentum * x_var * numel / (numel + 1)
else:
x_mean = self.running_mean.view(1, -1, 1, 1)
x_var = self.running_var.view(1, -1, 1, 1)
x_std = torch.sqrt(x_var + self.eps)
x_std = torch.pow(2., torch.ceil(torch.log2(x_std)).clamp(
-self.weight_exp_bits, self.weight_exp_bits
))
return (x - self.pre_hook_fx(x_mean)) / x_std
[文档]class LeakyIntegratorStep(torch.autograd.Function):
[文档] @staticmethod
def forward(
ctx, x, decay, state, w_scale):
output = _listep_forward(x, decay, state, w_scale, dtype=torch.int64, hw_bits=_hw_bits)
if x.requires_grad or state.requires_grad:
ctx.save_for_backward(decay, state)
return output
[文档] @staticmethod
def backward(ctx, grad_output):
decay, state = ctx.saved_tensors
grad_input, grad_decay, grad_state = _listep_backward(
grad_output, decay, state, hw_bits=_hw_bits
)
return grad_input, grad_decay, grad_state, None
[文档]class CubaLIFNode(neuron.BaseNode):
def __init__(
self, current_decay: Union[float, torch.Tensor], voltage_decay: Union[float, torch.Tensor],
v_threshold: float = 1., v_reset: float = 0.,
scale=1 << 6,
requires_grad=False,
surrogate_function: Callable = surrogate.Sigmoid(),
norm: BatchNorm2d = None,
detach_reset=False,
step_mode="s", backend="torch",
store_v_seq: bool = False, store_i_seq: bool = False,
):
# author: https://github.com/AllenYolk
"""
* :ref:`API in English <CubaLIFNode.__init__-en>`
.. _CubaLIFNode.__init__-cn:
:param current_decay: 电流衰减常数
:type current_decay: float | torch.Tensor
:param voltage_decay: 电压衰减常数
:type voltage_decay: float | torch.Tensor
:param v_threshold: 神经元阈值电压。默认为1。
:type v_threshold: float
:param v_reset: 重置电压,默认为0
:type v_reset: float, None
:param scale: 量化参数,控制神经元的量化精度(参考了lava-dl的cuba.Neuron)。默认为 ``1<<6`` 。
等效于``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``。
:type scale: float
:param requires_grad: 指明 ``current_decay`` 和 ``voltage_decay`` 两个神经元参数是否可学习(是否需要梯度),默认为 ``False`` 。
:type requires_grad: bool
:param detach_reset: 是否将reset的计算图分离,默认为 ``False`` 。
:type detach_reset: bool
:param step_mode: 步进模式,可以为 `'s'` (单步)或 `'m'` (多步),默认为 `'s'` 。
:type step_mode: str
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。目前只支持torch
:type backend: str
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.voltage_state`` 。
通常设置成 ``False`` ,可以节省内存。
:type store_v_seq: bool
:param store_i_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电流值 ``self.i_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电流,即 ``shape = [N, *]`` 的 ``self.current_state`` 。
通常设置成 ``False`` ,可以节省内存。
:type store_i_seq: bool
.. math::
I[t] = (1 - \\alpha_{I})I[t-1] + X[t]
V[t] = (1 - \\alpha_{V})V[t-1] + I[t]
* :ref:`中文API <CubaLIFNode.__init__-cn>`
.. CubaLIFNode.__init__-en:
:param current_decay: current decay constant
:type current_decay: float | torch.Tensor
:param voltage_decay: voltage decay constant
:type voltage_decay: float | torch.Tensor
:param v_threshold: threshold of the the neurons in this layer. Default to 1.
:type v_threshold: float
:param v_reset: reset potential of the neurons in this layer, 0 by default
:type v_reset: float
:param scale: quantization precision (ref: lava-dl cuba.Neuron). Default to ``1<<6`` .
Equivalent to ``w_scale=int(scale)``, ``s_scale=int(scale * (1<<6))``, ``p_scale=1<<12``.
:type scale: float
:param requires_grad: whether ``current_decay`` and ``voltage_decay`` are learnable. Default to ``False`` .
:type requires_grad: bool
:param detach_reset: whether to detach the computational graph of reset in backward pass. Default to ``False`` .
:type detach_reset: bool
:param step_mode: the step mode, which can be `s` (single-step) or `m` (multi-step). Default to `'s'` .
:type step_mode: str
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. Only `torch` is supported.
:type backend: str
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.voltage_state`` with ``shape = [N, *]``, which can reduce the
memory consumption. Default to ``False`` .
:type store_v_seq: bool
:param store_i_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the current at each time-step to ``self.i_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the current at last time-step will be stored to ``self.current_state`` with ``shape = [N, *]``, which can reduce the
memory consumption. Default to ``False`` .
:type store_i_seq: bool
.. math::
I[t] = (1 - \\alpha_{I})I[t-1] + X[t]
V[t] = (1 - \\alpha_{V})V[t-1] + I[t]
"""
self.lava_cuba_neuron_params = {
'threshold': v_threshold,
'current_decay': current_decay,
'voltage_decay': voltage_decay,
'scale': scale
}
super().__init__(v_threshold=v_threshold, v_reset=v_reset, surrogate_function=surrogate_function,
detach_reset=detach_reset, step_mode=step_mode, backend=backend, store_v_seq=store_v_seq)
self.store_i_seq = store_i_seq
assert v_reset == 0., 'CubaLIFNode only supports for hard reset with v_reset = 0. !'
self.requires_grad = requires_grad
# the default quantization parameter setting in lava
self._scale = int(scale)
self._s_scale = int(scale * (1 << 6))
self._p_scale = 1 << _hw_bits
# Which is equivalent to:
# self.p_scale = 1<<12
# self.w_scale = int(scale)
# self.s_scale = int(scale * (1<<6))
self._v_threshold = int(v_threshold * self.scale) / self.scale
# ``_v_threshold`` is the nearest and no more than ``k / scale`` to ``v_threshold`` where ``k`` is an ``int``
self.v_threshold_eps = 0.01 / self.s_scale
# loihi use s[t] = v[t] > v_th, but we use s[t] = v[t] >= v_th. Thus, we use v[t] + eps >= v_th to approximate
current_decay = torch.tensor(self.p_scale * current_decay, dtype=torch.float32)
voltage_decay = torch.tensor(self.p_scale * voltage_decay, dtype=torch.float32)
if requires_grad:
self.current_decay = nn.Parameter(current_decay)
self.voltage_decay = nn.Parameter(voltage_decay)
else:
self.register_buffer('current_decay', current_decay)
self.register_buffer('voltage_decay', voltage_decay)
self.register_memory('current_state', 0.)
self.register_memory('voltage_state', 0.)
self.clamp_decay_parameters()
self.norm = norm
if self.norm is not None:
if isinstance(self.norm, BatchNorm2d):
self.norm.pre_hook_fx = self.quantize_8bit
else:
raise NotImplementedError(self.norm)
[文档] def quantize_8bit(self, x, descale=False):
return quantize_8b(x, scale=self.scale, descale=descale)
[文档] def clamp_decay_parameters(self):
with torch.no_grad():
self.current_decay.data.clamp_(min=0., max=self.p_scale)
self.voltage_decay.data.clamp_(min=0., max=self.p_scale)
@property
def scale(self):
"""Read-only attribute: scale"""
return self._scale
@property
def s_scale(self):
"""Read-only attribute: s_scale"""
return self._s_scale
@property
def p_scale(self):
"""Read-only attribute: s_scale"""
return self._p_scale
@property
def store_i_seq(self):
return self._store_i_seq
@store_i_seq.setter
def store_i_seq(self, value: bool):
self._store_i_seq = value
if value:
if not hasattr(self, "i_seq"):
self.register_memory("i_seq", None)
@property
def supported_backends(self):
if self.step_mode == "m" or self.step_mode == "s":
return ("torch",)
else:
raise ValueError(
f"self.step_mode should be 's' or 'm', "
f"but get {self.step_mode} instead."
)
# computation process
[文档] def state_initialization(self, x: torch.Tensor):
if isinstance(self.current_state, float):
self.current_state = torch.zeros_like(x.data)
if isinstance(self.voltage_state, float):
self.voltage_state = torch.zeros_like(x.data)
[文档] def neuronal_charge(self, x: torch.Tensor):
if self.requires_grad:
self.clamp_decay_parameters()
current = LeakyIntegratorStep.apply(
x,
step_quantize(self.current_decay),
self.current_state.contiguous(),
self.s_scale,
)
if self.norm is not None:
current = self.norm(current)
voltage = LeakyIntegratorStep.apply(
current,
step_quantize(self.voltage_decay),
self.voltage_state.contiguous(),
self.s_scale,
)
self.current_state = current
self.voltage_state = voltage
[文档] def neuronal_fire(self):
return self.surrogate_function(self.voltage_state - (self.v_threshold + self.v_threshold_eps))
[文档] def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike
self.voltage_state = self.jit_hard_reset(self.voltage_state, spike_d, self.v_reset)
[文档] def single_step_forward(self, x):
self.state_initialization(x)
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
[文档] def multi_step_forward(self, x_seq: torch.Tensor):
T = x_seq.shape[0]
y_seq = []
if self.store_v_seq:
v_seq = []
if self.store_i_seq:
i_seq = []
for t in range(T):
y = self.single_step_forward(x_seq[t])
y_seq.append(y)
if self.store_v_seq:
v_seq.append(self.voltage_state)
if self.store_i_seq:
i_seq.append(self.current_state)
if self.store_v_seq:
self.v_seq = torch.stack(v_seq)
if self.store_i_seq:
self.i_seq = torch.stack(i_seq)
return torch.stack(y_seq)
try:
import lava.lib.dl.slayer as slayer
# ----------------------------------------
# data reshape function
[文档] def TNX_to_NXT(x_seq: torch.Tensor):
# x_seq.shape = [T, N, *]
permute_args = list(range(1, x_seq.dim()))
permute_args.append(0)
return x_seq.permute(permute_args)
[文档] def NXT_to_TNX(x_seq: torch.Tensor):
# x_seq.shape = [N, *, T]
permute_args = list(range(x_seq.dim() - 1))
permute_args.insert(0, x_seq.dim() - 1)
return x_seq.permute(permute_args)
[文档] def lava_neuron_forward(lava_neuron: nn.Module, x_seq: torch.Tensor, v: torch.Tensor or float):
# x_seq.shape = [T, N, *]
# lave uses shape = [*, T], while SJ uses shape = [T, *]
unsqueeze_flag = False
if x_seq.dim() == 2:
x_seq = x_seq.unsqueeze(1)
# lave needs input with shape [N, ... ,T]
unsqueeze_flag = True
if isinstance(v, float):
v_init = v
v = torch.zeros_like(x_seq[0])
if v_init != 0.:
torch.fill_(v, v_init)
x_seq_shape = x_seq.shape
x_seq = x_seq.flatten(2).permute(1, 2, 0)
# [T, N, *] -> [N, *, T]
lava_neuron.voltage_state = v
spike = lava_neuron(x_seq).permute(2, 0, 1)
v = lava_neuron.voltage_state.reshape(x_seq_shape[1:])
spike = spike.reshape(x_seq_shape)
if unsqueeze_flag:
v = v.squeeze(1)
spike = spike.squeeze(1)
return spike, v
# ----------------------------------------
# quantize function
class _step_quantize(torch.autograd.Function):
@staticmethod
def forward(ctx, x, step):
return torch.round(x / step) * step
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
[文档] def step_quantize(x: torch.Tensor, step: float = 1.):
"""
:param x: the input tensor
:type x: torch.Tensor
:param step: the quantize step
:type step: float
:return: quantized tensor
:rtype: torch.Tensor
The step quantize function. Here is an example:
.. code-block:: python
# plt.style.use(['science', 'muted', 'grid'])
fig = plt.figure(dpi=200, figsize=(6, 4))
x = torch.arange(-4, 4, 0.001)
plt.plot(x, lava_exchange.step_quantize(x, 2.), label='quantize(x, step=2)')
plt.plot(x, x, label='y=x', ls='-.')
plt.legend()
plt.grid(ls='--')
plt.title('step quantize')
plt.xlabel('Input')
plt.ylabel('Output')
plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.svg')
plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.pdf')
.. image:: ../_static/API/activation_based/lava_exchange/step_quantize.*
:width: 100%
"""
return _step_quantize.apply(x, step)
[文档] def quantize_8bit(x: torch.Tensor, scale, descale=False):
if descale:
return step_quantize(x, 2. / scale).clamp(-256. / scale, 255. / scale) * scale
else:
return step_quantize(x, 2. / scale).clamp(-256. / scale, 255. / scale)
# ----------------------------------------
# convert function
[文档] def check_instance(m, instance):
if not isinstance(m, instance):
raise ValueError(
f'expected {m} with type {instance}, but got {m} with type {type(m)}!')
[文档] def check_no_bias(m):
if m.bias is not None:
raise ValueError(f'lava does not support for {type(m)} with bias!')
[文档] def to_lava_neuron_param_dict(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, neuron.IFNode):
if sj_ms_neuron.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
return {
'threshold': sj_ms_neuron.v_threshold,
'current_decay': 1.,
'voltage_decay': 0.,
'tau_grad': 1, 'scale_grad': 1, 'scale': sj_ms_neuron.lava_s_cale,
'norm': None, 'dropout': None,
'shared_param': True, 'persistent_state': True, 'requires_grad': False,
'graded_spike': False
}
elif isinstance(sj_ms_neuron, neuron.LIFNode):
if sj_ms_neuron.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
if sj_ms_neuron.decay_input:
raise ValueError('lava only supports for decay_input == False!')
return {
'threshold': sj_ms_neuron.v_threshold,
'current_decay': 1.,
'voltage_decay': 1. / sj_ms_neuron.tau,
'tau_grad': 1, 'scale_grad': 1, 'scale': sj_ms_neuron.lava_s_cale,
'norm': None, 'dropout': None,
'shared_param': True, 'persistent_state': True, 'requires_grad': False,
'graded_spike': False
}
else:
raise NotImplementedError(sj_ms_neuron)
[文档] def to_lava_neuron(sj_ms_neuron: nn.Module):
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
return slayer.neuron.cuba.Neuron(
**to_lava_neuron_param_dict(sj_ms_neuron)
)
else:
raise NotImplementedError(sj_ms_neuron)
[文档] def linear_to_lava_synapse_dense(fc: nn.Linear):
"""
:param fc: a pytorch linear layer without bias
:type fc: nn.Linear
:return: a lava slayer dense synapse
:rtype: slayer.synapse.Dense
Codes example:
.. code-block:: python
T = 4
N = 2
layer_nn = nn.Linear(8, 4, bias=False)
layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn)
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq)))
print('max error:', (y_nn - y_sl).abs().max())
"""
check_instance(fc, nn.Linear)
check_no_bias(fc)
slayer_dense = slayer.synapse.Dense(fc.in_features, fc.out_features)
# `slayer_dense` is a `torch.torch.nn.Conv3d`. Its weight has shape [out_features, in_features, 1, 1, 1]
slayer_dense.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()
return slayer_dense
[文档] def conv2d_to_lava_synapse_conv(conv2d_nn: nn.Conv2d):
"""
:param conv2d_nn: a pytorch conv2d layer without bias
:type conv2d_nn: nn.Conv2d
:return: a lava slayer conv synapse
:rtype: slayer.synapse.Conv
Codes example:
.. code-block:: python
T = 4
N = 2
layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False)
layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq)))
print('max error:', (y_nn - y_sl).abs().max())
"""
check_instance(conv2d_nn, nn.Conv2d)
check_no_bias(conv2d_nn)
slayer_conv = slayer.synapse.Conv(in_features=conv2d_nn.in_channels, out_features=conv2d_nn.out_channels,
kernel_size=conv2d_nn.kernel_size, stride=conv2d_nn.stride,
padding=conv2d_nn.padding, dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups)
# `slayer_conv` is a `torch.torch.nn.Conv3d`.
slayer_conv.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()
return slayer_conv
[文档] def avgpool2d_to_lava_synapse_pool(pool2d_nn: nn.AvgPool2d):
"""
:param pool2d_nn: a pytorch AvgPool2d layer
:type pool2d_nn: nn.AvgPool2d
:return: a lava slayer pool layer
:rtype: slayer.synapse.Pool
.. admonition:: Warning
:class: warning
The lava slayer pool layer applies sum pooling, rather than average pooling.
.. code-block:: python
T = 4
N = 2
layer_nn = nn.AvgPool2d(kernel_size=2, stride=2)
layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn)
x_seq = torch.rand([T, N, 3, 28, 28])
with torch.no_grad():
y_nn = functional.seq_to_ann_forward(x_seq, layer_nn)
y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) / 4.
print('max error:', (y_nn - y_sl).abs().max())
"""
check_instance(pool2d_nn, nn.AvgPool2d)
logging.warning(
'The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer.')
return slayer.synapse.Pool(pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding)
[文档] def to_lava_block_dense(fc: nn.Linear, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True):
check_instance(fc, nn.Linear)
check_no_bias(fc)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Dense
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(neuron_params, fc.in_features, fc.out_features, delay_shift=False)
else:
lava_block = block_init(neuron_params, fc.in_features, fc.out_features, delay_shift=False, pre_hook_fx=None)
lava_block.synapse.weight.data[:, :, 0, 0, 0] = fc.weight.data.clone()
return lava_block
[文档] def to_lava_block_conv(conv2d_nn: nn.Conv2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True):
check_instance(conv2d_nn, nn.Conv2d)
check_no_bias(conv2d_nn)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Conv
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(neuron_params, in_features=conv2d_nn.in_channels,
out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size,
stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups, delay_shift=False)
else:
lava_block = block_init(neuron_params, in_features=conv2d_nn.in_channels,
out_features=conv2d_nn.out_channels, kernel_size=conv2d_nn.kernel_size,
stride=conv2d_nn.stride, padding=conv2d_nn.padding, dilation=conv2d_nn.dilation,
groups=conv2d_nn.groups, delay_shift=False, pre_hook_fx=None)
lava_block.synapse.weight.data[:, :, :, :, 0] = conv2d_nn.weight.data.clone()
return lava_block
[文档] def to_lava_block_pool(pool2d_nn: nn.AvgPool2d, sj_ms_neuron: nn.Module, quantize_to_8bit: bool = True):
check_instance(pool2d_nn, nn.AvgPool2d)
neuron_params = to_lava_neuron_param_dict(sj_ms_neuron)
if isinstance(sj_ms_neuron, (neuron.IFNode, neuron.LIFNode)):
block_init = slayer.block.cuba.Pool
else:
raise NotImplementedError(sj_ms_neuron)
if quantize_to_8bit:
# if 'pre_hook_fx' not in kwargs.keys(), then `pre_hook_fx` will be set to `quantize_8bit` by default
lava_block = block_init(neuron_params, pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding,
delay_shift=False)
else:
lava_block = block_init(neuron_params, pool2d_nn.kernel_size, pool2d_nn.stride, pool2d_nn.padding,
delay_shift=False, pre_hook_fx=None)
logging.warning(
'The lava slayer pool layer applies sum pooling, rather than average pooling. `avgpool2d_to_lava_synapse_pool` will return a sum pooling layer.')
return lava_block
[文档] def to_lava_block_flatten(flatten_nn: nn.Flatten):
check_instance(flatten_nn, nn.Flatten)
if flatten_nn.start_dim != 1:
raise ValueError('lava only supports for flatten_nn.start_dim == 1!')
return slayer.block.cuba.Flatten()
[文档] def to_lava_blocks(net: list or tuple or nn.Sequential):
# https://lava-nc.org/lava-lib-dl/netx/netx.html
'''
Supported layer types
input : {shape, type}
flatten: {shape, type}
average: {shape, type}
concat : {shape, type, layers}
dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)}
pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight}
conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride,
| padding, dilation, groups, weight, delay(if available)}
|
|-> this is the description of the compartment parameters
|-> {iDecay, vDecay, vThMant, refDelay, ... (other additional params)}
'''
blocks = []
length = net.__len__()
i = 0
k = None
while True:
if isinstance(net[i], nn.Linear):
if k is not None:
if isinstance(net[i], (nn.Conv2d, nn.Linear)):
net[i].weight.data /= k
else:
raise NotImplementedError(type(net[i]))
k = None
if i + 1 < length and isinstance(net[i + 1], (neuron.IFNode, neuron.LIFNode)):
blocks.append(to_lava_block_dense(net[i], net[i + 1]))
i += 2
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.Conv2d):
if i + 1 < length and isinstance(net[i + 1], (neuron.IFNode, neuron.LIFNode)):
blocks.append(to_lava_block_conv(net[i], net[i + 1]))
i += 2
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.AvgPool2d):
if i + 1 < length and isinstance(net[i + 1], (neuron.IFNode, neuron.LIFNode)):
blocks.append(to_lava_block_pool(net[i], net[i + 1]))
i += 2
if isinstance(net[i].kernel_size, int):
k = float(net[i].kernel_size * net[i].kernel_size)
else:
k = float(net[i].kernel_size[0] * net[i].kernel_size[1])
else:
raise ValueError(type(net[i]))
elif isinstance(net[i], nn.Flatten):
blocks.append(to_lava_block_flatten(net[i]))
i += 1
else:
raise ValueError(type(net[i]))
if i == length:
break
return blocks
[文档] class SumPool2d(nn.Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=1):
"""
.. code-block:: python
x = torch.rand([4, 2, 4, 16, 16])
with torch.no_grad():
sp_sj = SumPool2d(kernel_size=2, stride=2)
y_sj = functional.seq_to_ann_forward(x, sp_sj)
sp_la = slayer.synapse.Pool(kernel_size=2, stride=2)
y_la = lava_exchange.NXT_to_TNX(sp_la(lava_exchange.TNX_to_NXT(x)))
print((y_sj - y_la).abs().sum())
"""
super().__init__()
temp_conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=kernel_size, stride=stride,
padding=padding,
dilation=dilation, bias=False)
self.weight = torch.ones_like(temp_conv.weight.data)
self.kernel_size = temp_conv.kernel_size
self.stride = temp_conv.stride
self.padding = temp_conv.padding
self.dilation = temp_conv.dilation
del temp_conv
[文档] def forward(self, x: torch.Tensor):
# x.shape = [N, C, H, W]
if self.dilation == (1, 1):
return F.avg_pool2d(x, self.kernel_size, self.stride, self.padding) * self.weight.numel()
else:
N, C, H, W = x.shape
x = x.view(N * C, 1, H, W)
x = F.conv2d(x, weight=self.weight, bias=None, stride=self.stride, padding=self.padding,
dilation=self.dilation)
x = x.view(N, C, x.shape[2], x.shape[3])
return x
[文档] class BlockContainer(nn.Module, base.StepModule):
@property
def step_mode(self):
return self._step_mode
@step_mode.setter
def step_mode(self, value: str):
if value not in self.supported_step_mode():
raise ValueError(f'step_mode can only be {self.supported_step_mode()}, but got "{value}"!')
self._step_mode = value
if isinstance(self.neuron, base.StepModule):
self.neuron.step_mode = value
if isinstance(self.synapse, base.StepModule):
self.synapse.step_mode = value
def __init__(self, synapse: nn.Conv2d or nn.Linear or nn.AvgPool2d or nn.Flatten, neu: CubaLIFNode or None,
step_mode: str = 's'):
super().__init__()
if isinstance(synapse, nn.Flatten):
assert neu is None
self.synapse = synapse
self.neuron = None
if synapse.start_dim != 1:
raise ValueError('lava only supports for torch.nn.Flatten with start_dim == 1!')
else:
if isinstance(neu, neuron.IFNode):
if neu.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
neu = CubaLIFNode(current_decay=1., voltage_decay=0., v_threshold=neu.v_threshold, scale=neu.lava_s_cale)
elif isinstance(neu, neuron.LIFNode):
if neu.v_reset != 0.:
raise ValueError('lava only supports for v_reset == 0!')
if neu.decay_input:
raise ValueError('lava only supports for decay_input == False!')
neu = CubaLIFNode(current_decay=1., voltage_decay=1. / neu.tau, v_threshold=neu.v_threshold,
scale=neu.lava_s_cale)
else:
assert isinstance(neu, CubaLIFNode)
self.synapse = synapse
self.neuron = neu
if isinstance(self.synapse, (nn.Conv2d, nn.Linear)):
assert self.synapse.bias is None
self.step_mode = step_mode
[文档] def forward(self, x: torch.Tensor):
if self.step_mode == 'm':
T = x.shape[0]
N = x.shape[1]
x = x.flatten(0, 1)
if isinstance(self.synapse, (nn.Conv2d, nn.Linear)):
weight = self.neuron.quantize_8bit(self.synapse.weight)
# 量化到 2k / self.neuron.scale, k = -128, -127, ..., 127,共有256个取值
if isinstance(self.synapse, nn.Conv2d):
x = F.conv2d(x, weight=weight, bias=self.synapse.bias, stride=self.synapse.stride,
padding=self.synapse.padding, dilation=self.synapse.dilation,
groups=self.synapse.groups)
elif isinstance(self.synapse, nn.Linear):
x = F.linear(x, weight, self.synapse.bias)
elif isinstance(self.synapse, (SumPool2d, nn.Flatten)):
x = self.synapse(x)
else:
raise NotImplementedError(type(self.synapse))
if self.step_mode == 'm':
x = x.view([T, N, *x.shape[1:]])
if self.neuron is not None:
x = self.neuron(x)
return x
[文档] def to_lava_block(self):
if isinstance(self.synapse, nn.Linear):
lava_block = slayer.block.cuba.Dense(self.neuron.lava_cuba_neuron_params, self.synapse.in_features,
self.synapse.out_features, delay_shift=False)
lava_block.synapse.weight.data[:, :, 0, 0, 0] = self.synapse.weight.data.clone()
elif isinstance(self.synapse, nn.Conv2d):
lava_block = slayer.block.cuba.Conv(self.neuron.lava_cuba_neuron_params,
in_features=self.synapse.in_channels,
out_features=self.synapse.out_channels,
kernel_size=self.synapse.kernel_size,
stride=self.synapse.stride, padding=self.synapse.padding,
dilation=self.synapse.dilation,
groups=self.synapse.groups, delay_shift=False)
lava_block.synapse.weight.data[:, :, :, :, 0] = self.synapse.weight.data.clone()
elif isinstance(self.synapse, SumPool2d):
lava_block = slayer.block.cuba.Pool(self.neuron.lava_cuba_neuron_params, self.synapse.kernel_size,
self.synapse.stride, self.synapse.padding, self.synapse.dilation,
delay_shift=False)
elif isinstance(self.synapse, nn.Flatten):
return slayer.block.cuba.Flatten()
else:
raise NotImplementedError
# 补上norm
if self.neuron.norm is not None:
lava_block.neuron.norm = self.neuron.norm.to_lava()
return lava_block
except BaseException as e:
logging.info(f'spikingjelly.activation_based.lava_exchange: {e}')
slayer = None