-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsweep_test.py
75 lines (59 loc) · 3.57 KB
/
sweep_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import argparse
import os
import wandb
from algorithms.algorithms_utils import AlgorithmsEnum
from algorithms.naive_algs import PopularItems
from algorithms.sgd_alg import ECF, DeepMatrixFactorization
from data.data_utils import get_dataloader, DatasetsEnum, build_user_and_item_pop_matrix, build_user_and_item_tag_matrix
from data.dataset import TrainRecDataset, ECFTrainRecDataset
from eval.eval import evaluate_recommender_algorithm, FullEvaluator, FullEvaluatorCalibrationDecorator
from utilities.wandb_utils import fetch_best_in_sweep
parser = argparse.ArgumentParser(description='Start a test experiment')
parser.add_argument('--sweep_id', '-s', type=str, help='ID of the sweep')
parser.add_argument('--wandb_entity_name', '-e', help='Name of the Entity on W&B', type=str)
parser.add_argument('--wandb_project_name', '-p', help='Name of the Project on W&B', type=str)
parser.add_argument('--measure_calibration', '-c', help='Whether to compute calibration metrics as well',
action='store_true', default=False)
args = parser.parse_args()
sweep_id = args.sweep_id
measure_calibration = args.measure_calibration
wandb_entity_name = args.wandb_entity_name
wandb_project_name = args.wandb_project_name
best_run_config = fetch_best_in_sweep(sweep_id,
good_faith=False,
project_base_directory='.',
preamble_path='~/PycharmProjects',
wandb_entitiy_name=wandb_entity_name,
wandb_project_name=wandb_project_name
)
# Model is now local
# Carry out Test
alg = AlgorithmsEnum[best_run_config['alg']]
dataset = DatasetsEnum[best_run_config['dataset']]
conf = best_run_config
print('Starting Test')
print(f'Algorithm is {alg.name} - Dataset is {dataset.name}')
wandb.init(project=wandb_project_name, entity=wandb_entity_name, config=conf, tags=[alg.name, dataset.name],
group=f'{alg.name} - {dataset.name} - test', name=conf['time_run'], job_type='test', reinit=True)
conf['running_settings']['eval_n_workers'] = 0
test_loader = get_dataloader(conf, 'test')
if alg.value == PopularItems or alg.value == DeepMatrixFactorization:
# Popular Items requires the popularity distribution over the items learned over the training data
# DeepMatrixFactorization also requires access to the training data
alg = alg.value.build_from_conf(conf, TrainRecDataset(conf['dataset_path']))
elif alg.value == ECF:
alg = alg.value.build_from_conf(conf, ECFTrainRecDataset(conf['dataset_path']))
else:
alg = alg.value.build_from_conf(conf, test_loader.dataset)
alg.load_model_from_path(conf['model_path'])
evaluator = FullEvaluator(aggr_by_group=True, n_groups=test_loader.dataset.n_user_groups,
user_to_user_group=test_loader.dataset.user_to_user_group)
if measure_calibration:
user_tag_mtx, item_tag_mtx = build_user_and_item_tag_matrix(os.path.join(conf['data_path'], conf['dataset']))
user_pop_mtx, item_pop_mtx = build_user_and_item_pop_matrix(os.path.join(conf['data_path'], conf['dataset']))
evaluator = FullEvaluatorCalibrationDecorator(evaluator, item_tag_mtx, user_tag_mtx, metric_name_prefix='tag')
evaluator = FullEvaluatorCalibrationDecorator(evaluator, item_pop_mtx, user_pop_mtx, metric_name_prefix='pop')
metrics_values = evaluate_recommender_algorithm(alg, test_loader, evaluator,
verbose=conf['running_settings']['batch_verbose'])
wandb.log(metrics_values, step=0)
wandb.finish()