Classifying Names with a Character-level Spiking LSTM
Authors: LiutaoYu, fangwei123456
This tutorial applies a Spiking LSTM to reproduce the PyTorch official tutorial NLP From Scratch: Classifying Names with a Character-Level RNN. Please make sure that you have read the original tutorial and corresponding codes before proceeding. Specifically, we will train a spiking LSTM to classify surnames into different languages according to their spelling, based on a dataset consisting of several thousands of surnames from 18 languages of origin. The integrated script can be found here ( clock_driven/examples/spiking_lstm_text.py).
Preparing the data
First of all, we need to download and preprocess the data as the original tutorial, which produces a dictionary {language: [names ...]}
.
Then, we split the dataset into a training set and a testing set (the ratio is 4:1), i.e., category_lines_train
and category_lines_test
.
Here, we emphasize several important variables: all_categories
is the list of 18 languages, the length of which is n_categories=18
;
n_letters=58
is the number of all characters composing the surnames.
# split the data into training set and testing set
numExamplesPerCategory = []
category_lines_train = {}
category_lines_test = {}
testNumtot = 0
for c, names in category_lines.items():
category_lines_train[c] = names[:int(len(names)*0.8)]
category_lines_test[c] = names[int(len(names)*0.8):]
numExamplesPerCategory.append([len(category_lines[c]), len(category_lines_train[c]), len(category_lines_test[c])])
testNumtot += len(category_lines_test[c])
In addition, we rephrase the function randomTrainingExample()
to function randomPair(sampleSource)
for different conditions.
Here we adopt function lineToTensor()
and randomChoice()
from the original tutorial.
lineToTensor()
converts a surname into a one-hot tensor, and randomChoice()
randomly choose a sample from the dataset.
# Preparing [x, y] pair
def randomPair(sampleSource):
"""
Args:
sampleSource: 'train', 'test', 'all'
Returns:
category, line, category_tensor, line_tensor
"""
category = randomChoice(all_categories)
if sampleSource == 'train':
line = randomChoice(category_lines_train[category])
elif sampleSource == 'test':
line = randomChoice(category_lines_test[category])
elif sampleSource == 'all':
line = randomChoice(category_lines[category])
category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.float)
line_tensor = lineToTensor(line)
return category, line, category_tensor, line_tensor
Building a spiking LSTM network
We build a spiking LSTM based on the rnn
module from spikingjelly .
The theory can be found in the paper Long Short-Term Memory Spiking Networks and Their Applications .
The amounts of neurons in the input layer, hidden layer and output layer are n_letters
, n_hidden
and n_categories
respectively.
We add a fully connected layer to the output layer, and use softmax
function to obtain the classification probability.
from spikingjelly.clock_driven import rnn
n_hidden = 256
class Net(nn.Module):
def __init__(self, n_letters, n_hidden, n_categories):
super().__init__()
self.n_input = n_letters
self.n_hidden = n_hidden
self.n_out = n_categories
self.lstm = rnn.SpikingLSTM(self.n_input, self.n_hidden, 1)
self.fc = nn.Linear(self.n_hidden, self.n_out)
def forward(self, x):
x, _ = self.lstm(x)
output = self.fc(x[-1])
output = F.softmax(output, dim=1)
return output
Training the network
First of all, we initialize the net
, and define parameters like TRAIN_EPISODES
and learning_rate
.
Here we adopt mse_loss
and Adam
optimizer to train the network.
The process of one training epoch is as follows:
1) randomly choose a sample from the training set, and convert the input and label into tensors;
2) feed the input to the network, and obtain the classification probability through the forward process;
3) calculate the network loss through mse_loss
;
4) back-propagate the gradients, and update the training parameters;
5) judge whether the prediction is correct or not, and count the number of correct predictions to obtain the training accuracy every plot_every
epochs;
6) evaluate the network on the testing set every plot_every
epochs to obtain the testing accuracy.
During training, we record the history of network loss avg_losses
, training accuracy accuracy_rec
and testing accuracy test_accu_rec
, to observe the training process.
After training, we will save the final state of the network for testing, and also some variables for later analyses.
# IF_TRAIN = 1
TRAIN_EPISODES = 1000000
plot_every = 1000
learning_rate = 1e-4
net = Net(n_letters, n_hidden, n_categories)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
print('Training...')
current_loss = 0
correct_num = 0
avg_losses = []
accuracy_rec = []
test_accu_rec = []
start = time.time()
for epoch in range(1, TRAIN_EPISODES+1):
net.train()
category, line, category_tensor, line_tensor = randomPair('train')
label_one_hot = F.one_hot(category_tensor.to(int), n_categories).float()
optimizer.zero_grad()
out_prob_log = net(line_tensor)
loss = F.mse_loss(out_prob_log, label_one_hot)
loss.backward()
optimizer.step()
current_loss += loss.data.item()
guess, _ = categoryFromOutput(out_prob_log.data)
if guess == category:
correct_num += 1
# Add current loss avg to list of losses
if epoch % plot_every == 0:
avg_losses.append(current_loss / plot_every)
accuracy_rec.append(correct_num / plot_every)
current_loss = 0
correct_num = 0
# evaluate the network on the testing set every ``plot_every`` epochs to obtain the testing accuracy
if epoch % plot_every == 0: # int(TRAIN_EPISODES/1000)
net.eval()
with torch.no_grad():
numCorrect = 0
for i in range(n_categories):
category = all_categories[i]
for tname in category_lines_test[category]:
output = net(lineToTensor(tname))
guess, _ = categoryFromOutput(output.data)
if guess == category:
numCorrect += 1
test_accu = numCorrect / testNumtot
test_accu_rec.append(test_accu)
print('Epoch %d %d%% (%s); Avg_loss %.4f; Train accuracy %.4f; Test accuracy %.4f' % (
epoch, epoch / TRAIN_EPISODES * 100, timeSince(start), avg_losses[-1], accuracy_rec[-1], test_accu))
torch.save(net, 'char_rnn_classification.pth')
np.save('avg_losses.npy', np.array(avg_losses))
np.save('accuracy_rec.npy', np.array(accuracy_rec))
np.save('test_accu_rec.npy', np.array(test_accu_rec))
np.save('category_lines_train.npy', category_lines_train, allow_pickle=True)
np.save('category_lines_test.npy', category_lines_test, allow_pickle=True)
# x = np.load('category_lines_test.npy', allow_pickle=True) # way to loading the data
# xdict = x.item()
plt.figure()
plt.subplot(311)
plt.plot(avg_losses)
plt.title('Average loss')
plt.subplot(312)
plt.plot(accuracy_rec)
plt.title('Train accuracy')
plt.subplot(313)
plt.plot(test_accu_rec)
plt.title('Test accuracy')
plt.xlabel('Epoch (*1000)')
plt.subplots_adjust(hspace=0.6)
plt.savefig('TrainingProcess.svg')
plt.close()
We will observe the following results when executing %run ./spiking_lstm_text.py
in Python Console with IF_TRAIN = 1
.
Backend Qt5Agg is interactive backend. Turning interactive mode on.
Training...
Epoch 1000 0% (0m 18s); Avg_loss 0.0525; Train accuracy 0.0830; Test accuracy 0.0806
Epoch 2000 0% (0m 37s); Avg_loss 0.0514; Train accuracy 0.1470; Test accuracy 0.1930
Epoch 3000 0% (0m 55s); Avg_loss 0.0503; Train accuracy 0.1650; Test accuracy 0.0537
Epoch 4000 0% (1m 14s); Avg_loss 0.0494; Train accuracy 0.1920; Test accuracy 0.0938
...
...
Epoch 998000 99% (318m 54s); Avg_loss 0.0063; Train accuracy 0.9300; Test accuracy 0.5036
Epoch 999000 99% (319m 14s); Avg_loss 0.0056; Train accuracy 0.9380; Test accuracy 0.5004
Epoch 1000000 100% (319m 33s); Avg_loss 0.0055; Train accuracy 0.9340; Test accuracy 0.5118
The following picture shows how average loss avg_losses
, training accuracy accuracy_rec
and testing accuracy test_accu_rec
improve with training.
Testing the network
We first load the well-trained network, and then conduct the following tests: 1) calculate the testing accuracy of the final network; 2) predict the language origin of the surnames provided by the user; 3) calculate the confusion matrix, indicating for every actual language (rows) which language the network guesses (columns).
# IF_TRAIN = 0
print('Testing...')
net = torch.load('char_rnn_classification.pth')
# calculate the testing accuracy of the final network
print('Calculating testing accuracy...')
numCorrect = 0
for i in range(n_categories):
category = all_categories[i]
for tname in category_lines_test[category]:
output = net(lineToTensor(tname))
guess, _ = categoryFromOutput(output.data)
if guess == category:
numCorrect += 1
test_accu = numCorrect / testNumtot
print('Test accuracy: {:.3f}, Random guess: {:.3f}'.format(test_accu, 1/n_categories))
# predict the language origin of the surnames provided by the user
n_predictions = 3
for j in range(3):
first_name = input('Please input a surname to predict its language origin:')
print('\n> %s' % first_name)
output = net(lineToTensor(first_name))
# Get top N categories
topv, topi = output.topk(n_predictions, 1, True)
predictions = []
for i in range(n_predictions):
value = topv[0][i].item()
category_index = topi[0][i].item()
print('(%.2f) %s' % (value, all_categories[category_index]))
predictions.append([value, all_categories[category_index]])
# calculate the confusion matrix
print('Calculating confusion matrix...')
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000
# Keep track of correct guesses in a confusion matrix
for i in range(n_confusion):
category, line, category_tensor, line_tensor = randomPair('all')
output = net(line_tensor)
guess, guess_i = categoryFromOutput(output.data)
category_i = all_categories.index(category)
confusion[category_i][guess_i] += 1
confusion = confusion / confusion.sum(1)
np.save('confusion.npy', confusion)
# Set up plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
fig.colorbar(cax)
# Set up axes
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
# Force label at every tick
ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2
plt.show()
plt.savefig('ConfusionMatrix.svg')
plt.close()
We will observe the following results when executing %run ./spiking_lstm_text.py
in Python Console with IF_TRAIN = 0
.
Testing...
Calculating testing accuracy...
Test accuracy: 0.512, Random guess: 0.056
Please input a surname to predict its language origin:> YU
> YU
(0.18) Scottish
(0.12) English
(0.11) Italian
Please input a surname to predict its language origin:> Yu
> Yu
(0.63) Chinese
(0.23) Korean
(0.07) Vietnamese
Please input a surname to predict its language origin:> Zou
> Zou
(1.00) Chinese
(0.00) Arabic
(0.00) Polish
Calculating confusion matrix...
The following picture exhibits the confusion matrix, of which a brighter diagonal element indicates better prediction, and thus less confusion, such as Arabic and Greek. However, some languages are prone to confusion, such as Korean and Chinese, English and Scottish.