利用Spiking LSTM实现基于文本的姓氏分类任务
本教程作者:LiutaoYu,fangwei123456
本节教程使用Spiking LSTM重新实现PyTorch的官方教程 NLP From Scratch: Classifying Names with a Character-Level RNN。 对应的中文版教程可参见 使用字符级别特征的RNN网络进行名字分类。 请确保你已经阅读了原版教程和代码,因为本教程是对原教程的扩展。本教程将构建和训练字符级的Spiking LSTM来对姓氏进行分类。 具体而言,本教程将在18种语言构成的几千个姓氏的数据集上训练Spiking LSTM模型,网络可根据一个姓氏的拼写预测其属于哪种语言。 完整代码可见于 clock_driven/examples/spiking_lstm_text.py。
准备数据
首先,我们参照原教程下载数据,并进行预处理。预处理后,我们可以得到一个语言对应姓氏列表的字典,即 {language: [names ...]}
。
进一步地,我们将数据集按照4:1的比例划分为训练集和测试集,即 category_lines_train
和 category_lines_test
。
这里还需要留意几个后续会经常使用的变量: all_categories
是全部语言种类的列表, n_categories=18
则是语言种类的数量,
n_letters=58
是组成 names
的所有字母和符号的集合的元素数量。
# split the data into training set and testing set
numExamplesPerCategory = []
category_lines_train = {}
category_lines_test = {}
testNumtot = 0
for c, names in category_lines.items():
category_lines_train[c] = names[:int(len(names)*0.8)]
category_lines_test[c] = names[int(len(names)*0.8):]
numExamplesPerCategory.append([len(category_lines[c]), len(category_lines_train[c]), len(category_lines_test[c])])
testNumtot += len(category_lines_test[c])
此外,我们改写了原教程中的 randomTrainingExample()
函数,以便在不同条件下进行使用。
注意此处利用了原教程中定义的 lineToTensor()
和 randomChoice()
两个函数。
前者用于将单词转化为one-hot张量,后者用于从数据集中随机抽取一个样本。
# Preparing [x, y] pair
def randomPair(sampleSource):
"""
Args:
sampleSource: 'train', 'test', 'all'
Returns:
category, line, category_tensor, line_tensor
"""
category = randomChoice(all_categories)
if sampleSource == 'train':
line = randomChoice(category_lines_train[category])
elif sampleSource == 'test':
line = randomChoice(category_lines_test[category])
elif sampleSource == 'all':
line = randomChoice(category_lines[category])
category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.float)
line_tensor = lineToTensor(line)
return category, line, category_tensor, line_tensor
构造Spiking LSTM神经网络
我们利用 spikingjelly 中的rnn模块( rnn.SpikingLSTM()
)来搭建Spiking LSTM神经网络。
其工作原理可参见论文 Long Short-Term Memory Spiking Networks and Their Applications 。
输入层神经元个数等于 n_letters
,隐藏层神经元个数 n_hidden
可自行定义,输出层神经元个数等于 n_categories
。
我们在LSTM的输出层之后接一个全连接层,并利用softmax函数对全连接层的数据进行处理以获取类别概率。
from spikingjelly.clock_driven import rnn
n_hidden = 256
class Net(nn.Module):
def __init__(self, n_letters, n_hidden, n_categories):
super().__init__()
self.n_input = n_letters
self.n_hidden = n_hidden
self.n_out = n_categories
self.lstm = rnn.SpikingLSTM(self.n_input, self.n_hidden, 1)
self.fc = nn.Linear(self.n_hidden, self.n_out)
def forward(self, x):
x, _ = self.lstm(x)
output = self.fc(x[-1])
output = F.softmax(output, dim=1)
return output
网络训练
首先,我们初始化网络 net
,并定义训练时长 TRAIN_EPISODES
、学习率 learning_rate
等。
这里我们采用 mse_loss
损失函数和 Adam
优化器来对训练网络。
单个epoch的训练流程大致如下:1)从训练集中随机抽取一个样本,获取输入和标签,并转化为tensor;2)网络接收输入,进行前向过程,获取各类别的预测概率;
3)利用 mse_loss
函数获取网络预测概率和真实标签one-hot编码之间的差距,即网络损失;4)梯度反传,并更新网络参数;
5)判断此次预测是否正确,并累计预测正确的数量,以获取模型在训练过程中针对训练集的准确率(每隔 plot_every
个epoch计算一次);
6)每隔 plot_every
个epoch在测试集上测试一次,并统计测试准确率。
此外,在训练过程中,我们会记录网络损失 avg_losses
、训练集准确率 accuracy_rec
和测试集准确率 test_accu_rec
,以便于观察训练效果,并在训练完成后绘制图片。
在训练完成之后,我们会保存网络的最终状态以用于测试;同时,也可以保存相关变量,以便后续分析。
# IF_TRAIN = 1
TRAIN_EPISODES = 1000000
plot_every = 1000
learning_rate = 1e-4
net = Net(n_letters, n_hidden, n_categories)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
print('Training...')
current_loss = 0
correct_num = 0
avg_losses = []
accuracy_rec = []
test_accu_rec = []
start = time.time()
for epoch in range(1, TRAIN_EPISODES+1):
net.train()
category, line, category_tensor, line_tensor = randomPair('train')
label_one_hot = F.one_hot(category_tensor.to(int), n_categories).float()
optimizer.zero_grad()
out_prob_log = net(line_tensor)
loss = F.mse_loss(out_prob_log, label_one_hot)
loss.backward()
optimizer.step()
# 优化一次参数后,需要重置网络的状态。是否需要?结果差别不明显!(2020.11.3)
# functional.reset_net(net)
current_loss += loss.data.item()
guess, _ = categoryFromOutput(out_prob_log.data)
if guess == category:
correct_num += 1
# Add current loss avg to list of losses
if epoch % plot_every == 0:
avg_losses.append(current_loss / plot_every)
accuracy_rec.append(correct_num / plot_every)
current_loss = 0
correct_num = 0
# 每训练一定次数即进行一次测试
if epoch % plot_every == 0: # int(TRAIN_EPISODES/1000)
net.eval()
with torch.no_grad():
numCorrect = 0
for i in range(n_categories):
category = all_categories[i]
for tname in category_lines_test[category]:
output = net(lineToTensor(tname))
# 运行一次后,需要重置网络的状态。是否需要?
# functional.reset_net(net)
guess, _ = categoryFromOutput(output.data)
if guess == category:
numCorrect += 1
test_accu = numCorrect / testNumtot
test_accu_rec.append(test_accu)
print('Epoch %d %d%% (%s); Avg_loss %.4f; Train accuracy %.4f; Test accuracy %.4f' % (
epoch, epoch / TRAIN_EPISODES * 100, timeSince(start), avg_losses[-1], accuracy_rec[-1], test_accu))
torch.save(net, 'char_rnn_classification.pth')
np.save('avg_losses.npy', np.array(avg_losses))
np.save('accuracy_rec.npy', np.array(accuracy_rec))
np.save('test_accu_rec.npy', np.array(test_accu_rec))
np.save('category_lines_train.npy', category_lines_train, allow_pickle=True)
np.save('category_lines_test.npy', category_lines_test, allow_pickle=True)
# x = np.load('category_lines_test.npy', allow_pickle=True) # 读取数据的方法
# xdict = x.item()
plt.figure()
plt.subplot(311)
plt.plot(avg_losses)
plt.title('Average loss')
plt.subplot(312)
plt.plot(accuracy_rec)
plt.title('Train accuracy')
plt.subplot(313)
plt.plot(test_accu_rec)
plt.title('Test accuracy')
plt.xlabel('Epoch (*1000)')
plt.subplots_adjust(hspace=0.6)
plt.savefig('TrainingProcess.svg')
plt.close()
设定 IF_TRAIN = 1
,在Python Console中运行 %run ./spiking_lstm_text.py
,输出如下:
Backend Qt5Agg is interactive backend. Turning interactive mode on.
Training...
Epoch 1000 0% (0m 18s); Avg_loss 0.0525; Train accuracy 0.0830; Test accuracy 0.0806
Epoch 2000 0% (0m 37s); Avg_loss 0.0514; Train accuracy 0.1470; Test accuracy 0.1930
Epoch 3000 0% (0m 55s); Avg_loss 0.0503; Train accuracy 0.1650; Test accuracy 0.0537
Epoch 4000 0% (1m 14s); Avg_loss 0.0494; Train accuracy 0.1920; Test accuracy 0.0938
...
...
Epoch 998000 99% (318m 54s); Avg_loss 0.0063; Train accuracy 0.9300; Test accuracy 0.5036
Epoch 999000 99% (319m 14s); Avg_loss 0.0056; Train accuracy 0.9380; Test accuracy 0.5004
Epoch 1000000 100% (319m 33s); Avg_loss 0.0055; Train accuracy 0.9340; Test accuracy 0.5118
下图展示了训练过程中损失函数、测试集准确率、测试集准确率随时间的变化。
值得注意的一点是,测试表明,在当前Spiking LSTM网络中是否在一次运行完成后重置网络 functional.reset_net(net)
对于结果没有显著的影响。
我们猜测是因为当前网络输入是随时间变化的,而且网络自身需要运行一段时间后才会输出分类结果,因此网络初始状态影响不显著。
网络测试
在测试过程中,我们首先需要导入训练完成后存储的网络,随后进行三方面的测试:(1)计算最终的测试集准确率;(2)让用户输入姓氏拼写以预测其属于哪种语言; (3)计算Confusion matrix,每一行表示当样本源于某一个类别时,网络预测其属于各类别的概率,即对角线表示预测正确的概率。
# IF_TRAIN = 0
print('Testing...')
net = torch.load('char_rnn_classification.pth')
# 遍历测试集计算准确率
print('Calculating testing accuracy...')
numCorrect = 0
for i in range(n_categories):
category = all_categories[i]
for tname in category_lines_test[category]:
output = net(lineToTensor(tname))
# 运行一次后,需要重置网络的状态。是否需要?
# functional.reset_net(net)
guess, _ = categoryFromOutput(output.data)
if guess == category:
numCorrect += 1
test_accu = numCorrect / testNumtot
print('Test accuracy: {:.3f}, Random guess: {:.3f}'.format(test_accu, 1/n_categories))
# 让用户输入姓氏以判断其属于哪种语系
n_predictions = 3
for j in range(3):
first_name = input('请输入一个姓氏以判断其属于哪种语系:')
print('\n> %s' % first_name)
output = net(lineToTensor(first_name))
# 运行一次后,需要重置网络的状态。是否需要?
# functional.reset_net(net)
# Get top N categories
topv, topi = output.topk(n_predictions, 1, True)
predictions = []
for i in range(n_predictions):
value = topv[0][i].item()
category_index = topi[0][i].item()
print('(%.2f) %s' % (value, all_categories[category_index]))
predictions.append([value, all_categories[category_index]])
# 计算confusion矩阵
print('Calculating confusion matrix...')
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000
# Keep track of correct guesses in a confusion matrix
for i in range(n_confusion):
category, line, category_tensor, line_tensor = randomPair('all')
output = net(line_tensor)
# 运行一次后,需要重置网络的状态。是否需要?
# functional.reset_net(net)
guess, guess_i = categoryFromOutput(output.data)
category_i = all_categories.index(category)
confusion[category_i][guess_i] += 1
confusion = confusion / confusion.sum(1)
np.save('confusion.npy', confusion)
# Set up plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2
plt.show()
plt.savefig('ConfusionMatrix.svg')
plt.close()
设定 IF_TRAIN = 0
,在Python Console中运行 %run ./spiking_lstm_text.py
,输出如下:
Testing...
Calculating testing accuracy...
Test accuracy: 0.512, Random guess: 0.056
请输入一个姓氏以判断其属于哪种语系:> YU
> YU
(0.18) Scottish
(0.12) English
(0.11) Italian
请输入一个姓氏以判断其属于哪种语系:> Yu
> Yu
(0.63) Chinese
(0.23) Korean
(0.07) Vietnamese
请输入一个姓氏以判断其属于哪种语系:> Zou
> Zou
(1.00) Chinese
(0.00) Arabic
(0.00) Polish
Calculating confusion matrix...
下图展示了Confusion matrix。对角线越亮,表示模型对某一类别预测最好,很少产生混淆,如Arabic和Greek。 而有的语言则较容易产生混淆,如Korean和Chinese,Spanish和Portuguese,English和Scottish。