From fa98e68033d48334487b5c13e76ddc7275ff48f0 Mon Sep 17 00:00:00 2001 From: F-jie Date: Tue, 22 Feb 2022 23:11:54 +0800 Subject: [PATCH] =?UTF-8?q?FEAT=EF=BC=9Acopy=20all=20DETR=20detection=20co?= =?UTF-8?q?de?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DETR.py | 73 ++++++++++++++++++++++++++++++++++++++++--- TAU/engine.py | 3 +- TAU/utils/TAUUtils.py | 3 ++ 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/DETR.py b/DETR.py index 0ec5965..2d3c7c8 100644 --- a/DETR.py +++ b/DETR.py @@ -1,13 +1,17 @@ import argparse -import imp +from fileinput import filename +from functools import total_ordering +import json from tabnanny import check -import typing +import time import torch +import datetime from pathlib import Path from torch.utils.data import DataLoader +from TAU.engine import train_one_epoch from TAU.model.detr import build from TAU.dataset.coco import build_coco, get_coco_api_from_dataset -from TAU.utils.TAUUtils import collate_fn +from TAU.utils.TAUUtils import collate_fn, evaluate, is_main_process, save_on_master def get_args_parser(): parser = argparse.ArgumentParser("set transformer detector", add_help=False) @@ -70,6 +74,65 @@ def main(args): args.start_epoch = checkpoint["epoch"] + 1 if args.eval: - test_stats, coco_evaluator = evaluate - + test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_train, base_ds, device, args.output_dir) + if args.output_dir: + save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") + return + + print("Start training") + start_time = time.time() + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + sampler_train.set_epoch(epoch) + train_stats = train_one_epoch( + model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm + ) + lr_scheduler.step() + if args.output_dir: + checkpoint_paths = [output_dir / "checkpoint.pth"] + + if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: + checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") + + for checkpoint_path in checkpoint_paths: + save_on_master({ + "model": model_without_ddp.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch, + "args": args + }, checkpoint_path) + + test_stats, coco_evaluator = evaluate( + model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir + ) + log_stats = { + **{f"train_{k}": v for k, v in train_stats.items()}, + **{f"test_{k}": v for k, v in test_stats.items()}, + "epoch": epoch, + "n_parameters": n_parameters + } + if args.output_dir and is_main_process(): + with (output_dir / "log.txt").open("a") as f: + f.write(json.dump(log_stats) + "\n") + + if coco_evaluator is not None: + (output_dir / "eval").mkdir(exist_ok=True) + if "bbox" in coco_evaluator.coco_eval: + filenames = ["latest.pth"] + if epoch % 50 == 0: + filenames.append(f"{epoch:03}.pth") + for name in filenames: + torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print("Training time {}".format(total_time_str)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser("DETR training and evaluation script", parents=[get_args_parser()]) + args = parser.parse_args() + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args) diff --git a/TAU/engine.py b/TAU/engine.py index 3641ae9..48b1dd5 100644 --- a/TAU/engine.py +++ b/TAU/engine.py @@ -1,5 +1,6 @@ import math import sys +from unittest import result from cv2 import reduce from numpy import flatnonzero import torch @@ -82,7 +83,7 @@ def evalute(model, criterion, postprocessors, data_loader, base_ds, device, outp orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) results = postprocessors["bbox"](outputs, orig_target_sizes) - res = {targets["image_id"].item(): output for target, output in zip(targets, outputs)} + res = {targets["image_id"].item(): output for target, output in zip(targets, results)} if coco_evaluator is not None: coco_evaluator.update(res) diff --git a/TAU/utils/TAUUtils.py b/TAU/utils/TAUUtils.py index e8480f9..bcd6e5a 100644 --- a/TAU/utils/TAUUtils.py +++ b/TAU/utils/TAUUtils.py @@ -569,3 +569,6 @@ def convert_to_xywh(boxes): xmin, ymin, xmax, ymax = boxes.unbind(1) return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs)