Skip to content

Commit

Permalink
FEAT:copy all DETR detection code
Browse files Browse the repository at this point in the history
  • Loading branch information
F-jie committed Feb 22, 2022
1 parent e31abfc commit fa98e68
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 6 deletions.
73 changes: 68 additions & 5 deletions DETR.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import argparse
import imp
from fileinput import filename
from functools import total_ordering
import json
from tabnanny import check
import typing
import time
import torch
import datetime
from pathlib import Path
from torch.utils.data import DataLoader
from TAU.engine import train_one_epoch
from TAU.model.detr import build
from TAU.dataset.coco import build_coco, get_coco_api_from_dataset
from TAU.utils.TAUUtils import collate_fn
from TAU.utils.TAUUtils import collate_fn, evaluate, is_main_process, save_on_master

def get_args_parser():
parser = argparse.ArgumentParser("set transformer detector", add_help=False)
Expand Down Expand Up @@ -70,6 +74,65 @@ def main(args):
args.start_epoch = checkpoint["epoch"] + 1

if args.eval:
test_stats, coco_evaluator = evaluate

test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_train, base_ds, device, args.output_dir)
if args.output_dir:
save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
return

print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm
)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / "checkpoint.pth"]

if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0:
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")

for checkpoint_path in checkpoint_paths:
save_on_master({
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args
}, checkpoint_path)

test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters
}

if args.output_dir and is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dump(log_stats) + "\n")

if coco_evaluator is not None:
(output_dir / "eval").mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name)

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))

if __name__ == "__main__":
parser = argparse.ArgumentParser("DETR training and evaluation script", parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
3 changes: 2 additions & 1 deletion TAU/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import sys
from unittest import result
from cv2 import reduce
from numpy import flatnonzero
import torch
Expand Down Expand Up @@ -82,7 +83,7 @@ def evalute(model, criterion, postprocessors, data_loader, base_ds, device, outp

orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors["bbox"](outputs, orig_target_sizes)
res = {targets["image_id"].item(): output for target, output in zip(targets, outputs)}
res = {targets["image_id"].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)

Expand Down
3 changes: 3 additions & 0 deletions TAU/utils/TAUUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,6 @@ def convert_to_xywh(boxes):
xmin, ymin, xmax, ymax = boxes.unbind(1)
return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)

def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)

0 comments on commit fa98e68

Please sign in to comment.