spikingjelly.activation_based.lava_exchange package
Module contents
- class spikingjelly.activation_based.lava_exchange.step_quantize_atgf(*args, **kwargs)[源代码]
基类:
Function
- spikingjelly.activation_based.lava_exchange.quantize_8b(x, scale, descale=False)[源代码]
Denote
k
as anint
,x[i]
will be quantized to the nearest2 * k / scale
, andk = {-128, -127, ..., 126, 127}
.
- class spikingjelly.activation_based.lava_exchange.BatchNorm2d(num_features: int, eps: float = 1e-05, momentum: float = 0.1, track_running_stats: bool = True, weight_exp_bits: int = 3, pre_hook_fx: ~typing.Callable = <function BatchNorm2d.<lambda>>)[源代码]
基类:
Module
- class spikingjelly.activation_based.lava_exchange.LeakyIntegratorStep(*args, **kwargs)[源代码]
基类:
Function
- class spikingjelly.activation_based.lava_exchange.CubaLIFNode(current_decay: Union[float, Tensor], voltage_decay: Union[float, Tensor], v_threshold: float = 1.0, v_reset: float = 0.0, scale=64, requires_grad=False, surrogate_function: Callable = Sigmoid(alpha=4.0, spiking=True), norm: Optional[BatchNorm2d] = None, detach_reset=False, step_mode='s', backend='torch', store_v_seq: bool = False, store_i_seq: bool = False)[源代码]
基类:
BaseNode
- 参数:
current_decay (Union[float, torch.Tensor]) – 电流衰减常数
voltage_decay (Union[float, torch.Tensor]) – 电压衰减常数
v_threshold (float) – 神经元阈值电压。默认为1。
v_reset (float, None) – 重置电压,默认为0
scale (float) – 量化参数,控制神经元的量化精度(参考了lava-dl的cuba.Neuron)。默认为
1<<6
。 等效于``w_scale=int(scale)``,s_scale=int(scale * (1<<6))
,p_scale=1<<12
。requires_grad (bool) – 指明
current_decay
和voltage_decay
两个神经元参数是否可学习(是否需要梯度),默认为False
。detach_reset (bool) – 是否将reset的计算图分离,默认为
False
。step_mode (str) – 步进模式,可以为 ‘s’ (单步)或 ‘m’ (多步),默认为 ‘s’ 。
backend (str) – 使用哪种后端。不同的
step_mode
可能会带有不同的后端。可以通过打印self.supported_backends
查看当前 使用的步进模式支持的后端。目前只支持torchstore_v_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电压值self.v_seq
。设置为False
时计算完成后只保留最后一个时刻的电压,即shape = [N, *]
的self.voltage_state
。 通常设置成False
,可以节省内存。store_i_seq (bool) – 在使用
step_mode = 'm'
时,给与shape = [T, N, *]
的输入后,是否保存中间过程的shape = [T, N, *]
的各个时间步的电流值self.i_seq
。设置为False
时计算完成后只保留最后一个时刻的电流,即shape = [N, *]
的self.current_state
。 通常设置成False
,可以节省内存。
\[I[t] = (1 - \alpha_{I})I[t-1] + X[t] V[t] = (1 - \alpha_{V})V[t-1] + I[t]\]- 参数:
current_decay (Union[float, torch.Tensor]) – current decay constant
voltage_decay (Union[float, torch.Tensor]) – voltage decay constant
v_threshold (float) – threshold of the the neurons in this layer. Default to 1.
v_reset (float) – reset potential of the neurons in this layer, 0 by default
scale (float) – quantization precision (ref: lava-dl cuba.Neuron). Default to
1<<6
. Equivalent tow_scale=int(scale)
,s_scale=int(scale * (1<<6))
,p_scale=1<<12
.requires_grad (bool) – whether
current_decay
andvoltage_decay
are learnable. Default toFalse
.detach_reset (bool) – whether to detach the computational graph of reset in backward pass. Default to
False
.step_mode (str) – the step mode, which can be s (single-step) or m (multi-step). Default to ‘s’ .
backend – backend fot this neurons layer. Different
step_mode
may support for different backends. The user can
print
self.supported_backends
and check what backends are supported by the currentstep_mode
. Only torch is supported. :type backend: str- 参数:
store_v_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the voltage at each time-step toself.v_seq
withshape = [T, N, *]
. If set toFalse
, only the voltage at last time-step will be stored toself.voltage_state
withshape = [N, *]
, which can reduce the memory consumption. Default toFalse
.store_i_seq (bool) – when using
step_mode = 'm'
and given input withshape = [T, N, *]
, this option controls whether storing the current at each time-step toself.i_seq
withshape = [T, N, *]
. If set toFalse
, only the current at last time-step will be stored toself.current_state
withshape = [N, *]
, which can reduce the memory consumption. Default toFalse
.
\[I[t] = (1 - \alpha_{I})I[t-1] + X[t] V[t] = (1 - \alpha_{V})V[t-1] + I[t]\]- property scale
scale
- Type:
Read-only attribute
- property s_scale
s_scale
- Type:
Read-only attribute
- property p_scale
s_scale
- Type:
Read-only attribute
- property store_i_seq
- property supported_backends
- spikingjelly.activation_based.lava_exchange.lava_neuron_forward(lava_neuron: Module, x_seq: Tensor, v: Union[Tensor, float])[源代码]
- spikingjelly.activation_based.lava_exchange.step_quantize(x: Tensor, step: float = 1.0)[源代码]
- 参数:
x (torch.Tensor) – the input tensor
step (float) – the quantize step
- 返回:
quantized tensor
- 返回类型:
The step quantize function. Here is an example:
# plt.style.use(['science', 'muted', 'grid']) fig = plt.figure(dpi=200, figsize=(6, 4)) x = torch.arange(-4, 4, 0.001) plt.plot(x, lava_exchange.step_quantize(x, 2.), label='quantize(x, step=2)') plt.plot(x, x, label='y=x', ls='-.') plt.legend() plt.grid(ls='--') plt.title('step quantize') plt.xlabel('Input') plt.ylabel('Output') plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.svg') plt.savefig('./docs/source/_static/API/activation_based/lava_exchange/step_quantize.pdf')
- spikingjelly.activation_based.lava_exchange.linear_to_lava_synapse_dense(fc: Linear)[源代码]
- 参数:
fc (nn.Linear) – a pytorch linear layer without bias
- 返回:
a lava slayer dense synapse
- 返回类型:
slayer.synapse.Dense
Codes example:
T = 4 N = 2 layer_nn = nn.Linear(8, 4, bias=False) layer_sl = lava_exchange.linear_to_lava_synapse_dense(layer_nn) x_seq = torch.rand([T, N, 8]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.conv2d_to_lava_synapse_conv(conv2d_nn: Conv2d)[源代码]
- 参数:
conv2d_nn (nn.Conv2d) – a pytorch conv2d layer without bias
- 返回:
a lava slayer conv synapse
- 返回类型:
slayer.synapse.Conv
Codes example:
T = 4 N = 2 layer_nn = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) layer_sl = lava_exchange.conv2d_to_lava_synapse_conv(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.avgpool2d_to_lava_synapse_pool(pool2d_nn: AvgPool2d)[源代码]
- 参数:
pool2d_nn (nn.AvgPool2d) – a pytorch AvgPool2d layer
- 返回:
a lava slayer pool layer
- 返回类型:
slayer.synapse.Pool
Warning
The lava slayer pool layer applies sum pooling, rather than average pooling.
T = 4 N = 2 layer_nn = nn.AvgPool2d(kernel_size=2, stride=2) layer_sl = lava_exchange.avgpool2d_to_lava_synapse_pool(layer_nn) x_seq = torch.rand([T, N, 3, 28, 28]) with torch.no_grad(): y_nn = functional.seq_to_ann_forward(x_seq, layer_nn) y_sl = lava_exchange.NXT_to_TNX(layer_sl(lava_exchange.TNX_to_NXT(x_seq))) / 4. print('max error:', (y_nn - y_sl).abs().max())
- spikingjelly.activation_based.lava_exchange.to_lava_block_dense(fc: Linear, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_block_conv(conv2d_nn: Conv2d, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_block_pool(pool2d_nn: AvgPool2d, sj_ms_neuron: Module, quantize_to_8bit: bool = True)[源代码]
- spikingjelly.activation_based.lava_exchange.to_lava_blocks(net: Union[list, tuple, Sequential])[源代码]
Supported layer types input : {shape, type} flatten: {shape, type} average: {shape, type} concat : {shape, type, layers} dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)} pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight} conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride,
- class spikingjelly.activation_based.lava_exchange.SumPool2d(kernel_size, stride=None, padding=0, dilation=1)[源代码]
基类:
Module
x = torch.rand([4, 2, 4, 16, 16]) with torch.no_grad(): sp_sj = SumPool2d(kernel_size=2, stride=2) y_sj = functional.seq_to_ann_forward(x, sp_sj) sp_la = slayer.synapse.Pool(kernel_size=2, stride=2) y_la = lava_exchange.NXT_to_TNX(sp_la(lava_exchange.TNX_to_NXT(x))) print((y_sj - y_la).abs().sum())