-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathlesson 21. RNN words predict.py
61 lines (43 loc) · 1.94 KB
/
lesson 21. RNN words predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from tensorflow.keras.layers import Dense, SimpleRNN, Input
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.text import Tokenizer, text_to_word_sequence
from tensorflow.keras.utils import to_categorical
with open('text', 'r', encoding='utf-8') as f:
texts = f.read()
texts = texts.replace('\ufeff', '') # убираем первый невидимый символ
maxWordsCount = 1000
tokenizer = Tokenizer(num_words=maxWordsCount, filters='!–"—#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n\r«»',
lower=True, split=' ', char_level=False)
tokenizer.fit_on_texts([texts])
dist = list(tokenizer.word_counts.items())
print(dist[:10])
data = tokenizer.texts_to_sequences([texts])
res = to_categorical(data[0], num_classes=maxWordsCount)
print(res.shape)
inp_words = 3
n = res.shape[0] - inp_words
X = np.array([res[i:i + inp_words, :] for i in range(n)])
Y = res[inp_words:]
model = Sequential()
model.add(Input((inp_words, maxWordsCount)))
model.add(SimpleRNN(128, activation='tanh'))
model.add(Dense(maxWordsCount, activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy', metrics=['accuracy'], optimizer='adam')
history = model.fit(X, Y, batch_size=32, epochs=50)
def buildPhrase(texts, str_len=20):
res = texts
data = tokenizer.texts_to_sequences([texts])[0]
for i in range(str_len):
x = to_categorical(data[i: i + inp_words], num_classes=maxWordsCount) # преобразуем в One-Hot-encoding
inp = x.reshape(1, inp_words, maxWordsCount)
pred = model.predict(inp)
indx = pred.argmax(axis=1)[0]
data.append(indx)
res += " " + tokenizer.index_word[indx] # дописываем строку
return res
res = buildPhrase("позитив добавляет годы")
print(res)