-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrainning.py
85 lines (71 loc) · 2.62 KB
/
trainning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- coding: utf-8 -*-
"""
Created on Sun May 19 19:28:18 2019
@author: 雷浩洁
训练模型
"""
import os
import numpy as np
import tensorflow as tf
import model
import get_batch_data as gbd
N_CLASSES=3
IMG_H=19
IMG_W=128
CAPACITY = 1000
MAX_STEP=4000
learning_rate=0.01
def run_trainning():
"""
对神经网络进行训练
"""
PATH = "Cleaned_5sRNA_test/"
logs_train_dir="Net_model/" # 保存训练得来的模型的文件夹
row = 19
column = 128
vec_len = 8
temp_batch_size = 64
inputs, Labels= gbd.get_Data(PATH=PATH,
row=row,
column=column)
train_X, train_Y, one_hot_train_Y = gbd.get_batch_data(inputs,Labels, batch_size=temp_batch_size)
train_logits,train_v_length=model.interface(inputs=train_X,
Y=one_hot_train_Y,
batch_size=temp_batch_size,
vec_len=vec_len,
temp_batch_size=temp_batch_size)
train_loss=model.loss(logits=train_logits,
v_length=train_v_length,
labels=train_Y,
Y=one_hot_train_Y,
temp_batch_size=temp_batch_size)
train_op = model.trainning(train_loss,learning_rate)
train_acc = model.evalution(train_logits,train_Y)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
saver = tf.train.Saver()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord = coord)
try:
for step in np.arange(MAX_STEP):
if coord.should_stop():
break
_, tra_loss, tra_acc=sess.run([train_op, train_loss, train_acc])
if step % 100 == 0:
print("Step %d,train loss = %.2f,train accuracy = %.2f" %(step, tra_loss, tra_acc))
print(train_X.shape)
if step % 200 == 0:
# 每两百轮保存一次训练数据
checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt')
saver.save(sess,
save_path=checkpoint_path,
global_step=step)
except tf.errors.OutOfRangeError:
print('Done Trainning')
finally:
coord.request_stop()
coord.join(threads)
sess.close()
if __name__ == "__main__":
run_trainning()