5.1 RNN(文本预测)

RNN原理此处就不过多阐述,本篇主要介绍用tensorflow如何进行RNN训练。

问题描述

给定一个文本文件里面有很多英文句子,现在使用RNN对其进行训练。最后我们给定一个开头让它预测接下来的数个单词。

训练

import tensorflow as tf

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Embedding, LSTM, Dense, Bidirectional
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np 
import time

把里面的词分开,并编码成数字。

tokenizer = Tokenizer()

data = open('../../tensorflow_datasets/irish-lyrics-eof.txt').read()

corpus = data.lower().split("\n")

tokenizer.fit_on_texts(corpus)
total_words = len(tokenizer.word_index) + 1

print(tokenizer.word_index)
print(total_words)

## {'the': 1, 'and': 2, 'i': 3, 'to': 4, 'a': 5, 'of': 6, 'my': 7, 'in': 8, 'me': 9, 'for': 10, 'you': 11, 'all': 12, 'was': 13, 'she': 14, 'that': 15, 'on': 16, 'with': 17, 'her': 18, 'but': 19, 'as': 20, 'when': 21, 'love': 22, 'is': 23, 'your': 24, 'it': 25, 'will': 26, 'from': 27, 'by': 28, 'they': 29,.................'stationed': 2678, 'cork': 2679, 'roamin': 2680, 'swear': 2681, 'treat': 2682, 'sportin': 2683, 'hurley': 2684, 'bollin': 2685, 'maids': 2686, 'summertime': 2687, 'pluck': 2688, 'yon': 2689}
## 2690

把里面的所有句子都序列化。

input_sequences = []
for line in corpus:
    token_list = tokenizer.texts_to_sequences([line])[0]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i+1]
        input_sequences.append(n_gram_sequence)

# pad sequences 
max_sequence_len = max([len(x) for x in input_sequences])
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))

# create predictors and label
xs, labels = input_sequences[:,:-1],input_sequences[:,-1]

ys = tf.keras.utils.to_categorical(labels, num_classes=total_words)

模型定义及训练。

与前面相比就是在嵌入之后换成了LSTM,一种常见的RNN结构。

