spikingjelly.activation_based.examples package#
Spiking FCNet for MNIST#
Spiking CNN for Fashion MNIST#
- class spikingjelly.activation_based.examples.conv_fashion_mnist.CSNN(T: int, channels: int, use_cupy=False)[源代码]#
基类:
Module
- spikingjelly.activation_based.examples.conv_fashion_mnist.main()[源代码]#
(sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h
- usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]
[-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]
Classify Fashion-MNIST
- optional arguments:
- -h, --help
show this help message and exit
- -T T
simulating time-steps
-device DEVICE device -b B batch size -epochs N number of total epochs to run -j N number of data loading workers (default: 4) -data-dir DATA_DIR root dir of Fashion-MNIST dataset -out-dir OUT_DIR root dir for saving logs and checkpoint -resume RESUME resume from the checkpoint path -amp automatic mixed precision training -cupy use cupy neuron and multi-step forward mode -opt OPT use which optimizer. SDG or Adam -momentum MOMENTUM momentum for SGD -save-es dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}
Spike-based BP for CIFAR-10#
代码作者: Yanqi Chen <chyq@pku.edu.cn>
A reproduction of the paper Enabling Spike-Based Backpropagation for Training Deep Neural Network Architectures.
This code reproduces a novel gradient-based training method of SNN. We to some extent refer to the network structure and some other detailed implementation in the authors' implementation. Since the training method and neuron models are slightly different from which in this framework, we rewrite them in a compatible style.
Assuming you have at least 1 Nvidia GPU.
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu(*args, **kwargs)[源代码]#
基类:
Function
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.BaseNode(v_threshold=1.0, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>, monitor=False)[源代码]#
基类:
Module
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.LIFNode(tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>, fire=True)[源代码]#
基类:
BaseNode
- class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.IFNode(v_threshold=0.75, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>)[源代码]#
基类:
BaseNode
DVS Gesture Classification#
Optimizing Training Memory: Spiking VGG for CIFAR10-DVS#
See the tutorial and the Github repo for more details.
- class spikingjelly.activation_based.examples.memopt.data_module.Cutout(n_holes, length=None, max_length=None)[源代码]#
基类:
objectRandomly mask out one or more patches from an image. :param n_holes: Number of patches to cut out of each image. :type n_holes: int :param length: The length (in pixels) of each square patch. :type length: int :param max_length: If not None, randomly sample the length of the square
patch. If None, use the argument length instead.
- class spikingjelly.activation_based.examples.memopt.data_module.CIFAR10DVSNDA(M=1, N=2)[源代码]#
基类:
object
- class spikingjelly.activation_based.examples.memopt.data_module.CIFAR10DVSDataModule(*args: Any, **kwargs: Any)[源代码]#
基类:
LightningDataModule
- class spikingjelly.activation_based.examples.memopt.models.VGGBlock(in_plane, out_plane, kernel_size, stride, padding, preceding_avg_pool=False, **kwargs)[源代码]#
基类:
Module
- class spikingjelly.activation_based.examples.memopt.models.CIFAR10DVSVGG(dropout: float = 0.25, tau: float = 1.333, decay_input: bool = False, detach_reset: bool = True, surrogate_function=ATan(alpha=2.0, spiking=True), backend='triton')[源代码]#
基类:
Module
Speech Commands#
代码作者: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>
A reproduction of the paper Technical report: supervised training of convolutional spiking neural networks with PyTorch.
This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.
备注
To prevent too much dependency like librosa, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.
Confusion matrix of TEST set after training (50 epochs):
Count |
Prediction |
||||||||||||
"Yes" |
"Stop" |
"No" |
"Right" |
"Up" |
"Left" |
"On" |
"Down" |
"Off" |
"Go" |
Other |
Silence |
||
Ground Truth |
"Yes" |
234 |
0 |
2 |
0 |
0 |
3 |
0 |
0 |
0 |
1 |
16 |
0 |
"Stop" |
0 |
233 |
0 |
1 |
5 |
0 |
0 |
0 |
0 |
1 |
9 |
0 |
|
"No" |
0 |
1 |
223 |
1 |
0 |
1 |
0 |
5 |
0 |
9 |
12 |
0 |
|
"Right" |
0 |
0 |
0 |
234 |
0 |
0 |
0 |
0 |
0 |
0 |
24 |
1 |
|
"Up" |
0 |
4 |
0 |
0 |
249 |
0 |
0 |
0 |
8 |
0 |
11 |
0 |
|
"Left" |
3 |
1 |
2 |
3 |
1 |
250 |
0 |
0 |
1 |
0 |
6 |
0 |
|
"On" |
0 |
3 |
0 |
0 |
0 |
0 |
231 |
0 |
2 |
1 |
9 |
0 |
|
"Down" |
0 |
0 |
7 |
0 |
0 |
1 |
2 |
230 |
0 |
4 |
8 |
1 |
|
"Off" |
0 |
0 |
2 |
1 |
4 |
2 |
6 |
0 |
237 |
1 |
9 |
0 |
|
"Go" |
0 |
2 |
5 |
0 |
0 |
2 |
0 |
1 |
5 |
220 |
16 |
0 |
|
Other |
6 |
21 |
12 |
25 |
22 |
19 |
25 |
14 |
11 |
40 |
4072 |
1 |
|
Silence |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
0 |
260 |
|
- spikingjelly.activation_based.examples.speechcommands.create_fb_matrix(n_freqs: int, f_min: float, f_max: float, n_mels: int, sample_rate: int, dct_type: str | None = 'slaney') Tensor[源代码]#
- class spikingjelly.activation_based.examples.speechcommands.MelScaleDelta(order, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0, f_max: float | None = None, n_stft: int | None = None, dct_type: str | None = 'slaney')[源代码]#
基类:
Module
- class spikingjelly.activation_based.examples.speechcommands.LIFWrapper(module, flatten=False)[源代码]#
基类:
Module- forward(x_seq: Tensor) Tensor[源代码]#
API Language:
中文
输入
x_seq的形状为[batch_size, channel, T, n_mel]。在送入被包装模块前, 时间维和批量维会交换为[T, channel, batch_size, n_mel]以适配多步前向。 当self.flatten=True时,输出会重排并展平成[batch_size, T, channel * n_mel];否则返回[batch_size, channel, T, n_mel]。- 参数:
x_seq (torch.Tensor) -- 输入序列,shape=[batch_size, channel, T, n_mel]
- 返回:
输出序列;当
self.flatten=True时 shape=[batch_size, T, channel * n_mel], 否则 shape=[batch_size, channel, T, n_mel]- 返回类型:
English
The input
x_seqhas shape[batch_size, channel, T, n_mel]. Before passing it to the wrapped module, the time and batch dimensions are swapped to[T, channel, batch_size, n_mel]to match multi-step forward mode. Ifself.flatten=True, the output is permuted and flattened to[batch_size, T, channel * n_mel]; otherwise the output shape is[batch_size, channel, T, n_mel].- 参数:
x_seq (torch.Tensor) -- Input sequence, shape=[batch_size, channel, T, n_mel]
- 返回:
Output sequence; shape=[batch_size, T, channel * n_mel] when
self.flatten=True, otherwise shape=[batch_size, channel, T, n_mel]- 返回类型:
- class spikingjelly.activation_based.examples.speechcommands.Net[源代码]#
基类:
ModuleAPI Language:
中文
初始化语音命令识别网络。该网络由三层卷积-脉冲神经元模块组成, 最后一层输出会展平到
channel * n_mel维度,再由全连接层映射到类别空间。 同时初始化训练过程统计字段。
English
Initialize the speech command recognition network. The network stacks three convolution-spiking blocks, flattens the final feature dimension to
channel * n_mel, and maps features to class logits with a linear layer. Training statistics fields are also initialized.- forward(x)[源代码]#
API Language:
中文
对输入特征先经过卷积脉冲模块,得到按时间步展开的类别 logits, 然后沿时间维做均值池化,输出每个样本的最终分类 logits。
- 参数:
x (torch.Tensor) -- 输入特征,shape=[batch_size, delta_order + 1, T, n_mel]
- 返回:
分类 logits,shape=[batch_size, label_cnt]
- 返回类型:
English
Run the input features through the convolutional spiking stack to obtain per-time-step class logits, then apply mean pooling over the time dimension to produce final logits for each sample.
- 参数:
x (torch.Tensor) -- Input features, shape=[batch_size, delta_order + 1, T, n_mel]
- 返回:
Classification logits, shape=[batch_size, label_cnt]
- 返回类型:
RSNN for Sequential Fashion MNIST#
- class spikingjelly.activation_based.examples.rsnn_sequential_fmnist.StatefulSynapseNet[源代码]#
基类:
Module
Spiking LSTM for Sequential MNIST#
Spiking LSTM for Text Classification#
A2C#
DQN_state#
- class spikingjelly.activation_based.examples.DQN_state.DQN(input_size, hidden_size, output_size)[源代码]#
基类:
Module
- class spikingjelly.activation_based.examples.Spiking_DQN_state.Transition(state, action, next_state, reward)#
基类:
tupleCreate new instance of Transition(state, action, next_state, reward)
- action#
Alias for field number 1
- next_state#
Alias for field number 2
- reward#
Alias for field number 3
- state#
Alias for field number 0
- class spikingjelly.activation_based.examples.Spiking_DQN_state.ReplayMemory(capacity)[源代码]#
基类:
object
- class spikingjelly.activation_based.examples.Spiking_DQN_state.NonSpikingLIFNode(*args, **kwargs)[源代码]#
基类:
LIFNode
- class spikingjelly.activation_based.examples.Spiking_DQN_state.DQSN(input_size, hidden_size, output_size, T=16)[源代码]#
基类:
Module