forked from wrongu/modularity
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_best_ckpt.py
75 lines (63 loc) · 2.76 KB
/
create_best_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python
import torch
import argparse
import warnings
from pathlib import Path
from typing import Union
from eval import evaluate
def bestify(weights_dir: Union[str, Path],
field: str = "val_loss",
mode: str = "min",
overwrite: bool = False,
data_dir: Union[str, Path] = Path("data"),
verbose: bool = False):
weights_dir = Path(weights_dir)
data_dir = Path(data_dir)
best_file = weights_dir / "best.ckpt"
if mode not in ("max", "min"):
raise ValueError(f"Argument 'mode' must be 'max' or 'min' but is {mode}")
best_ckpt, best_val = None, float("-inf") if mode == "max" else float("+inf")
for ckpt in weights_dir.glob("*.ckpt"):
if ckpt.is_symlink():
continue
evaluate(ckpt, data_dir, metrics=[field])
data = torch.load(ckpt, map_location="cpu")
val = data[field]
if verbose:
print(f"\t{ckpt}: {field}={val}")
if (mode == "max" and val > best_val) or (mode == "min" and val < best_val):
best_ckpt, best_val = ckpt.resolve(), val
if best_ckpt is None:
warnings.warn(f"No checkpoints to check in {weights_dir}")
return
# Only raise error if 'overwrite' is False and the 'best.ckpt' referee would change.
if best_file.exists():
previous_best = best_file.resolve()
if previous_best != best_ckpt and not overwrite:
raise FileExistsError()
elif previous_best == best_ckpt:
if verbose:
print(f"Keeping previous best for {weights_dir}: {best_ckpt}")
return
# If we reach here, it's time to create a new symlink called "best.ckpt" that points to the best_ckpt file
if verbose:
print(f"Best for {weights_dir} is {best_ckpt}; creating link in {best_file}")
best_file.symlink_to(target=best_ckpt, target_is_directory=False)
def main(args):
assert args.directory.is_dir(), f"{args.directory} is not a directory"
bestify(args.directory, args.field, args.mode, args.overwrite, verbose=args.verbose)
if args.recurse:
for sub in args.directory.iterdir():
if sub.is_dir():
args.directory = sub
main(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("directory", default=Path("."), type=Path)
parser.add_argument("--field", default="val_loss", type=str)
parser.add_argument("--mode", default="min", type=str, choices=["min", "max"])
parser.add_argument("--recurse", action="store_true", default=False)
parser.add_argument("--overwrite", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
args = parser.parse_args()
main(args)