Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ultralytics/yolov5
Browse files Browse the repository at this point in the history
  • Loading branch information
acai66 committed Nov 23, 2021
2 parents f3ce0d7 + 7a39803 commit 9e38ec3
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 40 deletions.
4 changes: 2 additions & 2 deletions detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn)
stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
stride, names, pt, jit, onnx, engine = model.stride, model.names, model.pt, model.jit, model.onnx, model.engine
imgsz = check_img_size(imgsz, s=stride) # check image size

# Half
half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
half &= (pt or engine) and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
if pt:
model.model.half() if half else model.model.float()

Expand Down
55 changes: 54 additions & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
TensorFlow GraphDef | yolov5s.pb | 'pb'
TensorFlow Lite | yolov5s.tflite | 'tflite'
TensorFlow.js | yolov5s_web_model/ | 'tfjs'
TensorRT | yolov5s.engine | 'engine'
Usage:
$ python path/to/export.py --weights yolov5s.pt --include torchscript onnx coreml saved_model pb tflite tfjs
Expand All @@ -24,6 +25,7 @@
yolov5s_saved_model
yolov5s.pb
yolov5s.tflite
yolov5s.engine
TensorFlow.js:
$ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
Expand Down Expand Up @@ -263,6 +265,51 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
LOGGER.info(f'\n{prefix} export failure: {e}')


def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
try:
check_requirements(('tensorrt',))
import tensorrt as trt

opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
export_onnx(model, im, file, opset, train, False, simplify)
onnx = file.with_suffix('.onnx')
assert onnx.exists(), f'failed to export ONNX file: {onnx}'

LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
f = str(file).replace('.pt', '.engine') # TensorRT engine file
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = workspace * 1 << 30

flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
network = builder.create_network(flag)
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(str(onnx)):
raise RuntimeError(f'failed to load ONNX file: {onnx}')

inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
LOGGER.info(f'{prefix} Network Description:')
for inp in inputs:
LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')

half &= builder.platform_has_fast_fp16
LOGGER.info(f'{prefix} building FP{16 if half else 32} engine in {f}')
if half:
config.set_flag(trt.BuilderFlag.FP16)
with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
t.write(engine.serialize())
LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')

except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
Expand All @@ -278,6 +325,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
dynamic=False, # ONNX/TF: dynamic axes
simplify=False, # ONNX: simplify model
opset=12, # ONNX: opset version
verbose=False, # TensorRT: verbose log
workspace=4, # TensorRT: workspace size (GB)
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
Expand Down Expand Up @@ -322,6 +371,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
export_torchscript(model, im, file, optimize)
if 'onnx' in include:
export_onnx(model, im, file, opset, train, dynamic, simplify)
if 'engine' in include:
export_engine(model, im, file, train, half, simplify, workspace, verbose)
if 'coreml' in include:
export_coreml(model, im, file)

