利用Spiking LSTM实现基于文本的姓氏分类任务

本教程作者:LiutaoYufangwei123456

本节教程使用Spiking LSTM重新实现PyTorch的官方教程 NLP From Scratch: Classifying Names with a Character-Level RNN。 对应的中文版教程可参见 使用字符级别特征的RNN网络进行名字分类。 请确保你已经阅读了原版教程和代码,因为本教程是对原教程的扩展。本教程将构建和训练字符级的Spiking LSTM来对姓氏进行分类。 具体而言,本教程将在18种语言构成的几千个姓氏的数据集上训练Spiking LSTM模型,网络可根据一个姓氏的拼写预测其属于哪种语言。 完整代码可见于 clock_driven/examples/spiking_lstm_text.py

准备数据

首先,我们参照原教程下载数据,并进行预处理。预处理后,我们可以得到一个语言对应姓氏列表的字典,即 {language: [names ...]} 。 进一步地,我们将数据集按照4:1的比例划分为训练集和测试集,即 category_lines_traincategory_lines_test 。 这里还需要留意几个后续会经常使用的变量: all_categories 是全部语言种类的列表, n_categories=18 则是语言种类的数量, n_letters=58 是组成 names 的所有字母和符号的集合的元素数量。

# split the data into training set and testing set
numExamplesPerCategory = []
category_lines_train = {}
category_lines_test = {}
testNumtot = 0
for c, names in category_lines.items():
    category_lines_train[c] = names[:int(len(names)*0.8)]
    category_lines_test[c] = names[int(len(names)*0.8):]
    numExamplesPerCategory.append([len(category_lines[c]), len(category_lines_train[c]), len(category_lines_test[c])])
    testNumtot += len(category_lines_test[c])

此外,我们改写了原教程中的 randomTrainingExample() 函数,以便在不同条件下进行使用。 注意此处利用了原教程中定义的 lineToTensor()randomChoice() 两个函数。 前者用于将单词转化为one-hot张量,后者用于从数据集中随机抽取一个样本。

# Preparing [x, y] pair
def randomPair(sampleSource):
    """
    Args:
        sampleSource:  'train', 'test', 'all'
    Returns:
        category, line, category_tensor, line_tensor
    """
    category = randomChoice(all_categories)
    if sampleSource == 'train':
        line = randomChoice(category_lines_train[category])
    elif sampleSource == 'test':
        line = randomChoice(category_lines_test[category])
    elif sampleSource == 'all':
        line = randomChoice(category_lines[category])
    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.float)
    line_tensor = lineToTensor(line)
    return category, line, category_tensor, line_tensor

构造Spiking LSTM神经网络

我们利用 spikingjelly 中的rnn模块( rnn.SpikingLSTM() )来搭建Spiking LSTM神经网络。 其工作原理可参见论文 Long Short-Term Memory Spiking Networks and Their Applications 。 输入层神经元个数等于 n_letters ,隐藏层神经元个数 n_hidden 可自行定义,输出层神经元个数等于 n_categories 。 我们在LSTM的输出层之后接一个全连接层,并利用softmax函数对全连接层的数据进行处理以获取类别概率。

from spikingjelly.clock_driven import rnn
n_hidden = 256

class Net(nn.Module):
    def __init__(self, n_letters, n_hidden, n_categories):
        super().__init__()
        self.n_input = n_letters
        self.n_hidden = n_hidden
        self.n_out = n_categories
        self.lstm = rnn.SpikingLSTM(self.n_input, self.n_hidden, 1)
        self.fc = nn.Linear(self.n_hidden, self.n_out)

    def forward(self, x):
        x, _ = self.lstm(x)
        output = self.fc(x[-1])
        output = F.softmax(output, dim=1)
        return output

网络训练

首先,我们初始化网络 net ,并定义训练时长 TRAIN_EPISODES 、学习率 learning_rate 等。 这里我们采用 mse_loss 损失函数和 Adam 优化器来对训练网络。 单个epoch的训练流程大致如下:1)从训练集中随机抽取一个样本,获取输入和标签,并转化为tensor;2)网络接收输入,进行前向过程,获取各类别的预测概率; 3)利用 mse_loss 函数获取网络预测概率和真实标签one-hot编码之间的差距,即网络损失;4)梯度反传,并更新网络参数; 5)判断此次预测是否正确,并累计预测正确的数量,以获取模型在训练过程中针对训练集的准确率(每隔 plot_every 个epoch计算一次); 6)每隔 plot_every 个epoch在测试集上测试一次,并统计测试准确率。 此外,在训练过程中,我们会记录网络损失 avg_losses 、训练集准确率 accuracy_rec 和测试集准确率 test_accu_rec ,以便于观察训练效果,并在训练完成后绘制图片。 在训练完成之后,我们会保存网络的最终状态以用于测试;同时,也可以保存相关变量,以便后续分析。

