Classifying Names with a Character-level Spiking LSTM

Authors: LiutaoYu, fangwei123456

This tutorial applies a Spiking LSTM to reproduce the PyTorch official tutorial NLP From Scratch: Classifying Names with a Character-Level RNN. Please make sure that you have read the original tutorial and corresponding codes before proceeding. Specifically, we will train a spiking LSTM to classify surnames into different languages according to their spelling, based on a dataset consisting of several thousands of surnames from 18 languages of origin. The integrated script can be found here ( clock_driven/examples/

Preparing the data

First of all, we need to download and preprocess the data as the original tutorial, which produces a dictionary {language: [names ...]} . Then, we split the dataset into a training set and a testing set (the ratio is 4:1), i.e., category_lines_train and category_lines_test . Here, we emphasize several important variables: all_categories is the list of 18 languages, the length of which is n_categories=18; n_letters=58 is the number of all characters composing the surnames.

# split the data into training set and testing set
numExamplesPerCategory = []
category_lines_train = {}
category_lines_test = {}
testNumtot = 0
for c, names in category_lines.items():
    category_lines_train[c] = names[:int(len(names)*0.8)]
    category_lines_test[c] = names[int(len(names)*0.8):]
    numExamplesPerCategory.append([len(category_lines[c]), len(category_lines_train[c]), len(category_lines_test[c])])
    testNumtot += len(category_lines_test[c])

In addition, we rephrase the function randomTrainingExample() to function randomPair(sampleSource) for different conditions. Here we adopt function lineToTensor() and randomChoice() from the original tutorial. lineToTensor() converts a surname into a one-hot tensor, and randomChoice() randomly choose a sample from the dataset.

# Preparing [x, y] pair
def randomPair(sampleSource):
        sampleSource:  'train', 'test', 'all'
        category, line, category_tensor, line_tensor
    category = randomChoice(all_categories)
    if sampleSource == 'train':
        line = randomChoice(category_lines_train[category])
    elif sampleSource == 'test':
        line = randomChoice(category_lines_test[category])
    elif sampleSource == 'all':
        line = randomChoice(category_lines[category])
    category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.float)
    line_tensor = lineToTensor(line)
    return category, line, category_tensor, line_tensor

Building a spiking LSTM network

We build a spiking LSTM based on the rnn module from spikingjelly . The theory can be found in the paper Long Short-Term Memory Spiking Networks and Their Applications . The amounts of neurons in the input layer, hidden layer and output layer are n_letters, n_hidden and n_categories respectively. We add a fully connected layer to the output layer, and use softmax function to obtain the classification probability.

from spikingjelly.clock_driven import rnn
n_hidden = 256

class Net(nn.Module):
    def __init__(self, n_letters, n_hidden, n_categories):
        self.n_input = n_letters
        self.n_hidden = n_hidden
        self.n_out = n_categories
        self.lstm = rnn.SpikingLSTM(self.n_input, self.n_hidden, 1)
        self.fc = nn.Linear(self.n_hidden, self.n_out)

    def forward(self, x):
        x, _ = self.lstm(x)
        output = self.fc(x[-1])
        output = F.softmax(output, dim=1)
        return output

Training the network

First of all, we initialize the net , and define parameters like TRAIN_EPISODES and learning_rate. Here we adopt mse_loss and Adam optimizer to train the network. The process of one training epoch is as follows: 1) randomly choose a sample from the training set, and convert the input and label into tensors; 2) feed the input to the network, and obtain the classification probability through the forward process; 3) calculate the network loss through mse_loss; 4) back-propagate the gradients, and update the training parameters; 5) judge whether the prediction is correct or not, and count the number of correct predictions to obtain the training accuracy every plot_every epochs; 6) evaluate the network on the testing set every plot_every epochs to obtain the testing accuracy. During training, we record the history of network loss avg_losses , training accuracy accuracy_rec and testing accuracy test_accu_rec , to observe the training process. After training, we will save the final state of the network for testing, and also some variables for later analyses.

# IF_TRAIN = 1
plot_every = 1000
learning_rate = 1e-4

net = Net(n_letters, n_hidden, n_categories)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

current_loss = 0
correct_num = 0
avg_losses = []
accuracy_rec = []
test_accu_rec = []
start = time.time()
for epoch in range(1, TRAIN_EPISODES+1):
    category, line, category_tensor, line_tensor = randomPair('train')
    label_one_hot = F.one_hot(, n_categories).float()

    out_prob_log = net(line_tensor)
    loss = F.mse_loss(out_prob_log, label_one_hot)

    current_loss +=

    guess, _ = categoryFromOutput(
    if guess == category:
        correct_num += 1

    # Add current loss avg to list of losses
    if epoch % plot_every == 0:
        avg_losses.append(current_loss / plot_every)
        accuracy_rec.append(correct_num / plot_every)
        current_loss = 0
        correct_num = 0

    # evaluate the network on the testing set every ``plot_every`` epochs to obtain the testing accuracy
    if epoch % plot_every == 0:  # int(TRAIN_EPISODES/1000)
        with torch.no_grad():
            numCorrect = 0
            for i in range(n_categories):
                category = all_categories[i]
                for tname in category_lines_test[category]:
                    output = net(lineToTensor(tname))
                    guess, _ = categoryFromOutput(
                    if guess == category:
                        numCorrect += 1
            test_accu = numCorrect / testNumtot
            print('Epoch %d %d%% (%s); Avg_loss %.4f; Train accuracy %.4f; Test accuracy %.4f' % (
                epoch, epoch / TRAIN_EPISODES * 100, timeSince(start), avg_losses[-1], accuracy_rec[-1], test_accu)), 'char_rnn_classification.pth')'avg_losses.npy', np.array(avg_losses))'accuracy_rec.npy', np.array(accuracy_rec))'test_accu_rec.npy', np.array(test_accu_rec))'category_lines_train.npy', category_lines_train, allow_pickle=True)'category_lines_test.npy', category_lines_test, allow_pickle=True)
