-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDETR.py
138 lines (112 loc) · 5.69 KB
/
DETR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
from fileinput import filename
from functools import total_ordering
import json
from tabnanny import check
import time
import torch
import datetime
from pathlib import Path
from torch.utils.data import DataLoader
from TAU.engine import train_one_epoch
from TAU.model.detr import build
from TAU.dataset.coco import build_coco, get_coco_api_from_dataset
from TAU.utils.TAUUtils import collate_fn, evaluate, is_main_process, save_on_master
def get_args_parser():
parser = argparse.ArgumentParser("set transformer detector", add_help=False)
return parser
def main(args):
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentaton only"
print(args)
device = torch.device(args.device)
model, criterion, postprocessors = build(args)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("numer of params: ", n_parameters)
param_dicts = [
{"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
{
"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
"lr": args.lr_backbone
}
]
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
dataset_train = build_coco(image_set="train", args=args)
dataset_val = build_coco(image_set="val", args=args)
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True)
data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, collate_fn=collate_fn, num_workers=args.num_workers)
data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, drop_last=False, collate_fn=collate_fn, num_workers=args.num_workers)
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
model_without_ddp.detr.load_state_dict(checkpoint["model"])
output_dir = Path(args.output_dir)
if args.resume.startwith("http"):
checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location="cpu", check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
if not args.eval and "optimizer" in checkpoint and "lr_scheduler" in checkpoint and "epoch" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.eval:
test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_train, base_ds, device, args.output_dir)
if args.output_dir:
save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm
)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / "checkpoint.pth"]
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
for checkpoint_path in checkpoint_paths:
save_on_master({
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args
}, checkpoint_path)
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters
}
if args.output_dir and is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dump(log_stats) + "\n")
if coco_evaluator is not None:
(output_dir / "eval").mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))
if __name__ == "__main__":
parser = argparse.ArgumentParser("DETR training and evaluation script", parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)