# IF_TRAIN = 1
TRAIN_EPISODES = 1000000
plot_every = 1000
learning_rate = 1e-4

net = Net(n_letters, n_hidden, n_categories)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

print('Training...')
current_loss = 0
correct_num = 0
avg_losses = []
accuracy_rec = []
test_accu_rec = []
start = time.time()
for epoch in range(1, TRAIN_EPISODES+1):
    net.train()
    category, line, category_tensor, line_tensor = randomPair('train')
    label_one_hot = F.one_hot(category_tensor.to(int), n_categories).float()

    optimizer.zero_grad()
    out_prob_log = net(line_tensor)
    loss = F.mse_loss(out_prob_log, label_one_hot)
    loss.backward()
    optimizer.step()

    # 优化一次参数后,需要重置网络的状态。是否需要?结果差别不明显!(2020.11.3)
    # functional.reset_net(net)

    current_loss += loss.data.item()

    guess, _ = categoryFromOutput(out_prob_log.data)
    if guess == category:
        correct_num += 1

    # Add current loss avg to list of losses
    if epoch % plot_every == 0:
        avg_losses.append(current_loss / plot_every)
        accuracy_rec.append(correct_num / plot_every)
        current_loss = 0
        correct_num = 0

    # 每训练一定次数即进行一次测试
    if epoch % plot_every == 0:  # int(TRAIN_EPISODES/1000)
        net.eval()
        with torch.no_grad():
            numCorrect = 0
            for i in range(n_categories):
                category = all_categories[i]
                for tname in category_lines_test[category]:
                    output = net(lineToTensor(tname))
                    # 运行一次后,需要重置网络的状态。是否需要?
                    # functional.reset_net(net)
                    guess, _ = categoryFromOutput(output.data)
                    if guess == category:
                        numCorrect += 1
            test_accu = numCorrect / testNumtot
            test_accu_rec.append(test_accu)
            print('Epoch %d %d%% (%s); Avg_loss %.4f; Train accuracy %.4f; Test accuracy %.4f' % (
                epoch, epoch / TRAIN_EPISODES * 100, timeSince(start), avg_losses[-1], accuracy_rec[-1], test_accu))

torch.save(net, 'char_rnn_classification.pth')
np.save('avg_losses.npy', np.array(avg_losses))
np.save('accuracy_rec.npy', np.array(accuracy_rec))
np.save('test_accu_rec.npy', np.array(test_accu_rec))
np.save('category_lines_train.npy', category_lines_train, allow_pickle=True)
np.save('category_lines_test.npy', category_lines_test, allow_pickle=True)
# x = np.load('category_lines_test.npy', allow_pickle=True)  # 读取数据的方法
# xdict = x.item()

plt.figure()
plt.subplot(311)
plt.plot(avg_losses)
plt.title('Average loss')
plt.subplot(312)
plt.plot(accuracy_rec)
plt.title('Train accuracy')
plt.subplot(313)
plt.plot(test_accu_rec)
plt.title('Test accuracy')
plt.xlabel('Epoch (*1000)')
plt.subplots_adjust(hspace=0.6)
plt.savefig('TrainingProcess.svg')
plt.close()

设定 IF_TRAIN = 1 ,在Python Console中运行 %run ./spiking_lstm_text.py ,输出如下:

Backend Qt5Agg is interactive backend. Turning interactive mode on.
Training...
Epoch 1000 0% (0m 18s); Avg_loss 0.0525; Train accuracy 0.0830; Test accuracy 0.0806
Epoch 2000 0% (0m 37s); Avg_loss 0.0514; Train accuracy 0.1470; Test accuracy 0.1930
Epoch 3000 0% (0m 55s); Avg_loss 0.0503; Train accuracy 0.1650; Test accuracy 0.0537
Epoch 4000 0% (1m 14s); Avg_loss 0.0494; Train accuracy 0.1920; Test accuracy 0.0938
...
...
Epoch 998000 99% (318m 54s); Avg_loss 0.0063; Train accuracy 0.9300; Test accuracy 0.5036
Epoch 999000 99% (319m 14s); Avg_loss 0.0056; Train accuracy 0.9380; Test accuracy 0.5004
Epoch 1000000 100% (319m 33s); Avg_loss 0.0055; Train accuracy 0.9340; Test accuracy 0.5118

