Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support pytorch2.0 #500

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.10.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -34,23 +34,17 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
torch: [1.10.1, 1.11.0, 1.12.1, 1.13.1]
torch: [2.0.0]
include:
- torch: 1.10.1
torchvision: 0.11.2
- torch: 1.11.0
torchvision: 0.12.0
- torch: 1.12.1
torchvision: 0.13.1
- torch: 1.13.1
torchvision: 0.14.1
- torch: 2.0.0
torchvision: 0.15.0
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.10.9
- name: Install dependencies
run: |
python -m pip install -U pip
Expand All @@ -67,7 +61,7 @@ jobs:
coverage report -m
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
if: matrix.torch == '1.12.1'
if: matrix.torch == '2.0.0'
with:
file: ./coverage.xml
flags: unittests
Expand Down
58 changes: 0 additions & 58 deletions nanodet/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import re

import torch
from torch._six import string_classes

np_str_obj_array_pattern = re.compile(r"[SaUO]")

default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}"
)


def collate_function(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""

elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif (
elem_type.__module__ == "numpy"
and elem_type.__name__ != "str_"
and elem_type.__name__ != "string_"
):
elem = batch[0]
if elem_type.__name__ == "ndarray":
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))

return batch
elif elem.shape == (): # scalars
return batch
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: collate_function([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
return elem_type(*(collate_function(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
transposed = zip(*batch)
return [collate_function(samples) for samples in transposed]

raise TypeError(default_collate_err_msg_format.format(elem_type))


def naive_collate(batch):
"""Only collate dict value in to a list. E.g. meta data dict and img_info
Expand Down
9 changes: 6 additions & 3 deletions nanodet/model/arch/one_stage_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ def forward(self, x):

def inference(self, meta):
with torch.no_grad():
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
time1 = time.time()
preds = self(meta["img"])
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
time2 = time.time()
print("forward time: {:.3f}s".format((time2 - time1)), end=" | ")
results = self.head.post_process(preds, meta)
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.synchronize()
print("decode time: {:.3f}s".format((time.time() - time2)), end=" | ")
return results

Expand Down
166 changes: 78 additions & 88 deletions nanodet/trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import os
import warnings
from typing import Any, Dict, List
from typing import Any, Dict

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -52,7 +52,10 @@ def __init__(self, cfg, evaluator=None):
self.weight_averager = build_weight_averager(
cfg.model.weight_averager, device=self.device
)
self.avg_model = copy.deepcopy(self.model)
self.avg_model = copy.deepcopy(self.model).requires_grad_(False)

self.validation_step_outputs = [] # save validation results
self.test_step_outputs = [] # save test results

def _preprocess_batch_input(self, batch):
batch_imgs = batch["img"]
Expand Down Expand Up @@ -107,9 +110,6 @@ def training_step(self, batch, batch_idx):

return loss

def training_epoch_end(self, outputs: List[Any]) -> None:
self.trainer.save_checkpoint(os.path.join(self.cfg.save_dir, "model_last.ckpt"))

def validation_step(self, batch, batch_idx):
batch = self._preprocess_batch_input(batch)
if self.weight_averager is not None:
Expand Down Expand Up @@ -138,84 +138,14 @@ def validation_step(self, batch, batch_idx):
self.logger.info(log_msg)

dets = self.model.head.post_process(preds, batch)
self.validation_step_outputs.append(dets)
return dets

def validation_epoch_end(self, validation_step_outputs):
"""
Called at the end of the validation epoch with the
outputs of all validation steps.Evaluating results
and save best model.
Args:
validation_step_outputs: A list of val outputs

"""
results = {}
for res in validation_step_outputs:
results.update(res)
all_results = (
gather_results(results)
if dist.is_available() and dist.is_initialized()
else results
)
if all_results:
eval_results = self.evaluator.evaluate(
all_results, self.cfg.save_dir, rank=self.local_rank
)
metric = eval_results[self.cfg.evaluator.save_key]
# save best model
if metric > self.save_flag:
self.save_flag = metric
best_save_path = os.path.join(self.cfg.save_dir, "model_best")
mkdir(self.local_rank, best_save_path)
self.trainer.save_checkpoint(
os.path.join(best_save_path, "model_best.ckpt")
)
self.save_model_state(
os.path.join(best_save_path, "nanodet_model_best.pth")
)
txt_path = os.path.join(best_save_path, "eval_results.txt")
if self.local_rank < 1:
with open(txt_path, "a") as f:
f.write("Epoch:{}\n".format(self.current_epoch + 1))
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))
else:
warnings.warn(
"Warning! Save_key is not in eval results! Only save model last!"
)
self.logger.log_metrics(eval_results, self.current_epoch + 1)
else:
self.logger.info("Skip val on rank {}".format(self.local_rank))

def test_step(self, batch, batch_idx):
dets = self.predict(batch, batch_idx)
self.test_step_outputs.append(dets)
return dets

def test_epoch_end(self, test_step_outputs):
results = {}
for res in test_step_outputs:
results.update(res)
all_results = (
gather_results(results)
if dist.is_available() and dist.is_initialized()
else results
)
if all_results:
res_json = self.evaluator.results2json(all_results)
json_path = os.path.join(self.cfg.save_dir, "results.json")
json.dump(res_json, open(json_path, "w"))

if self.cfg.test_mode == "val":
eval_results = self.evaluator.evaluate(
all_results, self.cfg.save_dir, rank=self.local_rank
)
txt_path = os.path.join(self.cfg.save_dir, "eval_results.txt")
with open(txt_path, "a") as f:
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))
else:
self.logger.info("Skip test on rank {}".format(self.local_rank))

def configure_optimizers(self):
"""
Prepare optimizer and learning-rate scheduler
Expand All @@ -237,16 +167,7 @@ def configure_optimizers(self):
}
return dict(optimizer=optimizer, lr_scheduler=scheduler)

def optimizer_step(
self,
epoch=None,
batch_idx=None,
optimizer=None,
optimizer_idx=None,
optimizer_closure=None,
on_tpu=None,
using_lbfgs=None,
):
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
"""
Performs a single optimization step (parameter update).
Args:
Expand Down Expand Up @@ -277,7 +198,6 @@ def optimizer_step(

# update params
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()

def scalar_summary(self, tag, phase, value, step):
"""
Expand Down Expand Up @@ -320,6 +240,9 @@ def on_fit_start(self) -> None:
def on_train_epoch_start(self):
self.model.set_epoch(self.current_epoch)

def on_train_epoch_end(self) -> None:
self.trainer.save_checkpoint(os.path.join(self.cfg.save_dir, "model_last.ckpt"))

def on_train_batch_end(self, outputs, batch, batch_idx) -> None:
if self.weight_averager:
self.weight_averager.update(self.model, self.global_step)
Expand All @@ -328,11 +251,78 @@ def on_validation_epoch_start(self):
if self.weight_averager:
self.weight_averager.apply_to(self.avg_model)

def on_validation_epoch_end(self):
results = {}
for res in self.validation_step_outputs:
results.update(res)
all_results = (
gather_results(results)
if dist.is_available() and dist.is_initialized()
else results
)
if all_results:
eval_results = self.evaluator.evaluate(
all_results, self.cfg.save_dir, rank=self.local_rank
)
metric = eval_results[self.cfg.evaluator.save_key]
# save best model
if metric > self.save_flag:
self.save_flag = metric
best_save_path = os.path.join(self.cfg.save_dir, "model_best")
mkdir(self.local_rank, best_save_path)
self.trainer.save_checkpoint(
os.path.join(best_save_path, "model_best.ckpt")
)
self.save_model_state(
os.path.join(best_save_path, "nanodet_model_best.pth")
)
txt_path = os.path.join(best_save_path, "eval_results.txt")
if self.local_rank < 1:
with open(txt_path, "a") as f:
f.write("Epoch:{}\n".format(self.current_epoch + 1))
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))
else:
warnings.warn(
"Warning! Save_key is not in eval results! Only save model last!"
)
self.logger.log_metrics(eval_results, self.current_epoch + 1)
else:
self.logger.info("Skip val on rank {}".format(self.local_rank))

self.validation_step_outputs.clear() # free memory

def on_test_epoch_start(self) -> None:
if self.weight_averager:
self.on_load_checkpoint({"state_dict": self.state_dict()})
self.weight_averager.apply_to(self.model)

def on_test_epoch_end(self):
results = {}
for res in self.test_step_outputs:
results.update(res)
all_results = (
gather_results(results)
if dist.is_available() and dist.is_initialized()
else results
)
if all_results:
res_json = self.evaluator.results2json(all_results)
json_path = os.path.join(self.cfg.save_dir, "results.json")
json.dump(res_json, open(json_path, "w"))

if self.cfg.test_mode == "val":
eval_results = self.evaluator.evaluate(
all_results, self.cfg.save_dir, rank=self.local_rank
)
txt_path = os.path.join(self.cfg.save_dir, "eval_results.txt")
with open(txt_path, "a") as f:
for k, v in eval_results.items():
f.write("{}: {}\n".format(k, v))
else:
self.logger.info("Skip test on rank {}".format(self.local_rank))
self.test_step_outputs.clear() # free memory

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None:
if self.weight_averager:
avg_params = convert_avg_params(checkpointed_state)
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ onnx-simplifier
opencv-python
pyaml
pycocotools
pytorch-lightning>=1.9.0,<2.0.0
pytorch-lightning>=2.0.0
tabulate
tensorboard
termcolor
torch>=1.10,<2.0
torch>=2.0
torchmetrics
torchvision
tqdm
Loading