import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import spikingjelly.event_driven.encoding as encoding
import spikingjelly.event_driven.neuron as neuron
import sys
if sys.platform != 'win32':
import readline
import math
[文档]class Net(nn.Module):
def __init__(self, m, T):
# m是高斯调谐曲线编码器编码一个像素点所使用的神经元数量
super().__init__()
self.tempotron = neuron.Tempotron(784*m, 10, T)
[文档] def forward(self, x: torch.Tensor):
# 返回的是输出层10个Tempotron在仿真时长内的电压峰值
return self.tempotron(x, 'v_max')
[文档]def main():
'''
:return: None
使用高斯调谐曲线编码器编码图像为脉冲,单层Tempotron进行MNIST识别。运行示例:
.. code-block:: python
>>> import spikingjelly.event_driven.examples.tempotron_mnist as tempotron_mnist
>>> tempotron_mnist.main()
输入运行的设备,例如“cpu”或“cuda:0”
input device, e.g., "cpu" or "cuda:0": cuda:15
输入保存MNIST数据集的位置,例如“./”
input root directory for saving MNIST dataset, e.g., "./": ./mnist
输入batch_size,例如“64”
input batch_size, e.g., "64": 64
输入学习率,例如“1e-3”
input learning rate, e.g., "1e-3": 1e-3
输入仿真时长,例如“100”
input simulating steps, e.g., "100": 100
输入训练轮数,即遍历训练集的次数,例如“100”
input training epochs, e.g., "100": 10
输入使用高斯调谐曲线编码每个像素点使用的神经元数量,例如“16”
input neuron number for encoding a piexl in GaussianTuning encoder, e.g., "16": 16
输入保存tensorboard日志文件的位置,例如“./”
input root directory for saving tensorboard logs, e.g., "./": ./logs_tempotron_mnist
cuda:15 ./mnist 64 0.001 100 100 16 ./logs_tempotron_mnist
train_acc 0.09375 0
cuda:15 ./mnist 64 0.001 100 100 16 ./logs_tempotron_mnist
train_acc 0.78125 512
...
'''
device = input('输入运行的设备,例如“cpu”或“cuda:0”\n input device, e.g., "cpu" or "cuda:0": ')
dataset_dir = input('输入保存MNIST数据集的位置,例如“./”\n input root directory for saving MNIST dataset, e.g., "./": ')
batch_size = int(input('输入batch_size,例如“64”\n input batch_size, e.g., "64": '))
learning_rate = float(input('输入学习率,例如“1e-3”\n input learning rate, e.g., "1e-3": '))
T = int(input('输入仿真时长,例如“100”\n input simulating steps, e.g., "100": '))
train_epoch = int(input('输入训练轮数,即遍历训练集的次数,例如“100”\n input training epochs, e.g., "100": '))
m = int(input('输入使用高斯调谐曲线编码每个像素点使用的神经元数量,例如“16”\n input neuron number for encoding a piexl in GaussianTuning encoder, e.g., "16": '))
log_dir = input('输入保存tensorboard日志文件的位置,例如“./”\n input root directory for saving tensorboard logs, e.g., "./": ')
# 每个像素点用m个神经元来编码
encoder = encoding.GaussianTuning(n=1, m=m, x_min=torch.zeros(size=[1]).to(device), x_max=torch.ones(size=[1]).to(device))
writer = SummaryWriter(log_dir)
# 初始化数据加载器
train_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(
root=dataset_dir,
train=True,
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(
root=dataset_dir,
train=False,
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=batch_size,
shuffle=True,
drop_last=False)
# 初始化网络
net = Net(m, T).to(device)
# 使用Adam优化器
optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
train_times = 0
for epoch in range(train_epoch):
net.train()
for img, label in train_data_loader:
img = img.view(img.shape[0], -1).unsqueeze(1) # [batch_size, 1, 784]
in_spikes = encoder.encode(img.to(device), T) # [batch_size, 1, 784, m]
in_spikes = in_spikes.view(in_spikes.shape[0], -1) # [batch_size, 784*m]
v_max = net(in_spikes)
train_acc = (v_max.argmax(dim=1) == label.to(device)).float().mean().item()
if train_times % 256 == 0:
writer.add_scalar('train_acc', train_acc, train_times)
if train_times % 512 == 0:
print(device, dataset_dir, batch_size, learning_rate, T, train_epoch, m, log_dir)
print('train_acc', train_acc, train_times)
loss = neuron.Tempotron.mse_loss(v_max, net.tempotron.v_threshold, label.to(device), 10)
loss.backward()
optimizer.step()
train_times += 1
net.eval()
with torch.no_grad():
correct_num = 0
img_num = 0
for img, label in test_data_loader:
img = img.view(img.shape[0], -1).unsqueeze(1) # [batch_size, 1, 784]
in_spikes = encoder.encode(img.to(device), T) # [batch_size, 1, 784, m]
in_spikes = in_spikes.view(in_spikes.shape[0], -1) # [batch_size, 784*m]
v_max = net(in_spikes)
correct_num += (v_max.argmax(dim=1) == label.to(device)).float().sum().item()
img_num += img.shape[0]
test_acc = correct_num / img_num
writer.add_scalar('test_acc', test_acc, epoch)
print('test_acc', test_acc, train_times, log_dir)
if __name__ == '__main__':
main()