spikingjelly.clock_driven.ann2snn.examples package
Submodules
spikingjelly.clock_driven.ann2snn.examples.if_cnn_mnist module
- class spikingjelly.clock_driven.ann2snn.examples.cnn_mnist.ANN[源代码]
基类:
torch.nn.modules.module.Module
- spikingjelly.clock_driven.ann2snn.examples.cnn_mnist.main(log_dir=None)[源代码]
- 返回
None
使用Conv-ReLU-[Conv-ReLU]-全连接-ReLU的网络结构训练并转换为SNN,进行MNIST识别。运行示例:
>>> import spikingjelly.clock_driven.ann2snn.examples.cnn_mnist as cnn_mnist >>> cnn_mnist.main() 输入运行的设备,例如“cpu”或“cuda:0” input device, e.g., "cpu" or "cuda:0": cuda:15 输入保存MNIST数据集的位置,例如“./” input root directory for saving MNIST dataset, e.g., "./": ./mnist 输入batch_size,例如“64” input batch_size, e.g., "64": 128 输入学习率,例如“1e-3” input learning rate, e.g., "1e-3": 1e-3 输入仿真时长,例如“100” input simulating steps, e.g., "100": 100 输入训练轮数,即遍历训练集的次数,例如“10” input training epochs, e.g., "10": 10 输入模型名字,用于自动生成日志文档,例如“cnn_mnist” input model name, for log_dir generating , e.g., "cnn_mnist" Epoch 0 [1/937] ANN Training Loss:2.252 Accuracy:0.078 Epoch 0 [101/937] ANN Training Loss:1.423 Accuracy:0.669 Epoch 0 [201/937] ANN Training Loss:1.117 Accuracy:0.773 Epoch 0 [301/937] ANN Training Loss:0.953 Accuracy:0.795 Epoch 0 [401/937] ANN Training Loss:0.865 Accuracy:0.788 Epoch 0 [501/937] ANN Training Loss:0.807 Accuracy:0.792 Epoch 0 [601/937] ANN Training Loss:0.764 Accuracy:0.795 Epoch 0 [701/937] ANN Training Loss:0.726 Accuracy:0.835 Epoch 0 [801/937] ANN Training Loss:0.681 Accuracy:0.880 Epoch 0 [901/937] ANN Training Loss:0.641 Accuracy:0.889 100%|██████████| 100/100 [00:00<00:00, 116.12it/s] Epoch 0 [100/100] ANN Validating Loss:0.327 Accuracy:0.881 Save model to: cnn_mnist-XXXXX\cnn_mnist.pkl ...... --------------------simulator summary-------------------- time elapsed: 46.55072790000008 (sec) ---------------------------------------------------------