-
Notifications
You must be signed in to change notification settings - Fork 606
/
Copy pathevaluate.py
67 lines (51 loc) · 2.45 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import time
import importlib
from os.path import join, isdir, isfile, relpath
from glob import glob
from evaluation import BaseConfig, ModelForEvaluation, DEFAULT_CLASS, print_rank_0
from initialize import initialize, initialize_model_and_tokenizer
def add_evaluation_specific_args(parser):
"""Arguments for evaluation"""
group = parser.add_argument_group("evaluation", "Evaluation configurations")
# Task
group.add_argument("--task", nargs="+", default=[], help="All task config to evaluation")
group.add_argument("--data-path", type=str, required=True, help="Data dir path for all tasks")
return parser
def find_all_tasks(all_task_config_path):
tasks = []
for task in all_task_config_path:
if isdir(task):
tasks += [relpath(path, ".") for path in glob(join(task, "**/*.yaml"), recursive=True)]
elif isfile(task):
tasks.append(task)
return tasks
def evaluate_all_tasks(data_path, model, tokenizer, all_task_config_path, task_classes):
for config_path, task_class in zip(all_task_config_path, task_classes):
config = task_class.config_class().from_yaml_file(config_path)
config.path = join(data_path, config.path)
task = task_class(model, tokenizer, config)
task.evaluate()
def main():
args = initialize(extra_args_provider=add_evaluation_specific_args)
args.task = find_all_tasks(args.task)
task_classes = []
print_rank_0("> Loading task configs")
for task_config_path in args.task:
config = BaseConfig.from_yaml_file(task_config_path)
if config.module:
path = ".".join(config.module.split(".")[:-1])
module = importlib.import_module(path)
class_name = config.module.split(".")[-1]
task_class = getattr(module, class_name)
task_classes.append(task_class)
else:
task_classes.append(DEFAULT_CLASS[config.type])
print_rank_0(f" Task {config.name} loaded from config {task_config_path}")
print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}")
model, tokenizer = initialize_model_and_tokenizer(args)
model = ModelForEvaluation(model)
start = time.time()
evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes)
print_rank_0(f"Finish {len(task_classes)} task{'s' if len(task_classes) > 1 else ''} in {time.time() - start:.1f}s")
if __name__ == "__main__":
main()