-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
88 lines (70 loc) · 3.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import json
import os
from datetime import datetime
from typing import Any, Dict
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
from tqdm import tqdm
from utils import score_ensemble, standardize_scores
def final_scoring(scores_dict: Dict[str, Dict[str, float]], cfg: Any) -> None:
"""
Calculate final scores using consensus and ensemble scores, and update predictions.
Parameters:
- scores_dict (Dict[str, Dict[str, Dict[str, float]]): Dictionary of scores.
- cfg (OmegaConf): The configuration file.
"""
cons_score_dict = scores_dict["consensus"]
clip_score_dict = {k: v for k, v in scores_dict.items() if k != "consensus"}
ens_score_dict = score_ensemble(*list(clip_score_dict.values()))
pred_file_path = os.path.join(cfg.DIR.Origin, "pred.csv")
pred = pd.read_csv(pred_file_path)
final_scores_dict = {}
for file in tqdm(cons_score_dict):
caption_list = list(cons_score_dict[file])
cons_scores = np.array(list(cons_score_dict[file].values()))
ens_scores = np.array([ens_score_dict[file][caption] for caption in caption_list])
total_scores = cfg.WEIGHT.Consensus_Score * cons_scores + cfg.WEIGHT.CLIP_Score * ens_scores
final_scores_dict[file] = {"captions": caption_list, "scores": total_scores.tolist()}
max_idxs = np.argsort(total_scores)[-2:][::-1]
max_caps = [caption_list[idx] for idx in max_idxs]
if abs(total_scores[max_idxs[0]] - total_scores[max_idxs[1]]) < cfg.THRESH.Short_Cap_Selection:
max_caption = min(max_caps, key=lambda x: len(x.split()))
else:
max_caption = max_caps[0]
pred.loc[pred["filename"] == file, "caption"] = max_caption
filename = "pred.csv"
print(f"Saving predictions to {filename}")
pred.to_csv(os.path.join(cfg.DIR.Result, filename), index=False)
with open(os.path.join(cfg.DIR.Result, f"{filename[:-4]}.json"), "w") as f:
json.dump(final_scores_dict, f)
def load_and_standardize_scores(file_name, score_directory):
"""
Attempts to load and standardize scores from a specified JSON file.
:param file_name: The name of the JSON file containing the scores.
:param score_directory: The directory where the score files are located.
:return: The standardized scores if successful, None otherwise.
"""
try:
return standardize_scores(os.path.join(score_directory, file_name))
except Exception as e:
print(f"Error loading {file_name}. Please check the file path.")
print(f"Error details: {e}")
exit(1)
if __name__ == "__main__":
cfg = OmegaConf.load("configs.yaml")
scores_dict = {}
score_models = [
("consensus", "itm_filtered_consensus.json", True),
("mobileclip", "mobileclip_scores.json", cfg.MODEL.MobileCLIP),
("openclip", "openclip_scores.json", cfg.MODEL.OpenCLIP),
("evaclip", "evaclip_scores.json", cfg.MODEL.EvaCLIP),
("metaclip", "metaclip_scores.json", cfg.MODEL.MetaCLIP),
("blip2itc", "blip2_itc_scores.json", cfg.MODEL.Blip2ITC),
]
for model_name, file_name, condition in score_models:
if condition:
scores = load_and_standardize_scores(file_name, cfg.DIR.Score)
if scores is not None:
scores_dict[model_name] = scores
final_scoring(scores_dict, cfg)