-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathCNN_V2.py
162 lines (127 loc) · 5.97 KB
/
CNN_V2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# -*- coding: utf-8 -*-
# @author: Awesome_Tang
# @date: 2018-12-15
# @version: python2.7
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
class constant(object):
"""
CNN 模型参数
"""
classes = 10 # 类别数
num_filters = 32 # 卷积核数
kernel_size = 3 # 卷积核大小
alpha = 1e-3 # 学习率
keep_prob = 0.5 # 保留比例
steps = 10000 # 迭代次数
batch_size = 128 # 每批次训练样本数
tensorboard_dir = 'tensorboard/CNN' # log输出路径
print_per_batch = 100 # 每多少轮输出一次结果
save_per_batch = 10 # 每多少轮存入tensorboard
decay_rate = 0.5 # 衰减率
decay_steps = 1000 # 衰减次数
class CNN:
def __init__(self):
self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
self.input_x = tf.placeholder(tf.float32, [None, 784], name='input_x')
self.input_y = tf.placeholder(tf.float32, [None, constant.classes], name='input_y')
self.keep_prob = tf.placeholder("float")
self.cnn_model()
@staticmethod
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
@staticmethod
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
@staticmethod
def conv2d(x, w):
return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding="SAME")
@staticmethod
def max_pool_2x2(x, kernel_size):
return tf.nn.max_pool(x, ksize=kernel_size, strides=[1, 2, 2, 1], padding="SAME")
def feed_data(self, x, y, keep_prob=1.):
feed_dict = {self.input_x: x,
self.input_y: y,
self.keep_prob: keep_prob}
return feed_dict
def cnn_model(self):
# 第一层: 卷积
x_image = tf.reshape(self.input_x, [-1, 28, 28, 1])
x_padding = tf.pad(x_image,[[0,0],[2,2],[2,2],[0,0]],"CONSTANT")
w_cv1 = self.weight_variable([5, 5, 1, 32])
b_cv1 = self.bias_variable([32])
h_cv1 = tf.nn.relu(tf.add(self.conv2d(x_padding, w_cv1), b_cv1))
h_mp1 = self.max_pool_2x2(h_cv1, [1, 2, 2, 1])
h_mp1 = tf.nn.dropout(h_mp1, self.keep_prob)
# 第二层: 卷积
w_cv2 = self.weight_variable([5, 5, 32, 64])
b_cv2 = self.bias_variable([64])
h_cv2 = tf.nn.relu(tf.add(self.conv2d(h_mp1, w_cv2), b_cv2))
h_mp2 = self.max_pool_2x2(h_cv2, [1, 2, 2, 1])
h_mp2 = tf.nn.dropout(h_mp2, self.keep_prob)
# 第二层: 卷积
w_cv3 = self.weight_variable([5, 5, 64, 128])
b_cv3 = self.bias_variable([128])
h_cv3 = tf.nn.relu(tf.add(self.conv2d(h_mp2, w_cv3), b_cv3))
h_mp3 = self.max_pool_2x2(h_cv3, [1, 2, 2, 1])
h_mp3 = tf.nn.dropout(h_mp3, self.keep_prob)
# 第三层: 全连接
w_fc1 = self.weight_variable([4 * 4 * 128, 1024])
b_fc1 = self.bias_variable([1024])
h_mp2_flat = tf.reshape(h_mp3, [-1, 4 * 4 * 128])
h_fc1 = tf.nn.relu(tf.add(tf.matmul(h_mp2_flat, w_fc1), b_fc1))
h_fc1 = tf.nn.dropout(h_fc1, self.keep_prob)
w_fc2 = self.weight_variable([1024, 10])
b_fc2 = self.bias_variable([10])
y_conv = tf.add(tf.matmul(h_fc1, w_fc2), b_fc2)
# 变换学习率
gloabl_steps = tf.Variable(0, trainable=False)
learning_rate = tf.train.exponential_decay(constant.alpha, gloabl_steps, constant.decay_steps,
constant.decay_rate, staircase=True)
# Adam优化器
loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=self.input_y, logits=y_conv))
train_step = tf.train.AdamOptimizer(
learning_rate).minimize(loss, global_step=gloabl_steps)
correct_prediction = tf.equal(
tf.argmax(y_conv, 1), tf.argmax(self.input_y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
tf.summary.scalar("loss", loss)
tf.summary.scalar("accuracy", accuracy)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(constant.tensorboard_dir)
start_time = datetime.now()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(constant.steps):
data_x, data_y = self.mnist.train.next_batch(constant.batch_size)
feed_dict = self.feed_data(data_x, data_y, constant.keep_prob)
sess.run([train_step, gloabl_steps], feed_dict=feed_dict)
if i % constant.save_per_batch == 0:
feed_dict[self.keep_prob] = 1.
s = sess.run(merged_summary, feed_dict=feed_dict)
writer.add_summary(s, i)
if i % constant.print_per_batch == 0:
train_acc, train_loss = sess.run([accuracy, loss],
feed_dict=feed_dict)
data_x, data_y = self.mnist.validation.images, self.mnist.validation.labels
feed_dict = self.feed_data(data_x, data_y)
val_acc, val_loss = sess.run([accuracy, loss],
feed_dict=feed_dict)
msg = 'Step {:5}, train_acc:{:8.2%}, train_loss:{:6.2f},' \
' val_acc:{:8.2%}, val_loss:{:6.2f}'
print msg.format(i, train_acc, train_loss, val_acc, val_loss)
end_time = datetime.now()
time_diff = (end_time - start_time).seconds
print 'Time Usage : {:.2f} hours'.format(time_diff / 3600.)
data_x, data_y = self.mnist.test.images, self.mnist.test.labels
feed_dict = self.feed_data(data_x, data_y)
test_acc, test_loss = sess.run([accuracy, loss],
feed_dict=feed_dict)
print "Test accuracy :{:8.2%}, loss:{:6.2f}".format(test_acc, test_loss)
sess.close()
if __name__ == "__main__":
CNN()