spikingjelly.activation_based.examples package#

Spiking FCNet for MNIST#

class spikingjelly.activation_based.examples.lif_fc_mnist.SNN(tau)[源代码]#

基类:Module

forward(x: Tensor)[源代码]#
spikingjelly.activation_based.examples.lif_fc_mnist.main()[源代码]#
返回:

None

使用全连接-LIF的网络结构,进行MNIST识别。

这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

The network with FC-LIF structure for classifying MNIST.

This function initials the network, starts trainingand shows accuracy on test dataset.

Spiking CNN for Fashion MNIST#

class spikingjelly.activation_based.examples.conv_fashion_mnist.CSNN(T: int, channels: int, use_cupy=False)[源代码]#

基类:Module

forward(x: Tensor)[源代码]#
spiking_encoder()[源代码]#
spikingjelly.activation_based.examples.conv_fashion_mnist.main()[源代码]#

(sj-dev) wfang@Precision-5820-Tower-X-Series:~/spikingjelly_dev$ python -m spikingjelly.activation_based.examples.conv_fashion_mnist -h

usage: conv_fashion_mnist.py [-h] [-T T] [-device DEVICE] [-b B] [-epochs N] [-j N] [-data-dir DATA_DIR] [-out-dir OUT_DIR]

[-resume RESUME] [-amp] [-cupy] [-opt OPT] [-momentum MOMENTUM] [-lr LR]

Classify Fashion-MNIST

optional arguments:
-h, --help

show this help message and exit

-T T

simulating time-steps

-device DEVICE device -b B batch size -epochs N number of total epochs to run -j N number of data loading workers (default: 4) -data-dir DATA_DIR root dir of Fashion-MNIST dataset -out-dir OUT_DIR root dir for saving logs and checkpoint -resume RESUME resume from the checkpoint path -amp automatic mixed precision training -cupy use cupy neuron and multi-step forward mode -opt OPT use which optimizer. SDG or Adam -momentum MOMENTUM momentum for SGD -save-es dir for saving a batch spikes encoded by the first {Conv2d-BatchNorm2d-IFNode}

Spike-based BP for CIFAR-10#

代码作者: Yanqi Chen <chyq@pku.edu.cn>

A reproduction of the paper Enabling Spike-Based Backpropagation for Training Deep Neural Network Architectures.

This code reproduces a novel gradient-based training method of SNN. We to some extent refer to the network structure and some other detailed implementation in the authors' implementation. Since the training method and neuron models are slightly different from which in this framework, we rewrite them in a compatible style.

Assuming you have at least 1 Nvidia GPU.

class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu(*args, **kwargs)[源代码]#

基类:Function

static forward(ctx, x)[源代码]#
static backward(ctx, grad_output)[源代码]#
class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.BaseNode(v_threshold=1.0, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>, monitor=False)[源代码]#

基类:Module

spiking()[源代码]#
forward(dv: Tensor)[源代码]#
reset()[源代码]#
class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.LIFNode(tau=100.0, v_threshold=1.0, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>, fire=True)[源代码]#

基类:BaseNode

forward(dv: Tensor)[源代码]#
class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.IFNode(v_threshold=0.75, v_reset=0.0, surrogate_function=<bound method Function.apply of <class 'spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.relu'>>)[源代码]#

基类:BaseNode

forward(dv: Tensor)[源代码]#
class spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.ResNet11[源代码]#

基类:Module

forward(x)[源代码]#
reset_()[源代码]#
spikingjelly.activation_based.examples.cifar10_r11_enabling_spikebased_backpropagation.main()[源代码]#

DVS Gesture Classification#

spikingjelly.activation_based.examples.classify_dvsg.main()[源代码]#

Optimizing Training Memory: Spiking VGG for CIFAR10-DVS#

See the tutorial and the Github repo for more details.

class spikingjelly.activation_based.examples.memopt.data_module.Cutout(n_holes, length=None, max_length=None)[源代码]#

基类:object

Randomly mask out one or more patches from an image. :param n_holes: Number of patches to cut out of each image. :type n_holes: int :param length: The length (in pixels) of each square patch. :type length: int :param max_length: If not None, randomly sample the length of the square

patch. If None, use the argument length instead.

class spikingjelly.activation_based.examples.memopt.data_module.CIFAR10DVSNDA(M=1, N=2)[源代码]#

基类:object

class spikingjelly.activation_based.examples.memopt.data_module.CIFAR10DVSDataModule(*args: Any, **kwargs: Any)[源代码]#

基类:LightningDataModule

prepare_data()[源代码]#
setup(stage: str)[源代码]#
train_dataloader()[源代码]#
val_dataloader()[源代码]#
test_dataloader()[源代码]#
predict_dataloader()[源代码]#
class spikingjelly.activation_based.examples.memopt.models.VGGBlock(in_plane, out_plane, kernel_size, stride, padding, preceding_avg_pool=False, **kwargs)[源代码]#

