5.2 RNN训练莎士比亚语录

和上次训练过程一样,这次我们用莎士比亚的语录再来试一次。

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import regularizers
import tensorflow.keras.utils as ku 
import numpy as np 

预处理

tokenizer = Tokenizer()
# !wget --no-check-certificate \
#     https://storage.googleapis.com/laurencemoroney-blog.appspot.com/sonnets.txt \
#     -O sonnets.txt
data = open('../../tensorflow_datasets/sonnets.txt').read()

corpus = data.lower().split("\n")

tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1

# create input sequences using list of tokens
input_sequences = []
for line in corpus:
    token_list = tokenizer.texts_to_sequences([line])[0]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i+1]
        input_sequences.append(n_gram_sequence)

# pad sequences 
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))

# create predictors and label
predictors, label = input_sequences[:,:-1],input_sequences[:,-1]

label = ku.to_categorical(label, num_classes=total_words)

模型定义

在嵌入层之后,我们采用双层的LSTM.这样做效果会更好,也比较不容易出现过拟合的现象。

model = Sequential()
model.add(Embedding(total_words, 100, input_length=max_sequence_len-1))
model.add(Bidirectional(LSTM(150, return_sequences = True)))
model.add(Dropout(0.2))
model.add(LSTM(100))
model.add(Dense(total_words/2, activation='relu', kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(total_words, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
print(model.summary())

如下表,我们看到这里面的参数相对于之前遇到的已经很多了。

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding (Embedding)        (None, 10, 100)           321100    
_________________________________________________________________
bidirectional (Bidirectional (None, 10, 300)           301200    
_________________________________________________________________
dropout (Dropout)            (None, 10, 300)           0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 100)               160400    
_________________________________________________________________
dense (Dense)                (None, 1605)              162105    
_________________________________________________________________
dense_1 (Dense)              (None, 3211)              5156866   
=================================================================
Total params: 6,101,671
Trainable params: 6,101,671
Non-trainable params: 0
_________________________________________________________________

训练

这个肯定很费时间,我们先训练50次看看准确度是否能收敛。

 history = model.fit(predictors, label, epochs=50, verbose=1)
Epoch 1/50
484/484 [==============================] - 60s 124ms/step - loss: 6.9128 - accuracy: 0.0225
Epoch 2/50
484/484 [==============================] - 58s 120ms/step - loss: 6.4965 - accuracy: 0.0224
Epoch 3/50
484/484 [==============================] - 58s 119ms/step - loss: 6.4036 - accuracy: 0.0248
Epoch 4/50
484/484 [==============================] - 58s 120ms/step - loss: 6.2776 - accuracy: 0.0312
Epoch 5/50
484/484 [==============================] - 57s 119ms/step - loss: 6.1777 - accuracy: 0.0359
Epoch 6/50
484/484 [==============================] - 57s 118ms/step - loss: 6.0908 - accuracy: 0.0384
Epoch 7/50
484/484 [==============================] - 57s 119ms/step - loss: 6.0096 - accuracy: 0.0429
Epoch 8/50
484/484 [==============================] - 57s 117ms/step - loss: 5.9228 - accuracy: 0.0464
Epoch 9/50
484/484 [==============================] - 56s 117ms/step - loss: 5.8225 - accuracy: 0.0523
Epoch 10/50
484/484 [==============================] - 57s 119ms/step - loss: 5.7132 - accuracy: 0.0578
Epoch 11/50
484/484 [==============================] - 57s 117ms/step - loss: 5.6026 - accuracy: 0.0685
Epoch 12/50
484/484 [==============================] - 57s 118ms/step - loss: 5.4933 - accuracy: 0.0739
Epoch 13/50
484/484 [==============================] - 58s 119ms/step - loss: 5.3858 - accuracy: 0.0817
Epoch 14/50
484/484 [==============================] - 56s 116ms/step - loss: 5.2785 - accuracy: 0.0871
Epoch 15/50
484/484 [==============================] - 57s 117ms/step - loss: 5.1774 - accuracy: 0.0965
Epoch 16/50
484/484 [==============================] - 57s 118ms/step - loss: 5.0682 - accuracy: 0.1021
Epoch 17/50
484/484 [==============================] - 57s 118ms/step - loss: 4.9694 - accuracy: 0.1097
Epoch 18/50
484/484 [==============================] - 57s 118ms/step - loss: 4.8685 - accuracy: 0.1172
Epoch 19/50
484/484 [==============================] - 57s 118ms/step - loss: 4.7710 - accuracy: 0.1266
Epoch 20/50
484/484 [==============================] - 57s 118ms/step - loss: 4.6725 - accuracy: 0.1335
Epoch 21/50
484/484 [==============================] - 56s 116ms/step - loss: 4.5686 - accuracy: 0.1453
Epoch 22/50
484/484 [==============================] - 57s 118ms/step - loss: 4.4746 - accuracy: 0.1539
Epoch 23/50
484/484 [==============================] - 58s 119ms/step - loss: 4.3686 - accuracy: 0.1669
Epoch 24/50
484/484 [==============================] - 46s 95ms/step - loss: 4.2699 - accuracy: 0.1787
Epoch 25/50
484/484 [==============================] - 31s 64ms/step - loss: 4.1703 - accuracy: 0.1931
Epoch 26/50
484/484 [==============================] - 30s 62ms/step - loss: 4.0734 - accuracy: 0.2046
Epoch 27/50
484/484 [==============================] - 31s 65ms/step - loss: 3.9779 - accuracy: 0.2160
Epoch 28/50
484/484 [==============================] - 31s 64ms/step - loss: 3.8763 - accuracy: 0.2354
Epoch 29/50
484/484 [==============================] - 31s 65ms/step - loss: 3.7813 - accuracy: 0.2498
Epoch 30/50
484/484 [==============================] - 31s 65ms/step - loss: 3.6894 - accuracy: 0.2655
Epoch 31/50
484/484 [==============================] - 31s 64ms/step - loss: 3.5991 - accuracy: 0.2870
Epoch 32/50
484/484 [==============================] - 30s 62ms/step - loss: 3.5110 - accuracy: 0.3022
Epoch 33/50
484/484 [==============================] - 30s 62ms/step - loss: 3.4262 - accuracy: 0.3203
Epoch 34/50
484/484 [==============================] - 29s 61ms/step - loss: 3.3444 - accuracy: 0.3359
Epoch 35/50
484/484 [==============================] - 31s 63ms/step - loss: 3.2573 - accuracy: 0.3603
Epoch 36/50
484/484 [==============================] - 31s 63ms/step - loss: 3.1833 - accuracy: 0.3740
Epoch 37/50
484/484 [==============================] - 31s 63ms/step - loss: 3.1027 - accuracy: 0.3953
Epoch 38/50
484/484 [==============================] - 30s 61ms/step - loss: 3.0389 - accuracy: 0.4075
Epoch 39/50
484/484 [==============================] - 30s 62ms/step - loss: 2.9521 - accuracy: 0.4290
Epoch 40/50
484/484 [==============================] - 31s 64ms/step - loss: 2.8959 - accuracy: 0.4430
Epoch 41/50
484/484 [==============================] - 30s 62ms/step - loss: 2.8206 - accuracy: 0.4601
Epoch 42/50
484/484 [==============================] - 31s 63ms/step - loss: 2.7585 - accuracy: 0.4712
Epoch 43/50
484/484 [==============================] - 31s 64ms/step - loss: 2.7074 - accuracy: 0.4788
Epoch 44/50
484/484 [==============================] - 30s 63ms/step - loss: 2.6470 - accuracy: 0.4965
Epoch 45/50
484/484 [==============================] - 30s 63ms/step - loss: 2.5871 - accuracy: 0.5069
Epoch 46/50
484/484 [==============================] - 30s 62ms/step - loss: 2.5337 - accuracy: 0.5217
Epoch 47/50
484/484 [==============================] - 30s 62ms/step - loss: 2.4722 - accuracy: 0.5345
Epoch 48/50
484/484 [==============================] - 30s 63ms/step - loss: 2.4228 - accuracy: 0.5470
Epoch 49/50
484/484 [==============================] - 31s 64ms/step - loss: 2.3668 - accuracy: 0.5609
Epoch 50/50
484/484 [==============================] - 31s 64ms/step - loss: 2.3140 - accuracy: 0.5737

我们发现一轮需要半分钟左右,而且最后这个准确率还在平稳上升,所以我们再训练它50个周期。

Epoch 1/50
484/484 [==============================] - 31s 64ms/step - loss: 2.2824 - accuracy: 0.5778
Epoch 2/50
484/484 [==============================] - 28s 58ms/step - loss: 2.2269 - accuracy: 0.5935
Epoch 3/50
484/484 [==============================] - 30s 62ms/step - loss: 2.1931 - accuracy: 0.6006
Epoch 4/50
484/484 [==============================] - 30s 62ms/step - loss: 2.1484 - accuracy: 0.6068
Epoch 5/50
484/484 [==============================] - 31s 64ms/step - loss: 2.1073 - accuracy: 0.6184
Epoch 6/50
484/484 [==============================] - 30s 62ms/step - loss: 2.0684 - accuracy: 0.6259
Epoch 7/50
484/484 [==============================] - 31s 64ms/step - loss: 2.0184 - accuracy: 0.6375
Epoch 8/50
484/484 [==============================] - 31s 64ms/step - loss: 1.9755 - accuracy: 0.6400
Epoch 9/50
484/484 [==============================] - 31s 64ms/step - loss: 1.9520 - accuracy: 0.6526
Epoch 10/50
484/484 [==============================] - 31s 64ms/step - loss: 1.9145 - accuracy: 0.6569
Epoch 11/50
484/484 [==============================] - 30s 62ms/step - loss: 1.8868 - accuracy: 0.6674
Epoch 12/50
484/484 [==============================] - 31s 64ms/step - loss: 1.8534 - accuracy: 0.6689
Epoch 13/50
484/484 [==============================] - 31s 64ms/step - loss: 1.8297 - accuracy: 0.6749
Epoch 14/50
484/484 [==============================] - 31s 64ms/step - loss: 1.7974 - accuracy: 0.6817
Epoch 15/50
484/484 [==============================] - 31s 64ms/step - loss: 1.7615 - accuracy: 0.6886
Epoch 16/50
484/484 [==============================] - 31s 65ms/step - loss: 1.7234 - accuracy: 0.6991
Epoch 17/50
484/484 [==============================] - 31s 64ms/step - loss: 1.7151 - accuracy: 0.6973
Epoch 18/50
484/484 [==============================] - 31s 64ms/step - loss: 1.6886 - accuracy: 0.7032
Epoch 19/50
484/484 [==============================] - 31s 64ms/step - loss: 1.6563 - accuracy: 0.7098
Epoch 20/50
484/484 [==============================] - 32s 65ms/step - loss: 1.6339 - accuracy: 0.7123
Epoch 21/50
484/484 [==============================] - 31s 65ms/step - loss: 1.6079 - accuracy: 0.7217
Epoch 22/50
484/484 [==============================] - 31s 65ms/step - loss: 1.5861 - accuracy: 0.7246
Epoch 23/50
484/484 [==============================] - 31s 65ms/step - loss: 1.5578 - accuracy: 0.7312
Epoch 24/50
484/484 [==============================] - 31s 64ms/step - loss: 1.5515 - accuracy: 0.7273
Epoch 25/50
484/484 [==============================] - 29s 61ms/step - loss: 1.5161 - accuracy: 0.7370
Epoch 26/50
484/484 [==============================] - 31s 64ms/step - loss: 1.5051 - accuracy: 0.7378
Epoch 27/50
484/484 [==============================] - 31s 63ms/step - loss: 1.4856 - accuracy: 0.7436
Epoch 28/50
484/484 [==============================] - 31s 64ms/step - loss: 1.4528 - accuracy: 0.7508
Epoch 29/50
484/484 [==============================] - 30s 63ms/step - loss: 1.4528 - accuracy: 0.7452
Epoch 30/50
484/484 [==============================] - 30s 62ms/step - loss: 1.4263 - accuracy: 0.7517
Epoch 31/50
484/484 [==============================] - 31s 64ms/step - loss: 1.4123 - accuracy: 0.7516
Epoch 32/50
484/484 [==============================] - 30s 62ms/step - loss: 1.3868 - accuracy: 0.7628
Epoch 33/50
484/484 [==============================] - 30s 61ms/step - loss: 1.3901 - accuracy: 0.7581
Epoch 34/50
484/484 [==============================] - 31s 64ms/step - loss: 1.3698 - accuracy: 0.7630
Epoch 35/50
484/484 [==============================] - 31s 64ms/step - loss: 1.3452 - accuracy: 0.7687
Epoch 36/50
484/484 [==============================] - 31s 63ms/step - loss: 1.3217 - accuracy: 0.7753
Epoch 37/50
484/484 [==============================] - 31s 64ms/step - loss: 1.3269 - accuracy: 0.7700
Epoch 38/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2993 - accuracy: 0.7776
Epoch 39/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2739 - accuracy: 0.7835
Epoch 40/50
484/484 [==============================] - 31s 63ms/step - loss: 1.2822 - accuracy: 0.7753
Epoch 41/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2700 - accuracy: 0.7792
Epoch 42/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2443 - accuracy: 0.7868
Epoch 43/50
484/484 [==============================] - 31s 63ms/step - loss: 1.2436 - accuracy: 0.7855
Epoch 44/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2260 - accuracy: 0.7871
Epoch 45/50
484/484 [==============================] - 31s 63ms/step - loss: 1.2118 - accuracy: 0.7916
Epoch 46/50
484/484 [==============================] - 31s 64ms/step - loss: 1.2227 - accuracy: 0.7868
Epoch 47/50
484/484 [==============================] - 31s 63ms/step - loss: 1.1917 - accuracy: 0.7954
Epoch 48/50
484/484 [==============================] - 31s 64ms/step - loss: 1.1881 - accuracy: 0.7944
Epoch 49/50
484/484 [==============================] - 31s 64ms/step - loss: 1.1712 - accuracy: 0.7978
Epoch 50/50
484/484 [==============================] - 29s 61ms/step - loss: 1.1657 - accuracy: 0.7954

我们发现最后的训练准确率差不多已经稳定,再训练可能只会有微小提升。所以我们选择中止训练。

作图分析

import matplotlib.pyplot as plt
acc = history.history['accuracy']
loss = history.history['loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.title('Training accuracy')

plt.figure()

plt.plot(epochs, loss, 'b', label='Training Loss')
plt.title('Training loss')
plt.legend()

plt.show()

这里我们只有后50次训练的记录了。我们可以看到准确率平稳上升,损失也在平稳下降,这是我们想要的。

预测

发现结果还是不太理想...有待精进!

seed_text = "i have a dream"
next_words = 100

for _ in range(next_words):
    token_list = tokenizer.texts_to_sequences([seed_text])[0]
    token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
    predicted = model.predict_classes(token_list, verbose=0)
    output_word = ""
    for word, index in tokenizer.word_index.items():
        if index == predicted:
            output_word = word
            break
    seed_text += " " + output_word
print(seed_text)
i have a dream and break do they be sit prove forth up words brow sense prove heaven should stand set men days away rare taken thence lust 'tis quite sit stand ere so behind borne behind free borne forth forth forth pry wrong ' afloat behind behind cross cross cross cross give behind cross fill that if you so much leaves near dwell away sorrow grew with quite endured hits hits stand bred ere behind behind behind defence sit sit sit defence age happy kind 'tis bright eyes back be often lived doth cross bright quite staineth hits hits stand bred so behind