Expand Down Expand Up @@ -360,13 +411,15 @@ def parse_opt():
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes')
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
parser.add_argument('--opset', type=int, default=13, help='ONNX: opset version')
parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='available formats are (torchscript, onnx, coreml, saved_model, pb, tflite, tfjs)')
help='available formats are (torchscript, onnx, engine, coreml, saved_model, pb, tflite, tfjs)')
opt = parser.parse_args()
print_args(FILE.stem, opt)
return opt
Expand Down
3 changes: 2 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
"""
from pathlib import Path

from models.common import AutoShape
from models.experimental import attempt_load
from models.yolo import Model
from utils.downloads import attempt_download
Expand Down Expand Up @@ -55,7 +56,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
if len(ckpt['model'].names) == classes:
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
model = AutoShape(model) # for file/URI/PIL/cv2/np inputs and NMS
return model.to(device)

except Exception as e:
Expand Down
40 changes: 31 additions & 9 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math
import platform
import warnings
from collections import namedtuple
from copy import copy
from pathlib import Path

Expand All @@ -23,7 +24,7 @@
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import time_sync
from utils.torch_utils import copy_attr, time_sync


def autopad(k, p=None): # kernel, padding
Expand Down Expand Up @@ -285,11 +286,12 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
# TensorFlow Lite: *.tflite
# ONNX Runtime: *.onnx
# OpenCV DNN: *.onnx with dnn=True
# TensorRT: *.engine
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel']
check_suffix(w, suffixes) # check weights have acceptable suffix
pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
pt, onnx, engine, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
jit = pt and 'torchscript' in w.lower()
stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults

Expand Down Expand Up @@ -317,6 +319,23 @@ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
import onnxruntime
session = onnxruntime.InferenceSession(w, None)
elif engine: # TensorRT
LOGGER.info(f'Loading {w} for TensorRT inference...')
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
logger = trt.Logger(trt.Logger.INFO)
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
model = runtime.deserialize_cuda_engine(f.read())
bindings = dict()
for index in range(model.num_bindings):
name = model.get_binding_name(index)
dtype = trt.nptype(model.get_binding_dtype(index))
shape = tuple(model.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
binding_addrs = {n: d.ptr for n, d in bindings.items()}
context = model.create_execution_context()
batch_size = bindings['images'].shape[0]
else: # TensorFlow model (TFLite, pb, saved_model)
import tensorflow as tf
if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
Expand All @@ -334,7 +353,7 @@ def wrap_frozen_graph(gd, inputs, outputs):
model = tf.keras.models.load_model(w)
elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
if 'edgetpu' in w.lower():
LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
import tflite_runtime.interpreter as tfli
delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
'Darwin': 'libedgetpu.1.dylib',
Expand Down Expand Up @@ -369,6 +388,11 @@ def forward(self, im, augment=False, visualize=False, val=False):
y = self.net.forward()
else: # ONNX Runtime
y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
elif self.engine: # TensorRT
assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
self.binding_addrs['images'] = int(im.data_ptr())
self.context.execute_v2(list(self.binding_addrs.values()))
y = self.bindings['output'].data
else: # TensorFlow model (TFLite, pb, saved_model)
im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
if self.pb:
Expand All @@ -391,7 +415,7 @@ def forward(self, im, augment=False, visualize=False, val=False):
y[..., 1] *= h # y
y[..., 2] *= w # w
y[..., 3] *= h # h
y = torch.tensor(y)
y = torch.tensor(y) if isinstance(y, np.ndarray) else y
return (y, []) if val else y


Expand All @@ -405,12 +429,10 @@ class AutoShape(nn.Module):

def __init__(self, model):
super().__init__()
LOGGER.info('Adding AutoShape... ')
copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
self.model = model.eval()

def autoshape(self):
LOGGER.info('AutoShape already enabled, skipping... ') # model already converted to model.autoshape()
return self

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
Expand Down
1 change: 0 additions & 1 deletion models/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""

import argparse
import logging
import sys
from copy import deepcopy
from pathlib import Path
Expand Down
9 changes: 1 addition & 8 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from utils.autoanchor import check_anchor_order
from utils.general import LOGGER, check_version, check_yaml, make_divisible, print_args
from utils.plots import feature_visualization
from utils.torch_utils import (copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device,
time_sync)
from utils.torch_utils import fuse_conv_and_bn, initialize_weights, model_info, scale_img, select_device, time_sync

try:
import thop # for FLOPs computation
Expand Down Expand Up @@ -226,12 +225,6 @@ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
self.info()
return self

def autoshape(self): # add AutoShape module
LOGGER.info('Adding AutoShape... ')
m = AutoShape(self) # wrap model
copy_attr(m, self, include=('yaml', 'nc', 'hyp', 'names', 'stride'), exclude=()) # copy attributes
return m

def info(self, verbose=False, img_size=640): # print model information
model_info(self, verbose, img_size)

Expand Down
25 changes: 23 additions & 2 deletions tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/update%2Fnotebook/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/ultralytics/yolov5/blob/master/tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
Expand Down Expand Up @@ -412,7 +412,7 @@
"from yolov5 import utils\n",
"display = utils.notebook_init() # checks"
],
"execution_count": 2,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -1081,6 +1081,27 @@
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VTRwsvA9u7ln"
},
"source": [
"# TensorRT \n",
"# https://developer.nvidia.com/nvidia-tensorrt-download\n",
"!lsb_release -a # check system\n",
"%ls /usr/local | grep cuda # check CUDA\n",
"!wget https://ultralytics.com/assets/TensorRT-8.2.0.6.Linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz # download\n",
"![ -d /content/TensorRT-8.2.0.6/ ] || tar -C /content/ -zxf ./TensorRT-8.2.0.6.Linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz # unzip\n",
"%pip list | grep tensorrt || pip install /content/TensorRT-8.2.0.6/python/tensorrt-8.2.0.6-cp37-none-linux_x86_64.whl # install\n",
"%env LD_LIBRARY_PATH=/usr/local/cuda-11.1/lib64:/content/cuda-11.1/lib64:/content/TensorRT-8.2.0.6/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 # add to path\n",
"\n",
"!python export.py --weights yolov5s.pt --include engine --imgsz 640 640 --device 0\n",
"!python detect.py --weights yolov5s.engine --imgsz 640 640 --device 0"
],
"execution_count": null,
"outputs": []
}
]
}
2 changes: 1 addition & 1 deletion utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
IMG_FORMATS = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
VID_FORMATS = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
NUM_THREADS = min(8, os.cpu_count()) # number of multiprocessing threads
NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of multiprocessing threads

# Get orientation exif tag
for orientation in ExifTags.TAGS.keys():
Expand Down
1 change: 0 additions & 1 deletion utils/loggers/wandb/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pathlib import Path
from typing import Dict

import pkg_resources as pkg
import yaml
from tqdm import tqdm

Expand Down
21 changes: 15 additions & 6 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def fitness(x):
return (x[:, :4] * w).sum(1)


def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=(), eps=1e-16):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
Expand All @@ -38,15 +38,15 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]

# Find unique classes
unique_classes = np.unique(target_cls)
unique_classes, nt = np.unique(target_cls, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections

# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
for ci, c in enumerate(unique_classes):
i = pred_cls == c
n_l = (target_cls == c).sum() # number of labels
n_l = nt[ci] # number of labels
n_p = i.sum() # number of predictions

if n_p == 0 or n_l == 0:
Expand All @@ -57,7 +57,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
tpc = tp[i].cumsum(0)

# Recall
recall = tpc / (n_l + 1e-16) # recall curve
recall = tpc / (n_l + eps) # recall curve
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases

# Precision
Expand All @@ -71,7 +71,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
py.append(np.interp(px, mrec, mpre)) # precision at [email protected]

# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)
f1 = 2 * p * r / (p + r + eps)
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
names = {i: v for i, v in enumerate(names)} # to dict
if plot:
Expand All @@ -81,7 +81,10 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall')

i = f1.mean(0).argmax() # max F1 index
return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32')
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
return tp, fp, p, r, f1, ap, unique_classes.astype('int32')


def compute_ap(recall, precision):
Expand Down Expand Up @@ -174,6 +177,12 @@ def process_batch(self, detections, labels):
def matrix(self):
return self.matrix

def tp_fp(self):
tp = self.matrix.diagonal() # true positives
fp = self.matrix.sum(1) - tp # false positives
# fn = self.matrix.sum(0) - tp # false negatives (missed detections)
return tp[:-1], fp[:-1] # remove background class

def plot(self, normalize=True, save_dir='', names=()):
try:
import seaborn as sn
Expand Down
Loading

0 comments on commit 9e38ec3

Please sign in to comment.