基类:Module

forward(x_seq)[源代码]#
class spikingjelly.activation_based.examples.memopt.models.CIFAR10DVSVGG(dropout: float = 0.25, tau: float = 1.333, decay_input: bool = False, detach_reset: bool = True, surrogate_function=ATan(alpha=2.0, spiking=True), backend='triton')[源代码]#

基类:Module

forward(input)[源代码]#
class spikingjelly.activation_based.examples.memopt.lightning_modules.ClassificationLightningModule(*args: Any, **kwargs: Any)[源代码]#

基类:LightningModule

forward(x)[源代码]#
training_step(batch, batch_idx)[源代码]#
on_train_epoch_end()[源代码]#
validation_step(batch, batch_idx)[源代码]#
on_validation_epoch_end()[源代码]#

Speech Commands#

代码作者: Yanqi Chen <chyq@pku.edu.cn>, Ismail Khalfaoui Hassani <ismail.khalfaoui-hassani@univ-tlse3.fr>

A reproduction of the paper Technical report: supervised training of convolutional spiking neural networks with PyTorch.

This code reproduces an audio recognition task using convolutional SNN. It provides comparable performance to ANN.

备注

To prevent too much dependency like librosa, we implement MelScale ourselves. We provide two kinds of DCT types: Slaney & HTK. Slaney style is used in the original paper and will be applied by default.

Confusion matrix of TEST set after training (50 epochs):

Count

Prediction

"Yes"

"Stop"

"No"

"Right"

"Up"

"Left"

"On"

"Down"

"Off"

"Go"

Other

Silence

Ground Truth

"Yes"

234

0

2

0

0

3

0

0

0

1

16

0

"Stop"

0

233

0

1

5

0

0

0

0

1

9

0

"No"

0

1

223

1

0

1

0

5

0

9

12

0

"Right"

0

0

0

234

0

0

0

0

0

0

24

1

"Up"

0

4

0

0

249

0

0

0

8

0

11

0

"Left"

3

1

2

3

1

250

0

0

1

0

6

0

"On"

0

3

0

0

0

0

231

0

2

1

9

0

"Down"

0

0

7

0

0

1

2

230

0

4

8

1

"Off"

0

0

2

1

4

2

6

0

237

1

9

0

"Go"

0

2

5

0

0

2

0

1

5

220

16

0

Other

6

21

12

25

22

19

25

14

11

40

4072

1

Silence

0

0

0

0

0

0

0

0

0

0

0

260

spikingjelly.activation_based.examples.speechcommands.mel_to_hz(mels, dct_type)[源代码]#
spikingjelly.activation_based.examples.speechcommands.hz_to_mel(frequencies, dct_type)[源代码]#
spikingjelly.activation_based.examples.speechcommands.create_fb_matrix(n_freqs: int, f_min: float, f_max: float, n_mels: int, sample_rate: int, dct_type: str | None = 'slaney') Tensor[源代码]#
class spikingjelly.activation_based.examples.speechcommands.MelScaleDelta(order, n_mels: int = 128, sample_rate: int = 16000, f_min: float = 0.0, f_max: float | None = None, n_stft: int | None = None, dct_type: str | None = 'slaney')[源代码]#

基类:Module

forward(specgram: Tensor) Tensor[源代码]#
class spikingjelly.activation_based.examples.speechcommands.Pad(size)[源代码]#

基类:object

class spikingjelly.activation_based.examples.speechcommands.Rescale[源代码]#

基类:object

spikingjelly.activation_based.examples.speechcommands.collate_fn(data)[源代码]#
class spikingjelly.activation_based.examples.speechcommands.LIFWrapper(module, flatten=False)[源代码]#

基类:Module

forward(x_seq: Tensor) Tensor[源代码]#

API Language:

中文 | English


  • 中文

输入 x_seq 的形状为 [batch_size, channel, T, n_mel]。在送入被包装模块前, 时间维和批量维会交换为 [T, channel, batch_size, n_mel] 以适配多步前向。 当 self.flatten=True 时,输出会重排并展平成 [batch_size, T, channel * n_mel];否则返回 [batch_size, channel, T, n_mel]

参数:

x_seq (torch.Tensor) -- 输入序列,shape=[batch_size, channel, T, n_mel]

返回:

输出序列;当 self.flatten=True 时 shape=[batch_size, T, channel * n_mel], 否则 shape=[batch_size, channel, T, n_mel]

返回类型:

torch.Tensor


  • English

The input x_seq has shape [batch_size, channel, T, n_mel]. Before passing it to the wrapped module, the time and batch dimensions are swapped to [T, channel, batch_size, n_mel] to match multi-step forward mode. If self.flatten=True, the output is permuted and flattened to [batch_size, T, channel * n_mel]; otherwise the output shape is [batch_size, channel, T, n_mel].

参数:

x_seq (torch.Tensor) -- Input sequence, shape=[batch_size, channel, T, n_mel]

