-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathlesson 29. VAE.py
85 lines (62 loc) · 2.48 KB
/
lesson 29. VAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow import keras
import keras.backend as K
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, Lambda, BatchNormalization, Dropout
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# стандартизация входных данных
x_train = x_train / 255
x_test = x_test / 255
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))
hidden_dim = 2
batch_size = 60 # должно быть кратно 60 000
def dropout_and_batch(x):
return Dropout(0.3)(BatchNormalization()(x))
input_img = Input((28, 28, 1))
x = Flatten()(input_img)
x = Dense(256, activation='relu')(x)
x = dropout_and_batch(x)
x = Dense(128, activation='relu')(x)
x = dropout_and_batch(x)
z_mean = Dense(hidden_dim)(x)
z_log_var = Dense(hidden_dim)(x)
def noiser(args):
global z_mean, z_log_var
z_mean, z_log_var = args
N = K.random_normal(shape=(batch_size, hidden_dim), mean=0., stddev=1.0)
return K.exp(z_log_var / 2) * N + z_mean
h = Lambda(noiser, output_shape=(hidden_dim,))([z_mean, z_log_var])
input_dec = Input(shape=(hidden_dim,))
d = Dense(128, activation='relu')(input_dec)
d = dropout_and_batch(d)
d = Dense(256, activation='relu')(d)
d = dropout_and_batch(d)
d = Dense(28*28, activation='sigmoid')(d)
decoded = Reshape((28, 28, 1))(d)
encoder = keras.Model(input_img, h, name='encoder')
decoder = keras.Model(input_dec, decoded, name='decoder')
vae = keras.Model(input_img, decoder(encoder(input_img)), name="vae")
def vae_loss(x, y):
x = K.reshape(x, shape=(batch_size, 28*28))
y = K.reshape(y, shape=(batch_size, 28*28))
loss = K.sum(K.square(x-y), axis=-1)
kl_loss = -0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return loss + kl_loss
vae.compile(optimizer='adam', loss=vae_loss)
vae.fit(x_train, x_train, epochs=5, batch_size=batch_size, shuffle=True)
h = encoder.predict(x_test[:6000], batch_size=batch_size)
plt.scatter(h[:, 0], h[:, 1])
n = 5
total = 2*n+1
plt.figure(figsize=(total, total))
num = 1
for i in range(-n, n+1):
for j in range(-n, n+1):
ax = plt.subplot(total, total, num)
num += 1
img = decoder.predict(np.expand_dims([3*i/n, 3*j/n], axis=0))
plt.imshow(img.squeeze(), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)