-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
71 lines (56 loc) · 2.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import torch
from Code import Knowformer
def save_model(config: dict, model: torch.nn.Module, model_path: str):
torch.save({'config': config, 'model': model.state_dict()}, model_path)
def load_model(model_path: str, device: str):
# load the trained Knowformer model
print(f'Loading N-Former from {model_path}')
state_dict = torch.load(model_path, map_location=device)
model_config = state_dict['config']
model = Knowformer(model_config)
model.load_state_dict(state_dict['model'])
return model_config, model
def swa(output_path, device):
"""
we save the best 20 models, load these model and average parameters
:param output_path:
:param device:
:return:
"""
files = os.listdir(output_path)
files = [file_name for file_name in files if file_name.startswith('epoch_')]
model_config = None
model_dicts = list()
for file_name in files:
state_dict = torch.load(os.path.join(output_path, file_name), map_location=device)
model_config = state_dict['config']
model_dicts.append(state_dict['model'])
avg_model_dict = dict()
for k in model_dicts[0]:
sum_param = None
for dit in model_dicts:
if sum_param is None:
sum_param = dit[k]
else:
sum_param += dit[k]
avg_param = sum_param / len(model_dicts)
avg_model_dict[k] = avg_param
model = Knowformer(model_config)
model.load_state_dict(avg_model_dict)
save_model(model_config, model, os.path.join(output_path, 'avg.bin'))
def score2str(score):
loss = score['loss']
hits1 = score['hits@1']
hits3 = score['hits@3']
hits10 = score['hits@10']
mrr = score['MRR']
return f'loss: {loss}, hits@1: {hits1}, hits@3: {hits3}, hits@10: {hits10}, MRR: {mrr}'
def save_results(triples, ranks):
results = list()
batch_size = len(triples)
for i in range(batch_size):
h, r, t = triples[i]
rank = ranks[i]
results.append((h, r, t, rank))
return results