返回:

Output sequence; shape=[batch_size, T, channel * n_mel] when self.flatten=True, otherwise shape=[batch_size, channel, T, n_mel]

返回类型:

torch.Tensor

class spikingjelly.activation_based.examples.speechcommands.Net[源代码]#

基类:Module

API Language:

中文 | English


  • 中文

初始化语音命令识别网络。该网络由三层卷积-脉冲神经元模块组成, 最后一层输出会展平到 channel * n_mel 维度,再由全连接层映射到类别空间。 同时初始化训练过程统计字段。


  • English

Initialize the speech command recognition network. The network stacks three convolution-spiking blocks, flattens the final feature dimension to channel * n_mel, and maps features to class logits with a linear layer. Training statistics fields are also initialized.

forward(x)[源代码]#

API Language:

中文 | English


  • 中文

对输入特征先经过卷积脉冲模块,得到按时间步展开的类别 logits, 然后沿时间维做均值池化,输出每个样本的最终分类 logits。

参数:

x (torch.Tensor) -- 输入特征,shape=[batch_size, delta_order + 1, T, n_mel]

返回:

分类 logits,shape=[batch_size, label_cnt]

返回类型:

torch.Tensor


  • English

Run the input features through the convolutional spiking stack to obtain per-time-step class logits, then apply mean pooling over the time dimension to produce final logits for each sample.

参数:

x (torch.Tensor) -- Input features, shape=[batch_size, delta_order + 1, T, n_mel]

返回:

Classification logits, shape=[batch_size, label_cnt]

返回类型:

torch.Tensor

RSNN for Sequential Fashion MNIST#

class spikingjelly.activation_based.examples.rsnn_sequential_fmnist.PlainNet[源代码]#

基类:Module

forward(x: Tensor)[源代码]#
class spikingjelly.activation_based.examples.rsnn_sequential_fmnist.StatefulSynapseNet[源代码]#

基类:Module

forward(x: Tensor)[源代码]#
class spikingjelly.activation_based.examples.rsnn_sequential_fmnist.FeedBackNet[源代码]#

基类:Module

forward(x: Tensor)[源代码]#
spikingjelly.activation_based.examples.rsnn_sequential_fmnist.main()[源代码]#

Spiking LSTM for Sequential MNIST#

class spikingjelly.activation_based.examples.spiking_lstm_sequential_mnist.Net[源代码]#

基类:Module

forward(x)[源代码]#
spikingjelly.activation_based.examples.spiking_lstm_sequential_mnist.main()[源代码]#

Spiking LSTM for Text Classification#

A2C#

class spikingjelly.activation_based.examples.Spiking_A2C.NonSpikingLIFNode(*args, **kwargs)[源代码]#

基类:LIFNode

single_step_forward(x: Tensor)[源代码]#
class spikingjelly.activation_based.examples.Spiking_A2C.ActorCritic(num_inputs, num_outputs, hidden_size, T=16)[源代码]#

基类:Module

forward(x)[源代码]#

DQN_state#

class spikingjelly.activation_based.examples.DQN_state.ReplayMemory(capacity)[源代码]#

基类:object

push(*args)[源代码]#

Saves a transition.

sample(batch_size)[源代码]#
class spikingjelly.activation_based.examples.DQN_state.DQN(input_size, hidden_size, output_size)[源代码]#

基类:Module

forward(x)[源代码]#
class spikingjelly.activation_based.examples.Spiking_DQN_state.Transition(state, action, next_state, reward)#

基类:tuple

Create new instance of Transition(state, action, next_state, reward)

action#

Alias for field number 1

next_state#

Alias for field number 2

reward#

Alias for field number 3

state#

Alias for field number 0

class spikingjelly.activation_based.examples.Spiking_DQN_state.ReplayMemory(capacity)[源代码]#

基类:object

push(*args)[源代码]#
sample(batch_size)[源代码]#
class spikingjelly.activation_based.examples.Spiking_DQN_state.NonSpikingLIFNode(*args, **kwargs)[源代码]#

基类:LIFNode

single_step_forward(x: Tensor)[源代码]#
class spikingjelly.activation_based.examples.Spiking_DQN_state.DQSN(input_size, hidden_size, output_size, T=16)[源代码]#

基类:Module

forward(x)[源代码]#
spikingjelly.activation_based.examples.Spiking_DQN_state.train(use_cuda, model_dir, log_dir, env_name, hidden_size, num_episodes, seed)[源代码]#
spikingjelly.activation_based.examples.Spiking_DQN_state.play(use_cuda, pt_path, env_name, hidden_size, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None)[源代码]#

PPO#

spikingjelly.activation_based.examples.PPO.make_env()[源代码]#
class spikingjelly.activation_based.examples.PPO.ActorCritic(num_inputs, num_outputs, hidden_size, std=0.0)[源代码]#

基类:Module

forward(x)[源代码]#

Common Utilities#