# x = np.load('category_lines_test.npy', allow_pickle=True)  # way to loading the data
# xdict = x.item()

plt.title('Average loss')
plt.title('Train accuracy')
plt.title('Test accuracy')
plt.xlabel('Epoch (*1000)')

We will observe the following results when executing %run ./ in Python Console with IF_TRAIN = 1 .

Backend Qt5Agg is interactive backend. Turning interactive mode on.
Epoch 1000 0% (0m 18s); Avg_loss 0.0525; Train accuracy 0.0830; Test accuracy 0.0806
Epoch 2000 0% (0m 37s); Avg_loss 0.0514; Train accuracy 0.1470; Test accuracy 0.1930
Epoch 3000 0% (0m 55s); Avg_loss 0.0503; Train accuracy 0.1650; Test accuracy 0.0537
Epoch 4000 0% (1m 14s); Avg_loss 0.0494; Train accuracy 0.1920; Test accuracy 0.0938
Epoch 998000 99% (318m 54s); Avg_loss 0.0063; Train accuracy 0.9300; Test accuracy 0.5036
Epoch 999000 99% (319m 14s); Avg_loss 0.0056; Train accuracy 0.9380; Test accuracy 0.5004
Epoch 1000000 100% (319m 33s); Avg_loss 0.0055; Train accuracy 0.9340; Test accuracy 0.5118

The following picture shows how average loss avg_losses , training accuracy accuracy_rec and testing accuracy test_accu_rec improve with training.


Testing the network

We first load the well-trained network, and then conduct the following tests: 1) calculate the testing accuracy of the final network; 2) predict the language origin of the surnames provided by the user; 3) calculate the confusion matrix, indicating for every actual language (rows) which language the network guesses (columns).

# IF_TRAIN = 0

net = torch.load('char_rnn_classification.pth')

# calculate the testing accuracy of the final network
print('Calculating testing accuracy...')
numCorrect = 0
for i in range(n_categories):
    category = all_categories[i]
    for tname in category_lines_test[category]:
        output = net(lineToTensor(tname))
        guess, _ = categoryFromOutput(
        if guess == category:
            numCorrect += 1
test_accu = numCorrect / testNumtot
print('Test accuracy: {:.3f}, Random guess: {:.3f}'.format(test_accu, 1/n_categories))

# predict the language origin of the surnames provided by the user
n_predictions = 3
for j in range(3):
    first_name = input('Please input a surname to predict its language origin:')
    print('\n> %s' % first_name)
    output = net(lineToTensor(first_name))

    # Get top N categories
    topv, topi = output.topk(n_predictions, 1, True)
    predictions = []

    for i in range(n_predictions):
        value = topv[0][i].item()
        category_index = topi[0][i].item()
        print('(%.2f) %s' % (value, all_categories[category_index]))
        predictions.append([value, all_categories[category_index]])

# calculate the confusion matrix
print('Calculating confusion matrix...')
confusion = torch.zeros(n_categories, n_categories)
n_confusion = 10000

# Keep track of correct guesses in a confusion matrix
for i in range(n_confusion):
    category, line, category_tensor, line_tensor = randomPair('all')
    output = net(line_tensor)
    guess, guess_i = categoryFromOutput(
    category_i = all_categories.index(category)
    confusion[category_i][guess_i] += 1

confusion = confusion / confusion.sum(1)'confusion.npy', confusion)

# Set up plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111)
cax = ax.matshow(confusion.numpy())
# Set up axes
ax.set_xticklabels([''] + all_categories, rotation=90)
ax.set_yticklabels([''] + all_categories)
# Force label at every tick
# sphinx_gallery_thumbnail_number = 2

We will observe the following results when executing %run ./ in Python Console with IF_TRAIN = 0 .

Calculating testing accuracy...
Test accuracy: 0.512, Random guess: 0.056
Please input a surname to predict its language origin:> YU
> YU
(0.18) Scottish
(0.12) English
(0.11) Italian
Please input a surname to predict its language origin:> Yu
> Yu
(0.63) Chinese
(0.23) Korean
(0.07) Vietnamese
Please input a surname to predict its language origin:> Zou
> Zou
(1.00) Chinese
(0.00) Arabic
(0.00) Polish
Calculating confusion matrix...

The following picture exhibits the confusion matrix, of which a brighter diagonal element indicates better prediction, and thus less confusion, such as Arabic and Greek. However, some languages are prone to confusion, such as Korean and Chinese, English and Scottish.