-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbatch_exp.py
executable file
·152 lines (115 loc) · 4.25 KB
/
batch_exp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# This script runs search and training using different hyperparams
import os
import argparse
from search_sr import run_search
from augment_sr import run_train
from search_sr import train_setup as search_setup
from augment_sr import train_setup as augment_setup
from omegaconf import OmegaConf as omg
from validate_sr import get_model, dataset_loop
import genotypes
import utils
import traceback
from pthflops import count_Flops
"""
EXAMPLE: python batch_exp.py -v 0 0.0001 -d v0.0 -g 0 -c quant_config.yaml
"""
VAL_CFG_PATH = "./sr_models/valsets4x.yaml"
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--values",
nargs="+",
default=[],
help="argument values ex.: 0.1 0.2 0.3 ",
)
parser.add_argument("-d", "--dir", type=str, default="batch", help="log dir")
parser.add_argument("-c", "--conf", type=str, default="quant_config.yaml", help="log dir")
parser.add_argument(
"-n", "--name", type=str, default="batch_experiment", help="experiment name"
)
parser.add_argument(
"-r",
"--repeat",
type=int,
default=1,
help="repeat experiments",
)
parser.add_argument("-g", "--gpu", type=int, default=0, help="gpu to use")
args = parser.parse_args()
def run_batch():
key = 'penalty'
values = args.values
base_run_name = args.name
config = f"./configs/{args.conf}"
cfg = omg.load(config)
log_dir = cfg.env.log_dir
assert (key in cfg.train) or (
key in cfg.search
), f"{key} is not found in config"
cfg.env.gpu = args.gpu
for r in range(1, args.repeat + 1):
for mode in ["train", "search"]:
cfg.env.log_dir = os.path.join(
log_dir,
args.dir,
f"trail_{r}",
)
os.makedirs(cfg.env.log_dir, exist_ok=True)
print("TRIAL #", r)
for val in values:
for mode in ["train", "search"]:
cfg.env.run_name = f"{base_run_name}_{key}_{val}_trail_{r}"
if key in cfg.search:
cfg.search[key] = val
if key in cfg.train:
cfg.train[key] = val
# get actual run dir with date stamp
run_path = utils.get_run_path(
cfg.env.log_dir, "SEARCH_" + cfg.env.run_name
)
cfg.train.genotype_path = os.path.join(
run_path,
"best_arch.gen",
)
print(f"SEARCHING: {str(key).upper()}:{str(val).upper()}")
cfg_search, writer, logger, log_handler = search_setup(cfg)
try:
run_search(cfg_search, writer, logger, log_handler)
except Exception as e:
with open(os.path.join(cfg_search.env.save_path, "ERROR.txt"), "a") as f:
f.write(traceback.format_exc())
print(traceback.format_exc())
raise e
print(f"TRAINING: {str(key).upper()}:{str(val).upper()}")
cfg_train, writer, logger, log_handler = augment_setup(cfg)
try:
run_train(cfg_train, writer, logger, log_handler)
except Exception as e:
with open(os.path.join(cfg_train.env.save_path, "ERROR.txt"), "a") as f:
f.write(traceback.format_exc())
print(traceback.format_exc())
raise e
with open(cfg.train.genotype_path, "r") as f:
genotype = genotypes.from_str(f.read())
weights_path = os.path.join(cfg.env.save_path, "best.pth.tar")
# VALIDATE:
logger = utils.get_logger(run_path + "/validation_log.txt")
save_dir = os.path.join(run_path, "FINAL_VAL")
os.makedirs(save_dir, exist_ok=True)
logger.info(genotype)
valid_cfg = omg.load(VAL_CFG_PATH)
model = get_model(
weights_path,
cfg.env.gpu,
genotype,
cfg.arch.c_fixed,
cfg.arch.channels,
cfg.dataset.scale,
body_cells=cfg.arch.body_cells,
skip_mode=cfg.arch.skip_mode,
)
dataset_loop(valid_cfg, model, logger, save_dir, cfg.env.gpu)
logger.info(count_Flops(model))
if __name__ == "__main__":
run_batch()