model = Sequential()
model.add(Embedding(total_words, 100, input_length=max_sequence_len-1))
model.add(Bidirectional(LSTM(150)))
model.add(Dense(total_words, activation='softmax'))
adam = Adam(lr=0.01)
model.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
#earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=5, verbose=0, mode='auto')
history = model.fit(xs, ys, epochs=100, verbose=1)
#print model.summary()
print(model)
Epoch 1/100
377/377 [==============================] - 13s 33ms/step - loss: 6.6735 - accuracy: 0.0696
Epoch 2/100
377/377 [==============================] - 13s 35ms/step - loss: 5.7891 - accuracy: 0.1103
Epoch 3/100
377/377 [==============================] - 13s 34ms/step - loss: 4.9342 - accuracy: 0.1608
Epoch 4/100
377/377 [==============================] - 13s 35ms/step - loss: 4.0518 - accuracy: 0.2279
Epoch 5/100
377/377 [==============================] - 13s 34ms/step - loss: 3.2253 - accuracy: 0.3251
Epoch 6/100
377/377 [==============================] - 13s 34ms/step - loss: 2.5441 - accuracy: 0.4278
Epoch 7/100
377/377 [==============================] - 13s 34ms/step - loss: 2.0503 - accuracy: 0.5236
Epoch 8/100
377/377 [==============================] - 13s 34ms/step - loss: 1.6816 - accuracy: 0.6000
Epoch 9/100
377/377 [==============================] - 13s 35ms/step - loss: 1.4112 - accuracy: 0.6644
Epoch 10/100
377/377 [==============================] - 13s 33ms/step - loss: 1.2757 - accuracy: 0.6892
Epoch 11/100
377/377 [==============================] - 13s 34ms/step - loss: 1.1403 - accuracy: 0.7225
Epoch 12/100
377/377 [==============================] - 13s 35ms/step - loss: 1.0396 - accuracy: 0.7467
Epoch 13/100
377/377 [==============================] - 13s 35ms/step - loss: 0.9845 - accuracy: 0.7567
Epoch 14/100
377/377 [==============================] - 13s 36ms/step - loss: 1.0100 - accuracy: 0.7534
Epoch 15/100
377/377 [==============================] - 13s 35ms/step - loss: 1.0399 - accuracy: 0.7378
Epoch 16/100
377/377 [==============================] - 13s 35ms/step - loss: 1.1572 - accuracy: 0.7078
Epoch 17/100
377/377 [==============================] - 13s 35ms/step - loss: 1.1517 - accuracy: 0.7010
Epoch 18/100
377/377 [==============================] - 13s 35ms/step - loss: 1.0355 - accuracy: 0.7302
Epoch 19/100
377/377 [==============================] - 13s 35ms/step - loss: 0.9822 - accuracy: 0.7451
Epoch 20/100
377/377 [==============================] - 13s 34ms/step - loss: 0.9095 - accuracy: 0.7674
Epoch 21/100
377/377 [==============================] - 13s 34ms/step - loss: 0.8308 - accuracy: 0.7860
Epoch 22/100
377/377 [==============================] - 13s 34ms/step - loss: 0.8204 - accuracy: 0.7908
Epoch 23/100
377/377 [==============================] - 13s 35ms/step - loss: 0.8255 - accuracy: 0.7848
Epoch 24/100
377/377 [==============================] - 13s 34ms/step - loss: 0.8749 - accuracy: 0.7744
Epoch 25/100
377/377 [==============================] - 13s 33ms/step - loss: 0.9931 - accuracy: 0.7392
Epoch 26/100
377/377 [==============================] - 13s 35ms/step - loss: 1.0942 - accuracy: 0.7183
Epoch 27/100
377/377 [==============================] - 13s 35ms/step - loss: 1.1175 - accuracy: 0.7096
Epoch 28/100
377/377 [==============================] - 13s 35ms/step - loss: 1.1793 - accuracy: 0.6963
Epoch 29/100
377/377 [==============================] - 13s 34ms/step - loss: 0.9927 - accuracy: 0.7450
Epoch 30/100
377/377 [==============================] - 13s 35ms/step - loss: 0.8399 - accuracy: 0.7803
Epoch 31/100
377/377 [==============================] - 13s 34ms/step - loss: 0.7961 - accuracy: 0.7927
Epoch 32/100
377/377 [==============================] - 12s 32ms/step - loss: 0.8260 - accuracy: 0.7846
Epoch 33/100
377/377 [==============================] - 13s 34ms/step - loss: 0.8635 - accuracy: 0.7711
Epoch 34/100
377/377 [==============================] - 13s 35ms/step - loss: 0.8859 - accuracy: 0.7654
Epoch 35/100
377/377 [==============================] - 13s 35ms/step - loss: 0.9220 - accuracy: 0.7563
Epoch 36/100
377/377 [==============================] - 13s 34ms/step - loss: 1.0614 - accuracy: 0.7222
Epoch 37/100
377/377 [==============================] - 13s 34ms/step - loss: 1.1474 - accuracy: 0.7070
Epoch 38/100
377/377 [==============================] - 19s 49ms/step - loss: 1.1066 - accuracy: 0.7127
Epoch 39/100
377/377 [==============================] - 21s 57ms/step - loss: 0.9473 - accuracy: 0.7493
Epoch 40/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8770 - accuracy: 0.7701
Epoch 41/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8103 - accuracy: 0.7836
Epoch 42/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7979 - accuracy: 0.7898
Epoch 43/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8467 - accuracy: 0.7785
Epoch 44/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8333 - accuracy: 0.7770
Epoch 45/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8838 - accuracy: 0.7715
Epoch 46/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9660 - accuracy: 0.7448
Epoch 47/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9930 - accuracy: 0.7380
Epoch 48/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9690 - accuracy: 0.7443
Epoch 49/100
377/377 [==============================] - 22s 57ms/step - loss: 0.9910 - accuracy: 0.7409
Epoch 50/100
377/377 [==============================] - 22s 57ms/step - loss: 0.9470 - accuracy: 0.7481
Epoch 51/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8933 - accuracy: 0.7625
Epoch 52/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8498 - accuracy: 0.7748
Epoch 53/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8372 - accuracy: 0.7797
Epoch 54/100
377/377 [==============================] - 22s 59ms/step - loss: 0.7566 - accuracy: 0.8003
Epoch 55/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7498 - accuracy: 0.8018
Epoch 56/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7654 - accuracy: 0.7957
Epoch 57/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8284 - accuracy: 0.7815
Epoch 58/100
377/377 [==============================] - 22s 57ms/step - loss: 0.9207 - accuracy: 0.7554
Epoch 59/100
377/377 [==============================] - 21s 57ms/step - loss: 0.9827 - accuracy: 0.7400
Epoch 60/100
377/377 [==============================] - 21s 56ms/step - loss: 0.9650 - accuracy: 0.7439
Epoch 61/100
377/377 [==============================] - 21s 57ms/step - loss: 0.9213 - accuracy: 0.7602
Epoch 62/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8284 - accuracy: 0.7769
Epoch 63/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8323 - accuracy: 0.7770
Epoch 64/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8226 - accuracy: 0.7775
Epoch 65/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8527 - accuracy: 0.7691
Epoch 66/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8918 - accuracy: 0.7632
Epoch 67/100
377/377 [==============================] - 21s 56ms/step - loss: 0.9432 - accuracy: 0.7514
Epoch 68/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9718 - accuracy: 0.7417
Epoch 69/100
377/377 [==============================] - 21s 54ms/step - loss: 0.9190 - accuracy: 0.7515
Epoch 70/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8419 - accuracy: 0.7740
Epoch 71/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7728 - accuracy: 0.7907
Epoch 72/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7592 - accuracy: 0.7978
Epoch 73/100
377/377 [==============================] - 21s 56ms/step - loss: 0.7652 - accuracy: 0.7954
Epoch 74/100
377/377 [==============================] - 21s 55ms/step - loss: 0.8769 - accuracy: 0.7677
Epoch 75/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9294 - accuracy: 0.7541
Epoch 76/100
377/377 [==============================] - 21s 57ms/step - loss: 0.9279 - accuracy: 0.7511
Epoch 77/100
377/377 [==============================] - 22s 58ms/step - loss: 0.9372 - accuracy: 0.7574
Epoch 78/100
377/377 [==============================] - 22s 57ms/step - loss: 0.9015 - accuracy: 0.7600
Epoch 79/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8527 - accuracy: 0.7720
Epoch 80/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8620 - accuracy: 0.7691
Epoch 81/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8496 - accuracy: 0.7766
Epoch 82/100
377/377 [==============================] - 21s 56ms/step - loss: 0.8471 - accuracy: 0.7738
Epoch 83/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8793 - accuracy: 0.7672
Epoch 84/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8734 - accuracy: 0.7659
Epoch 85/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8874 - accuracy: 0.7721
Epoch 86/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8655 - accuracy: 0.7680
Epoch 87/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8407 - accuracy: 0.7794
Epoch 88/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8817 - accuracy: 0.7696
Epoch 89/100
377/377 [==============================] - 21s 57ms/step - loss: 0.8687 - accuracy: 0.7689
Epoch 90/100
377/377 [==============================] - 21s 56ms/step - loss: 0.8713 - accuracy: 0.7695
Epoch 91/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8528 - accuracy: 0.7701
Epoch 92/100
377/377 [==============================] - 21s 56ms/step - loss: 0.8543 - accuracy: 0.7737
Epoch 93/100
377/377 [==============================] - 21s 55ms/step - loss: 0.8274 - accuracy: 0.7789
Epoch 94/100
377/377 [==============================] - 22s 58ms/step - loss: 0.7844 - accuracy: 0.7910
Epoch 95/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8883 - accuracy: 0.7775
Epoch 96/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8952 - accuracy: 0.7682
Epoch 97/100
377/377 [==============================] - 22s 58ms/step - loss: 0.8995 - accuracy: 0.7647
Epoch 98/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8762 - accuracy: 0.7679
Epoch 99/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8472 - accuracy: 0.7781
Epoch 100/100
377/377 [==============================] - 22s 57ms/step - loss: 0.8293 - accuracy: 0.7816
<tensorflow.python.keras.engine.sequential.Sequential object at 0x7fde79f33990>

作图分析。

最后的准确率也还行,能稳定下来。

import matplotlib.pyplot as plt

def plot_graphs(history, string):
  plt.plot(history.history[string])
  plt.xlabel("Epochs")
  plt.ylabel(string)
  plt.show()
plot_graphs(history, 'accuracy')

预测

我们给出一个引子,看看机器给我们生成随后的100个单词都是什么。

seed_text = "i have a dream"
next_words = 100

for _ in range(next_words):
    token_list = tokenizer.texts_to_sequences([seed_text])[0]
    token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
    predicted = model.predict_classes(token_list, verbose=0)
    output_word = ""
    for word, index in tokenizer.word_index.items():
        if index == predicted:
            output_word = word
            break
    seed_text += " " + output_word
print(seed_text)

完了,狗屁不通...当然此次数据量过少,训练次数也较少。

i have a dream and gathered them all by only all to smoke no pain love seen he said tree by tree to enchanting rostrevor me brain bent that right stand eyes gone i think gone the gone eyes gone red fail want belling road to dublin whack love flashed had gone no more bravely who marches away me on the broad majestic shannon long love gone where who marches i surely always ye strains gone red love meself and gorey right loch music for erin go bragh love so pray wid you right wid you good one was told the well i followed