下图展示了训练过程中损失函数、测试集准确率、测试集准确率随时间的变化。 值得注意的一点是,测试表明,在当前Spiking LSTM网络中是否在一次运行完成后重置网络 functional.reset_net(net) 对于结果没有显著的影响。 我们猜测是因为当前网络输入是随时间变化的,而且网络自身需要运行一段时间后才会输出分类结果,因此网络初始状态影响不显著。

../_images/TrainingProcess.svg

网络测试

在测试过程中,我们首先需要导入训练完成后存储的网络,随后进行三方面的测试:(1)计算最终的测试集准确率;(2)让用户输入姓氏拼写以预测其属于哪种语言; (3)计算Confusion matrix,每一行表示当样本源于某一个类别时,网络预测其属于各类别的概率,即对角线表示预测正确的概率。

# IF_TRAIN = 0
print('Testing...')

net = torch.load('char_rnn_classification.pth')

# 遍历测试集计算准确率
print('Calculating testing accuracy...')
numCorrect = 0
for i in range(n_categories):
    category = all_categories[i]
    for tname in category_lines_test[category]:
        output = net(lineToTensor(tname))
        # 运行一次后,需要重置网络的状态。是否需要?
        # functional.reset_net(net)
        guess, _ = categoryFromOutput(output.data)
        if guess == category:
            numCorrect += 1
test_accu = numCorrect / testNumtot
print('Test accuracy: {:.3f}, Random guess: {:.3f}'.format(test_accu, 1/n_categories))

# 让用户输入姓氏以判断其属于哪种语系
n_predictions = 3
for j in range(3):
    first_name = input('请输入一个姓氏以判断其属于哪种语系:')
    print('\n> %s' % first_name)
    output = net(lineToTensor(first_name))
    # 运行一次后,需要重置网络的状态。是否需要?
    # functional.reset_net(net)
    # Get top N categories
    topv, topi = output.topk(n_predictions, 1, True)
    predictions = []

    for i in range(n_predictions):
        value = topv[0][i].item()
        category_index = topi[0][i].item()
        print('(%.2f) %s' % (value, all_categories[category_index]))
        predictions.append([value, all_categories[category_index]])

# 计算confusion矩阵
print('Calculating confusion matrix...')
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000

# Keep track of correct guesses in a confusion matrix
for i in range(n_confusion):
    category, line, category_tensor, line_tensor = randomPair('all')
    output = net(line_tensor)
    # 运行一次后,需要重置网络的状态。是否需要?
    # functional.reset_net(net)
    guess, guess_i = categoryFromOutput(output.data)
    category_i = all_categories.index(category)
    confusion[category_i][guess_i] += 1

confusion = confusion / confusion.sum(1)
np.save('confusion.npy', confusion)

# Set up plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2
plt.show()
plt.savefig('ConfusionMatrix.svg')
plt.close()

设定 IF_TRAIN = 0,在Python Console中运行 %run ./spiking_lstm_text.py,输出如下:

Testing...
Calculating testing accuracy...
Test accuracy: 0.512, Random guess: 0.056
请输入一个姓氏以判断其属于哪种语系:> YU
> YU
(0.18) Scottish
(0.12) English
(0.11) Italian
请输入一个姓氏以判断其属于哪种语系:> Yu
> Yu
(0.63) Chinese
(0.23) Korean
(0.07) Vietnamese
请输入一个姓氏以判断其属于哪种语系:> Zou
> Zou
(1.00) Chinese
(0.00) Arabic
(0.00) Polish
Calculating confusion matrix...

下图展示了Confusion matrix。对角线越亮,表示模型对某一类别预测最好,很少产生混淆,如Arabic和Greek。 而有的语言则较容易产生混淆,如Korean和Chinese,Spanish和Portuguese,English和Scottish。

../_images/ConfusionMatrix.svg