-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmain.py
94 lines (73 loc) · 2.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import yaml
from chainer.dataset import convert
import chainer
from chainer.training import extensions
from preprocess import DataProcessor
from CNN import CNN
import chainer.links as L
import chainer.optimizers as O
from chainer import training
import sys
def main(options):
#load the config params
gpu = options['gpu']
data_path = options['path_dataset']
embeddings_path = options['path_vectors']
n_epoch = options['epochs']
batch_size = options['batchsize']
test = options['test']
embed_dim = options['embed_dim']
freeze = options['freeze_embeddings']
distance_embed_dim = options['distance_embed_dim']
#load the data
data_processor = DataProcessor(data_path)
data_processor.prepare_dataset()
train_data = data_processor.train_data
test_data = data_processor.test_data
vocab = data_processor.vocab
cnn = CNN(n_vocab=len(vocab), input_channel=1,
output_channel=100,
n_label=19,
embed_dim=embed_dim, position_dims=distance_embed_dim, freeze=freeze)
cnn.load_embeddings(embeddings_path, data_processor.vocab)
model = L.Classifier(cnn)
#use GPU if flag is set
if gpu >= 0:
model.to_gpu()
#setup the optimizer
optimizer = O.Adam()
optimizer.setup(model)
train_iter = chainer.iterators.SerialIterator(train_data, batch_size)
test_iter = chainer.iterators.SerialIterator(test_data, batch_size,repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, converter=convert.concat_examples, device=gpu)
trainer = training.Trainer(updater, (n_epoch, 'epoch'))
# Evaluation
test_model = model.copy()
test_model.predictor.train = False
trainer.extend(extensions.Evaluator(test_iter, test_model, device=gpu, converter=convert.concat_examples))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar(update_interval=10))
trainer.run()
if __name__ == '__main__':
options = {}
if len(sys.argv) > 1:
path_config = sys.argv[1]
else:
print("command format : python3 main.py config.yaml")
with open(path_config, 'r') as ymlfile:
cfg = yaml.load(ymlfile)
#options["path_data"] = cfg["path_data"]
options["path_vectors"] = cfg["path_vectors"]
options["path_dataset"] = cfg["path_dataset"]
options["gpu"] = cfg["gpu"]
options["epochs"] = cfg["epochs"]
options["batchsize"] = cfg["batch_size"]
options["test"] = cfg["test"]
options["embed_dim"] = cfg["embed_dim"]
options["freeze_embeddings"] = cfg["freeze_embedding"]
options["distance_embed_dim"] = cfg["distance_embed_dim"]
main(options)