diff --git a/Dockerfile b/Dockerfile index 6727548fa5..d0a797fdf3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,12 +17,6 @@ WORKDIR /usr/src/app # Copy contents COPY . /usr/src/app -# Copy weights -#RUN python3 -c "from models import *; \ -#attempt_download('weights/yolov3.pt'); \ -#attempt_download('weights/yolov3-spp.pt'); \ -#attempt_download('weights/yolov3-tiny.pt')" - # --------------------------------------------------- Extras Below --------------------------------------------------- @@ -31,7 +25,7 @@ COPY . /usr/src/app # for v in {300..303}; do t=ultralytics/coco:v$v && sudo docker build -t $t . && sudo docker push $t; done # Pull and Run -# t=ultralytics/yolov3:latest && sudo docker pull $t && sudo docker run -it --ipc=host $t +# t=ultralytics/yolov3:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t # Pull and Run with local directory access # t=ultralytics/yolov3:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/coco:/usr/src/coco $t diff --git a/detect.py b/detect.py index 4e4de61f34..2ccc20d31a 100644 --- a/detect.py +++ b/detect.py @@ -9,8 +9,8 @@ from models.experimental import attempt_load from utils.datasets import LoadStreams, LoadImages -from utils.general import check_img_size, non_max_suppression, apply_classifier, scale_coords, xyxy2xywh, \ - strip_optimizer, set_logging, increment_path +from utils.general import check_img_size, check_requirements, non_max_suppression, apply_classifier, scale_coords, \ + xyxy2xywh, strip_optimizer, set_logging, increment_path from utils.plots import plot_one_box from utils.torch_utils import select_device, load_classifier, time_synchronized @@ -81,12 +81,13 @@ def detect(save_img=False): # Process detections for i, det in enumerate(pred): # detections per image if webcam: # batch_size >= 1 - p, s, im0 = Path(path[i]), '%g: ' % i, im0s[i].copy() + p, s, im0, frame = path[i], '%g: ' % i, im0s[i].copy(), dataset.count else: - p, s, im0 = Path(path), '', im0s + p, s, im0, frame = path, '', im0s, getattr(dataset, 'frame', 0) - save_path = str(save_dir / p.name) - txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '') + p = Path(p) # to Path + save_path = str(save_dir / p.name) # img.jpg + txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt s += '%gx%g ' % img.shape[2:] # print string gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh if len(det): @@ -96,7 +97,7 @@ def detect(save_img=False): # Print results for c in det[:, -1].unique(): n = (det[:, -1] == c).sum() # detections per class - s += '%g %ss, ' % (n, names[int(c)]) # add to string + s += f'{n} {names[int(c)]}s, ' # add to string # Write results for *xyxy, conf, cls in reversed(det): @@ -107,23 +108,21 @@ def detect(save_img=False): f.write(('%g ' * len(line)).rstrip() % line + '\n') if save_img or view_img: # Add bbox to image - label = '%s %.2f' % (names[int(cls)], conf) + label = f'{names[int(cls)]} {conf:.2f}' plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) # Print time (inference + NMS) - print('%sDone. (%.3fs)' % (s, t2 - t1)) + print(f'{s}Done. ({t2 - t1:.3f}s)') # Stream results if view_img: cv2.imshow(str(p), im0) - if cv2.waitKey(1) == ord('q'): # q to quit - raise StopIteration # Save results (image with detections) if save_img: - if dataset.mode == 'images': + if dataset.mode == 'image': cv2.imwrite(save_path, im0) - else: + else: # 'video' if vid_path != save_path: # new video vid_path = save_path if isinstance(vid_writer, cv2.VideoWriter): @@ -140,7 +139,7 @@ def detect(save_img=False): s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' print(f"Results saved to {save_dir}{s}") - print('Done. (%.3fs)' % (time.time() - t0)) + print(f'Done. ({time.time() - t0:.3f}s)') if __name__ == '__main__': @@ -163,6 +162,7 @@ def detect(save_img=False): parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') opt = parser.parse_args() print(opt) + check_requirements() with torch.no_grad(): if opt.update: # update all models (to fix SourceChangeWarning) diff --git a/hubconf.py b/hubconf.py index df7796d4ba..7b0f20115e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -17,7 +17,7 @@ set_logging() -def create(name, pretrained, channels, classes): +def create(name, pretrained, channels, classes, autoshape): """Creates a specified YOLOv3 model Arguments: @@ -41,7 +41,8 @@ def create(name, pretrained, channels, classes): model.load_state_dict(state_dict, strict=False) # load if len(ckpt['model'].names) == classes: model.names = ckpt['model'].names # set class names attribute - # model = model.autoshape() # for PIL/cv2/np inputs and NMS + if autoshape: + model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS return model except Exception as e: @@ -50,7 +51,7 @@ def create(name, pretrained, channels, classes): raise Exception(s) from e -def yolov3(pretrained=False, channels=3, classes=80): +def yolov3(pretrained=False, channels=3, classes=80, autoshape=True): """YOLOv3 model from https://github.com/ultralytics/yolov3 Arguments: @@ -61,10 +62,10 @@ def yolov3(pretrained=False, channels=3, classes=80): Returns: pytorch model """ - return create('yolov3', pretrained, channels, classes) + return create('yolov3', pretrained, channels, classes, autoshape) -def yolov3_spp(pretrained=False, channels=3, classes=80): +def yolov3_spp(pretrained=False, channels=3, classes=80, autoshape=True): """YOLOv3-SPP model from https://github.com/ultralytics/yolov3 Arguments: @@ -75,10 +76,10 @@ def yolov3_spp(pretrained=False, channels=3, classes=80): Returns: pytorch model """ - return create('yolov3-spp', pretrained, channels, classes) + return create('yolov3-spp', pretrained, channels, classes, autoshape) -def yolov3_tiny(pretrained=False, channels=3, classes=80): +def yolov3_tiny(pretrained=False, channels=3, classes=80, autoshape=True): """YOLOv3-tiny model from https://github.com/ultralytics/yolov3 Arguments: @@ -89,16 +90,17 @@ def yolov3_tiny(pretrained=False, channels=3, classes=80): Returns: pytorch model """ - return create('yolov3-tiny', pretrained, channels, classes) + return create('yolov3-tiny', pretrained, channels, classes, autoshape) -def custom(path_or_model='path/to/model.pt'): +def custom(path_or_model='path/to/model.pt', autoshape=True): """YOLOv3-custom model from https://github.com/ultralytics/yolov3 - + Arguments (3 options): path_or_model (str): 'path/to/model.pt' path_or_model (dict): torch.load('path/to/model.pt') path_or_model (nn.Module): torch.load('path/to/model.pt')['model'] + Returns: pytorch model """ @@ -109,13 +111,12 @@ def custom(path_or_model='path/to/model.pt'): hub_model = Model(model.yaml).to(next(model.parameters()).device) # create hub_model.load_state_dict(model.float().state_dict()) # load state_dict hub_model.names = model.names # class names - return hub_model + return hub_model.autoshape() if autoshape else hub_model if __name__ == '__main__': - model = create(name='yolov3', pretrained=True, channels=3, classes=80) # pretrained example + model = create(name='yolov3', pretrained=True, channels=3, classes=80, autoshape=True) # pretrained example # model = custom(path_or_model='path/to/model.pt') # custom example - model = model.autoshape() # for PIL/cv2/np inputs and NMS # Verify inference from PIL import Image diff --git a/models/common.py b/models/common.py index f26ffdd0ff..fd9d9fcdd7 100644 --- a/models/common.py +++ b/models/common.py @@ -1,7 +1,9 @@ # This file contains modules common to various models import math + import numpy as np +import requests import torch import torch.nn as nn from PIL import Image, ImageDraw @@ -29,7 +31,7 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, k super(Conv, self).__init__() self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) self.bn = nn.BatchNorm2d(c2) - self.act = nn.LeakyReLU(0.1) if act else nn.Identity() + self.act = nn.LeakyReLU(0.1) if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) def forward(self, x): return self.act(self.bn(self.conv(x))) @@ -70,6 +72,21 @@ def forward(self, x): return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super(C3, self).__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)]) + # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + class SPP(nn.Module): # Spatial pyramid pooling layer used in YOLOv3-SPP def __init__(self, c1, c2, k=(5, 9, 13)): @@ -89,9 +106,39 @@ class Focus(nn.Module): def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups super(Focus, self).__init__() self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + # self.contract = Contract(gain=2) def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + # return self.conv(self.contract(x)) + + +class Contract(nn.Module): + # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert (H / s == 0) and (W / s == 0), 'Indivisible gain' + s = self.gain + x = x.view(N, C, H // s, s, W // s, s) # x(1,64,40,2,40,2) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40) + return x.view(N, C * s * s, H // s, W // s) # x(1,256,40,40) + + +class Expand(nn.Module): + # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160) + def __init__(self, gain=2): + super().__init__() + self.gain = gain + + def forward(self, x): + N, C, H, W = x.size() # assert C / s ** 2 == 0, 'Indivisible gain' + s = self.gain + x = x.view(N, s, s, C // s ** 2, H, W) # x(1,2,2,16,80,80) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2) + return x.view(N, C // s ** 2, H * s, W * s) # x(1,16,160,160) class Concat(nn.Module): @@ -128,35 +175,42 @@ def __init__(self, model): super(autoShape, self).__init__() self.model = model.eval() + def autoshape(self): + print('autoShape already enabled, skipping... ') # model already converted to model.autoshape() + return self + def forward(self, imgs, size=640, augment=False, profile=False): - # supports inference from various sources. For height=720, width=1280, RGB images example inputs are: - # opencv: imgs = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) - # PIL: imgs = Image.open('image.jpg') # HWC x(720,1280,3) - # numpy: imgs = np.zeros((720,1280,3)) # HWC - # torch: imgs = torch.zeros(16,3,720,1280) # BCHW - # multiple: imgs = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # filename: imgs = 'data/samples/zidane.jpg' + # URI: = 'https://github.com/ultralytics/yolov5/releases/download/v1.0/zidane.jpg' + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images p = next(self.model.parameters()) # for device and type if isinstance(imgs, torch.Tensor): # torch return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference # Pre-process - if not isinstance(imgs, list): - imgs = [imgs] + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images shape0, shape1 = [], [] # image and inference shapes - batch = range(len(imgs)) # batch size - for i in batch: - imgs[i] = np.array(imgs[i]) # to numpy - if imgs[i].shape[0] < 5: # image in CHW - imgs[i] = imgs[i].transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) - imgs[i] = imgs[i][:, :, :3] if imgs[i].ndim == 3 else np.tile(imgs[i][:, :, None], 3) # enforce 3ch input - s = imgs[i].shape[:2] # HWC + for i, im in enumerate(imgs): + if isinstance(im, str): # filename or uri + im = Image.open(requests.get(im, stream=True).raw if im.startswith('http') else im) # open + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC shape0.append(s) # image shape g = (size / max(s)) # gain shape1.append([y * g for y in s]) + imgs[i] = im # update shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape - x = [letterbox(imgs[i], new_shape=shape1, auto=False)[0] for i in batch] # pad - x = np.stack(x, 0) if batch[-1] else x[0][None] # stack + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32 @@ -166,7 +220,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS # Post-process - for i in batch: + for i in range(n): scale_coords(shape1, y[i][:, :4], shape0[i]) return Detections(imgs, y, self.names) @@ -187,7 +241,7 @@ def __init__(self, imgs, pred, names=None): self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized self.n = len(self.pred) - def display(self, pprint=False, show=False, save=False): + def display(self, pprint=False, show=False, save=False, render=False): colors = color_list() for i, (img, pred) in enumerate(zip(self.imgs, self.pred)): str = f'Image {i + 1}/{len(self.pred)}: {img.shape[0]}x{img.shape[1]} ' @@ -195,19 +249,21 @@ def display(self, pprint=False, show=False, save=False): for c in pred[:, -1].unique(): n = (pred[:, -1] == c).sum() # detections per class str += f'{n} {self.names[int(c)]}s, ' # add to string - if show or save: + if show or save or render: img = Image.fromarray(img.astype(np.uint8)) if isinstance(img, np.ndarray) else img # from np for *box, conf, cls in pred: # xyxy, confidence, class # str += '%s %.2f, ' % (names[int(cls)], conf) # label ImageDraw.Draw(img).rectangle(box, width=4, outline=colors[int(cls) % 10]) # plot + if pprint: + print(str) + if show: + img.show(f'Image {i}') # show if save: f = f'results{i}.jpg' str += f"saved to '{f}'" img.save(f) # save - if show: - img.show(f'Image {i}') # show - if pprint: - print(str) + if render: + self.imgs[i] = np.asarray(img) def print(self): self.display(pprint=True) # print results @@ -218,6 +274,10 @@ def show(self): def save(self): self.display(save=True) # save results + def render(self): + self.display(render=True) # render results + return self.imgs + def __len__(self): return self.n @@ -230,20 +290,13 @@ def tolist(self): return x -class Flatten(nn.Module): - # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions - @staticmethod - def forward(x): - return x.view(x.size(0), -1) - - class Classify(nn.Module): # Classification head, i.e. x(b,c1,20,20) to x(b,c2) def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups super(Classify, self).__init__() self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1) self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1) - self.flat = Flatten() + self.flat = nn.Flatten() def forward(self, x): z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list diff --git a/models/experimental.py b/models/experimental.py index a2908a15cf..2dbbf7fa32 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -22,25 +22,6 @@ def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) -class C3(nn.Module): - # Cross Convolution CSP - def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion - super(C3, self).__init__() - c_ = int(c2 * e) # hidden channels - self.cv1 = Conv(c1, c_, 1, 1) - self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) - self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) - self.cv4 = Conv(2 * c_, c2, 1, 1) - self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) - self.act = nn.LeakyReLU(0.1, inplace=True) - self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)]) - - def forward(self, x): - y1 = self.cv3(self.m(self.cv1(x))) - y2 = self.cv2(x) - return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) - - class Sum(nn.Module): # Weighted sum of 2 or more layers https://arxiv.org/abs/1911.09070 def __init__(self, n, weight=False): # n: number of inputs @@ -124,8 +105,8 @@ def forward(self, x, augment=False): for module in self: y.append(module(x, augment)[0]) # y = torch.stack(y).max(0)[0] # max ensemble - # y = torch.cat(y, 1) # nms ensemble - y = torch.stack(y).mean(0) # mean ensemble + # y = torch.stack(y).mean(0) # mean ensemble + y = torch.cat(y, 1) # nms ensemble return y, None # inference, train output diff --git a/models/export.py b/models/export.py index 7fbc3d9599..49df43e9fe 100644 --- a/models/export.py +++ b/models/export.py @@ -15,7 +15,7 @@ import models from models.experimental import attempt_load -from utils.activations import Hardswish +from utils.activations import Hardswish, SiLU from utils.general import set_logging, check_img_size if __name__ == '__main__': @@ -43,9 +43,12 @@ # Update model for k, m in model.named_modules(): m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility - if isinstance(m, models.common.Conv) and isinstance(m.act, nn.Hardswish): - m.act = Hardswish() # assign activation - # if isinstance(m, models.yolo.Detect): + if isinstance(m, models.common.Conv): # assign export-friendly activations + if isinstance(m.act, nn.Hardswish): + m.act = Hardswish() + elif isinstance(m.act, nn.SiLU): + m.act = SiLU() + # elif isinstance(m, models.yolo.Detect): # m.forward = m.forward_export # assign forward (optional) model.model[-1].export = True # set Detect() layer export=True y = model(img) # dry run diff --git a/models/yolo.py b/models/yolo.py index 7ef9d501eb..9f47100999 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -4,15 +4,11 @@ from copy import deepcopy from pathlib import Path -import math -import torch -import torch.nn as nn - sys.path.append('./') # to run '$ python *.py' files in subdirectories logger = logging.getLogger(__name__) -from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, NMS, autoShape -from models.experimental import MixConv2d, CrossConv, C3 +from models.common import * +from models.experimental import MixConv2d, CrossConv from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ @@ -78,17 +74,18 @@ def __init__(self, cfg='yolov3.yaml', ch=3, nc=None): # model, input channels, self.yaml = yaml.load(f, Loader=yaml.FullLoader) # model dict # Define model + ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels if nc and nc != self.yaml['nc']: logger.info('Overriding model.yaml nc=%g with nc=%g' % (self.yaml['nc'], nc)) self.yaml['nc'] = nc # override yaml value - self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist, ch_out + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist self.names = [str(i) for i in range(self.yaml['nc'])] # default names # print([x.shape for x in self.forward(torch.zeros(1, ch, 64, 64))]) # Build strides, anchors m = self.model[-1] # Detect() if isinstance(m, Detect): - s = 128 # 2x min stride + s = 256 # 2x min stride m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward m.anchors /= m.stride.view(-1, 1, 1) check_anchor_order(m) @@ -108,7 +105,7 @@ def forward(self, x, augment=False, profile=False): f = [None, 3, None] # flips (2-ud, 3-lr) y = [] # outputs for si, fi in zip(s, f): - xi = scale_img(x.flip(fi) if fi else x, si) + xi = scale_img(x.flip(fi) if fi else x, si, gs=int(self.stride.max())) yi = self.forward_once(xi)[0] # forward # cv2.imwrite('img%g.jpg' % s, 255 * xi[0].numpy().transpose((1, 2, 0))[:, :, ::-1]) # save yi[..., :4] /= si # de-scale @@ -241,13 +238,17 @@ def parse_model(d, ch): # model_dict, input_channels(3) elif m is nn.BatchNorm2d: args = [ch[f]] elif m is Concat: - c2 = sum([ch[-1 if x == -1 else x + 1] for x in f]) + c2 = sum([ch[x if x < 0 else x + 1] for x in f]) elif m is Detect: args.append([ch[x + 1] for x in f]) if isinstance(args[1], int): # number of anchors args[1] = [list(range(args[1] * 2))] * len(f) + elif m is Contract: + c2 = ch[f if f < 0 else f + 1] * args[0] ** 2 + elif m is Expand: + c2 = ch[f if f < 0 else f + 1] // args[0] ** 2 else: - c2 = ch[f] + c2 = ch[f if f < 0 else f + 1] m_ = nn.Sequential(*[m(*args) for _ in range(n)]) if n > 1 else m(*args) # module t = str(m)[8:-2].replace('__main__.', '') # module type diff --git a/requirements.txt b/requirements.txt index 4cb16138a0..3c23f2b750 100755 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,7 @@ tqdm>=4.41.0 # wandb # plotting ------------------------------------ -seaborn +seaborn>=0.11.0 pandas # export -------------------------------------- diff --git a/test.py b/test.py index 5c8a70b93a..c570a7889a 100644 --- a/test.py +++ b/test.py @@ -11,8 +11,8 @@ from models.experimental import attempt_load from utils.datasets import create_dataloader -from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \ - non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path +from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \ + box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr from utils.loss import compute_loss from utils.metrics import ap_per_class, ConfusionMatrix from utils.plots import plot_images, output_to_target, plot_study_txt @@ -86,7 +86,8 @@ def test(data, img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once path = data['test'] if opt.task == 'test' else data['val'] # path to val/test images - dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0] + dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True, + prefix=colorstr('test: ' if opt.task == 'test' else 'val: '))[0] seen = 0 confusion_matrix = ConfusionMatrix(nc=nc) @@ -226,7 +227,7 @@ def test(data, print(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) # Print results per class - if verbose and nc > 1 and len(stats): + if (verbose or (nc <= 20 and not training)) and nc > 1 and len(stats): for i, c in enumerate(ap_class): print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) @@ -302,6 +303,7 @@ def test(data, opt.save_json |= opt.data.endswith('coco.yaml') opt.data = check_file(opt.data) # check file print(opt) + check_requirements() if opt.task in ['val', 'test']: # run normally test(opt.data, diff --git a/train.py b/train.py index 91c8084fea..9f869cfd92 100644 --- a/train.py +++ b/train.py @@ -1,13 +1,12 @@ import argparse import logging +import math import os import random import time from pathlib import Path from threading import Thread -from warnings import warn -import math import numpy as np import torch.distributed as dist import torch.nn as nn @@ -28,7 +27,7 @@ from utils.datasets import create_dataloader from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ - print_mutation, set_logging + check_requirements, print_mutation, set_logging, one_cycle, colorstr from utils.google_utils import attempt_download from utils.loss import compute_loss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution @@ -36,15 +35,9 @@ logger = logging.getLogger(__name__) -try: - import wandb -except ImportError: - wandb = None - logger.info("Install Weights & Biases for experiment logging via 'pip install wandb' (recommended)") - def train(hyp, opt, device, tb_writer=None, wandb=None): - logger.info(f'Hyperparameters {hyp}') + logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items())) save_dir, epochs, batch_size, total_batch_size, weights, rank = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank @@ -71,7 +64,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): check_dataset(data_dict) # check train_path = data_dict['train'] test_path = data_dict['val'] - nc, names = (1, ['item']) if opt.single_cls else (int(data_dict['nc']), data_dict['names']) # number classes, names + nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes + names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check # Model @@ -103,6 +97,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): nbs = 64 # nominal batch size accumulate = max(round(nbs / total_batch_size), 1) # accumulate loss before optimizing hyp['weight_decay'] *= total_batch_size * accumulate / nbs # scale weight_decay + logger.info(f"Scaled weight_decay = {hyp['weight_decay']}") pg0, pg1, pg2 = [], [], [] # optimizer parameter groups for k, v in model.named_modules(): @@ -125,12 +120,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Scheduler https://arxiv.org/pdf/1812.01187.pdf # https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR - lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (1 - hyp['lrf']) + hyp['lrf'] # cosine + lf = one_cycle(1, hyp['lrf'], epochs) # cosine 1->hyp['lrf'] scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) # plot_lr_scheduler(optimizer, scheduler, epochs) # Logging - if wandb and wandb.run is None: + if rank in [-1, 0] and wandb and wandb.run is None: opt.hyp = hyp # add hyperparameters wandb_run = wandb.init(config=opt, resume="allow", project='YOLOv3' if opt.project == 'runs/train' else Path(opt.project).stem, @@ -163,7 +158,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): del ckpt, state_dict # Image sizes - gs = int(max(model.stride)) # grid size (max stride) + gs = int(model.stride.max()) # grid size (max stride) + nl = model.model[-1].nl # number of detection layers (used for scaling hyp['obj']) imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples # DP mode @@ -186,7 +182,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank, world_size=opt.world_size, workers=opt.workers, - image_weights=opt.image_weights) + image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: ')) mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class nb = len(dataloader) # number of batches assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1) @@ -195,8 +191,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: ema.updates = start_epoch * nb // accumulate # set EMA updates testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader - hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, - rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5)[0] + hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, + world_size=opt.world_size, workers=opt.workers, + pad=0.5, prefix=colorstr('val: '))[0] if not opt.resume: labels = np.concatenate(dataset.labels, 0) @@ -204,7 +201,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start() + plot_labels(labels, save_dir, loggers) if tb_writer: tb_writer.add_histogram('classes', c, 0) @@ -213,11 +210,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # Model parameters - hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset + hyp['box'] *= 3. / nl # scale to layers + hyp['cls'] *= nc / 80. * 3. / nl # scale to classes and layers + hyp['obj'] *= (imgsz / 640) ** 2 * 3. / nl # scale to image size and layers model.nc = nc # attach number of classes to model model.hyp = hyp # attach hyperparameters to model model.gr = 1.0 # iou loss ratio (obj_loss = 1.0 or iou) - model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights + model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights model.names = names # Start training @@ -228,9 +227,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): results = (0, 0, 0, 0, 0, 0, 0) # P, R, mAP@.5, mAP@.5-.95, val_loss(box, obj, cls) scheduler.last_epoch = start_epoch - 1 # do not move scaler = amp.GradScaler(enabled=cuda) - logger.info('Image sizes %g train, %g test\n' - 'Using %g dataloader workers\nLogging results to %s\n' - 'Starting training for %g epochs...' % (imgsz, imgsz_test, dataloader.num_workers, save_dir, epochs)) + logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' + f'Using {dataloader.num_workers} dataloader workers\n' + f'Logging results to {save_dir}\n' + f'Starting training for {epochs} epochs...') for epoch in range(start_epoch, epochs): # epoch ------------------------------------------------------------------ model.train() @@ -238,7 +238,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if opt.image_weights: # Generate indices if rank in [-1, 0]: - cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 # class weights + cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx # Broadcast if DDP @@ -289,6 +289,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size if rank != -1: loss *= opt.world_size # gradient averaged between devices in DDP mode + if opt.quad: + loss *= 4. # Backward scaler.scale(loss).backward() @@ -330,7 +332,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: # mAP if ema: - ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP results, maps, times = test.test(opt.data, @@ -386,10 +388,12 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if rank in [-1, 0]: # Strip optimizers + final = best if best.exists() else last # final model for f in [last, best]: - if f.exists(): # is *.pt - strip_optimizer(f) # strip optimizer - os.system('gsutil cp %s gs://%s/weights' % (f, opt.bucket)) if opt.bucket else None # upload + if f.exists(): + strip_optimizer(f) # strip optimizers + if opt.bucket: + os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload # Plots if plots: @@ -398,19 +402,24 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png'] wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files if (save_dir / f).exists()]}) - logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) + if opt.log_artifacts: + wandb.log_artifact(artifact_or_path=str(final), type='model', name=save_dir.stem) # Test best.pt + logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) if opt.data.endswith('coco.yaml') and nc == 80: # if COCO - results, _, _ = test.test(opt.data, - batch_size=total_batch_size, - imgsz=imgsz_test, - model=attempt_load(best if best.exists() else last, device).half(), - single_cls=opt.single_cls, - dataloader=testloader, - save_dir=save_dir, - save_json=True, # use pycocotools - plots=False) + for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests + results, _, _ = test.test(opt.data, + batch_size=total_batch_size, + imgsz=imgsz_test, + conf_thres=conf, + iou_thres=iou, + model=attempt_load(final, device).half(), + single_cls=opt.single_cls, + dataloader=testloader, + save_dir=save_dir, + save_json=save_json, + plots=False) else: dist.destroy_process_group() @@ -440,32 +449,35 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): parser.add_argument('--image-weights', action='store_true', help='use weighted image selection for training') parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--multi-scale', action='store_true', help='vary img-size +/- 50%%') - parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') + parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class') parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer') parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode') parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') parser.add_argument('--log-imgs', type=int, default=16, help='number of images for W&B logging, max 100') + parser.add_argument('--log-artifacts', action='store_true', help='log artifacts, i.e. final trained model') parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers') parser.add_argument('--project', default='runs/train', help='save to project/name') parser.add_argument('--name', default='exp', help='save to project/name') parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') + parser.add_argument('--quad', action='store_true', help='quad dataloader') opt = parser.parse_args() # Set DDP variables - opt.total_batch_size = opt.batch_size opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 set_logging(opt.global_rank) if opt.global_rank in [-1, 0]: check_git_status() + check_requirements() # Resume if opt.resume: # resume an interrupted run ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' + apriori = opt.global_rank, opt.local_rank with open(Path(ckpt).parent.parent / 'opt.yaml') as f: opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace - opt.cfg, opt.weights, opt.resume = '', ckpt, True + opt.cfg, opt.weights, opt.resume, opt.global_rank, opt.local_rank = '', ckpt, True, *apriori # reinstate logger.info('Resuming training from %s' % ckpt) else: # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') @@ -476,6 +488,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run # DDP mode + opt.total_batch_size = opt.batch_size device = select_device(opt.device, batch_size=opt.batch_size) if opt.local_rank != -1: assert torch.cuda.device_count() > opt.local_rank @@ -488,13 +501,15 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Hyperparameters with open(opt.hyp) as f: hyp = yaml.load(f, Loader=yaml.FullLoader) # load hyps - if 'box' not in hyp: - warn('Compatibility: %s missing "box" which was renamed from "giou" in %s' % - (opt.hyp, 'https://github.com/ultralytics/yolov5/pull/1120')) - hyp['box'] = hyp.pop('giou') # Train logger.info(opt) + try: + import wandb + except ImportError: + wandb = None + prefix = colorstr('wandb: ') + logger.info(f"{prefix}Install Weights & Biases for YOLOv3 logging with 'pip install wandb' (recommended)") if not opt.evolve: tb_writer = None # init loggers if opt.global_rank in [-1, 0]: diff --git a/utils/activations.py b/utils/activations.py index 24f5a30f02..aa3ddf071d 100644 --- a/utils/activations.py +++ b/utils/activations.py @@ -5,8 +5,8 @@ import torch.nn.functional as F -# Swish https://arxiv.org/pdf/1905.02244.pdf --------------------------------------------------------------------------- -class Swish(nn.Module): # +# SiLU https://arxiv.org/pdf/1606.08415.pdf ---------------------------------------------------------------------------- +class SiLU(nn.Module): # export-friendly version of nn.SiLU() @staticmethod def forward(x): return x * torch.sigmoid(x) diff --git a/utils/autoanchor.py b/utils/autoanchor.py index 98fea9981f..c6e6b9daf5 100644 --- a/utils/autoanchor.py +++ b/utils/autoanchor.py @@ -6,6 +6,8 @@ from scipy.cluster.vq import kmeans from tqdm import tqdm +from utils.general import colorstr + def check_anchor_order(m): # Check anchor order against stride order for YOLOv3 Detect() module m, and correct if necessary @@ -20,7 +22,8 @@ def check_anchor_order(m): def check_anchors(dataset, model, thr=4.0, imgsz=640): # Check anchor fit to data, recompute if necessary - print('\nAnalyzing anchors... ', end='') + prefix = colorstr('autoanchor: ') + print(f'\n{prefix}Analyzing anchors... ', end='') m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale @@ -35,7 +38,7 @@ def metric(k): # compute metric return bpr, aat bpr, aat = metric(m.anchor_grid.clone().cpu().view(-1, 2)) - print('anchors/target = %.2f, Best Possible Recall (BPR) = %.4f' % (aat, bpr), end='') + print(f'anchors/target = {aat:.2f}, Best Possible Recall (BPR) = {bpr:.4f}', end='') if bpr < 0.98: # threshold to recompute print('. Attempting to improve anchors, please wait...') na = m.anchor_grid.numel() // 2 # number of anchors @@ -46,9 +49,9 @@ def metric(k): # compute metric m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss check_anchor_order(m) - print('New anchors saved to model. Update model *.yaml to use these anchors in the future.') + print(f'{prefix}New anchors saved to model. Update model *.yaml to use these anchors in the future.') else: - print('Original anchors better than new anchors. Proceeding with original anchors.') + print(f'{prefix}Original anchors better than new anchors. Proceeding with original anchors.') print('') # newline @@ -70,6 +73,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10 from utils.autoanchor import *; _ = kmean_anchors() """ thr = 1. / thr + prefix = colorstr('autoanchor: ') def metric(k, wh): # compute metrics r = wh[:, None] / k[None] @@ -85,9 +89,9 @@ def print_results(k): k = k[np.argsort(k.prod(1))] # sort small to large x, best = metric(k, wh0) bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr - print('thr=%.2f: %.4f best possible recall, %.2f anchors past thr' % (thr, bpr, aat)) - print('n=%g, img_size=%s, metric_all=%.3f/%.3f-mean/best, past_thr=%.3f-mean: ' % - (n, img_size, x.mean(), best.mean(), x[x > thr].mean()), end='') + print(f'{prefix}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr') + print(f'{prefix}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' + f'past_thr={x[x > thr].mean():.3f}-mean: ', end='') for i, x in enumerate(k): print('%i,%i' % (round(x[0]), round(x[1])), end=', ' if i < len(k) - 1 else '\n') # use in *.cfg return k @@ -107,12 +111,12 @@ def print_results(k): # Filter i = (wh0 < 3.0).any(1).sum() if i: - print('WARNING: Extremely small objects found. ' - '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0))) + print(f'{prefix}WARNING: Extremely small objects found. {i} of {len(wh0)} labels are < 3 pixels in size.') wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels + # wh = wh * (np.random.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 # Kmeans calculation - print('Running kmeans for %g anchors on %g points...' % (n, len(wh))) + print(f'{prefix}Running kmeans for {n} anchors on {len(wh)} points...') s = wh.std(0) # sigmas for whitening k, dist = kmeans(wh / s, n, iter=30) # points, mean distance k *= s @@ -135,7 +139,7 @@ def print_results(k): # Evolve npr = np.random f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma - pbar = tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm') # progress bar + pbar = tqdm(range(gen), desc=f'{prefix}Evolving anchors with Genetic Algorithm:') # progress bar for _ in pbar: v = np.ones(sh) while (v == 1).all(): # mutate until a change occurs (prevent duplicates) @@ -144,7 +148,7 @@ def print_results(k): fg = anchor_fitness(kg) if fg > f: f, k = fg, kg.copy() - pbar.desc = 'Evolving anchors with Genetic Algorithm: fitness = %.4f' % f + pbar.desc = f'{prefix}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' if verbose: print_results(k) diff --git a/utils/datasets.py b/utils/datasets.py index 4b87004523..d2002fabfc 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -15,11 +15,12 @@ import cv2 import numpy as np import torch +import torch.nn.functional as F from PIL import Image, ExifTags from torch.utils.data import Dataset from tqdm import tqdm -from utils.general import xyxy2xywh, xywh2xyxy +from utils.general import xyxy2xywh, xywh2xyxy, clean_str from utils.torch_utils import torch_distributed_zero_first # Parameters @@ -55,7 +56,7 @@ def exif_size(img): def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, - rank=-1, world_size=1, workers=8, image_weights=False): + rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''): # Make sure only the first process in DDP process the dataset first, and the following others can use the cache with torch_distributed_zero_first(rank): dataset = LoadImagesAndLabels(path, imgsz, batch_size, @@ -66,8 +67,8 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa single_cls=opt.single_cls, stride=int(stride), pad=pad, - rank=rank, - image_weights=image_weights) + image_weights=image_weights, + prefix=prefix) batch_size = min(batch_size, len(dataset)) nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers @@ -79,7 +80,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa num_workers=nw, sampler=sampler, pin_memory=True, - collate_fn=LoadImagesAndLabels.collate_fn) + collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn) return dataloader, dataset @@ -128,7 +129,7 @@ def __init__(self, path, img_size=640): elif os.path.isfile(p): files = [p] # files else: - raise Exception('ERROR: %s does not exist' % p) + raise Exception(f'ERROR: {p} does not exist') images = [x for x in files if x.split('.')[-1].lower() in img_formats] videos = [x for x in files if x.split('.')[-1].lower() in vid_formats] @@ -138,13 +139,13 @@ def __init__(self, path, img_size=640): self.files = images + videos self.nf = ni + nv # number of files self.video_flag = [False] * ni + [True] * nv - self.mode = 'images' + self.mode = 'image' if any(videos): self.new_video(videos[0]) # new video else: self.cap = None - assert self.nf > 0, 'No images or videos found in %s. Supported formats are:\nimages: %s\nvideos: %s' % \ - (p, img_formats, vid_formats) + assert self.nf > 0, f'No images or videos found in {p}. ' \ + f'Supported formats are:\nimages: {img_formats}\nvideos: {vid_formats}' def __iter__(self): self.count = 0 @@ -170,14 +171,14 @@ def __next__(self): ret_val, img0 = self.cap.read() self.frame += 1 - print('video %g/%g (%g/%g) %s: ' % (self.count + 1, self.nf, self.frame, self.nframes, path), end='') + print(f'video {self.count + 1}/{self.nf} ({self.frame}/{self.nframes}) {path}: ', end='') else: # Read image self.count += 1 img0 = cv2.imread(path) # BGR assert img0 is not None, 'Image Not Found ' + path - print('image %g/%g %s: ' % (self.count, self.nf, path), end='') + print(f'image {self.count}/{self.nf} {path}: ', end='') # Padded resize img = letterbox(img0, new_shape=self.img_size)[0] @@ -237,9 +238,9 @@ def __next__(self): break # Print - assert ret_val, 'Camera Error %s' % self.pipe + assert ret_val, f'Camera Error {self.pipe}' img_path = 'webcam.jpg' - print('webcam %g: ' % self.count, end='') + print(f'webcam {self.count}: ', end='') # Padded resize img = letterbox(img0, new_shape=self.img_size)[0] @@ -256,7 +257,7 @@ def __len__(self): class LoadStreams: # multiple IP or RTSP cameras def __init__(self, sources='streams.txt', img_size=640): - self.mode = 'images' + self.mode = 'stream' self.img_size = img_size if os.path.isfile(sources): @@ -267,18 +268,18 @@ def __init__(self, sources='streams.txt', img_size=640): n = len(sources) self.imgs = [None] * n - self.sources = sources + self.sources = [clean_str(x) for x in sources] # clean source names for later for i, s in enumerate(sources): # Start the thread to read frames from the video stream - print('%g/%g: %s... ' % (i + 1, n, s), end='') + print(f'{i + 1}/{n}: {s}... ', end='') cap = cv2.VideoCapture(eval(s) if s.isnumeric() else s) - assert cap.isOpened(), 'Failed to open %s' % s + assert cap.isOpened(), f'Failed to open {s}' w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = cap.get(cv2.CAP_PROP_FPS) % 100 _, self.imgs[i] = cap.read() # guarantee first frame thread = Thread(target=self.update, args=([i, cap]), daemon=True) - print(' success (%gx%g at %.2f FPS).' % (w, h, fps)) + print(f' success ({w}x{h} at {fps:.2f} FPS).') thread.start() print('') # newline @@ -335,7 +336,7 @@ def img2label_paths(img_paths): class LoadImagesAndLabels(Dataset): # for training/testing def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False, - cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1): + cache_images=False, single_cls=False, stride=32, pad=0.0, prefix=''): self.img_size = img_size self.augment = augment self.hyp = hyp @@ -357,11 +358,11 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r parent = str(p.parent) + os.sep f += [x.replace('./', parent) if x.startswith('./') else x for x in t] # local to global path else: - raise Exception('%s does not exist' % p) + raise Exception(f'{prefix}{p} does not exist') self.img_files = sorted([x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in img_formats]) - assert self.img_files, 'No images found' + assert self.img_files, f'{prefix}No images found' except Exception as e: - raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url)) + raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {help_url}') # Check cache self.label_files = img2label_paths(self.img_files) # labels @@ -369,15 +370,15 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r if cache_path.is_file(): cache = torch.load(cache_path) # load if cache['hash'] != get_hash(self.label_files + self.img_files) or 'results' not in cache: # changed - cache = self.cache_labels(cache_path) # re-cache + cache = self.cache_labels(cache_path, prefix) # re-cache else: - cache = self.cache_labels(cache_path) # cache + cache = self.cache_labels(cache_path, prefix) # cache # Display cache [nf, nm, ne, nc, n] = cache.pop('results') # found, missing, empty, corrupted, total desc = f"Scanning '{cache_path}' for images and labels... {nf} found, {nm} missing, {ne} empty, {nc} corrupted" - tqdm(None, desc=desc, total=n, initial=n) - assert nf > 0 or not augment, f'No labels found in {cache_path}. Can not train without labels. See {help_url}' + tqdm(None, desc=prefix + desc, total=n, initial=n) + assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {help_url}' # Read cache cache.pop('hash') # remove hash @@ -431,9 +432,9 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r for i, x in pbar: self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i) gb += self.imgs[i].nbytes - pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9) + pbar.desc = f'{prefix}Caching images ({gb / 1E9:.1f}GB)' - def cache_labels(self, path=Path('./labels.cache')): + def cache_labels(self, path=Path('./labels.cache'), prefix=''): # Cache dataset labels, check images and read shapes x = {} # dict nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate @@ -465,18 +466,18 @@ def cache_labels(self, path=Path('./labels.cache')): x[im_file] = [l, shape] except Exception as e: nc += 1 - print('WARNING: Ignoring corrupted image and/or label %s: %s' % (im_file, e)) + print(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}') - pbar.desc = f"Scanning '{path.parent / path.stem}' for images and labels... " \ + pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' for images and labels... " \ f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted" if nf == 0: - print(f'WARNING: No labels found in {path}. See {help_url}') + print(f'{prefix}WARNING: No labels found in {path}. See {help_url}') x['hash'] = get_hash(self.label_files + self.img_files) x['results'] = [nf, nm, ne, nc, i + 1] torch.save(x, path) # save for next time - logging.info(f"New cache created: {path}") + logging.info(f'{prefix}New cache created: {path}') return x def __len__(self): @@ -578,6 +579,32 @@ def collate_fn(batch): l[:, 0] = i # add target image index for build_targets() return torch.stack(img, 0), torch.cat(label, 0), path, shapes + @staticmethod + def collate_fn4(batch): + img, label, path, shapes = zip(*batch) # transposed + n = len(shapes) // 4 + img4, label4, path4, shapes4 = [], [], path[:n], shapes[:n] + + ho = torch.tensor([[0., 0, 0, 1, 0, 0]]) + wo = torch.tensor([[0., 0, 1, 0, 0, 0]]) + s = torch.tensor([[1, 1, .5, .5, .5, .5]]) # scale + for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW + i *= 4 + if random.random() < 0.5: + im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2., mode='bilinear', align_corners=False)[ + 0].type(img[i].type()) + l = label[i] + else: + im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2) + l = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s + img4.append(im) + label4.append(l) + + for i, l in enumerate(label4): + l[:, 0] = i # add target image index for build_targets() + + return torch.stack(img4, 0), torch.cat(label4, 0), path4, shapes4 + # Ancillary functions -------------------------------------------------------------------------------------------------- def load_image(self, index): @@ -617,7 +644,7 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5): def load_mosaic(self, index): - # loads images in a mosaic + # loads images in a 4-mosaic labels4 = [] s = self.img_size @@ -674,6 +701,80 @@ def load_mosaic(self, index): return img4, labels4 +def load_mosaic9(self, index): + # loads images in a 9-mosaic + + labels9 = [] + s = self.img_size + indices = [index] + [self.indices[random.randint(0, self.n - 1)] for _ in range(8)] # 8 additional image indices + for i, index in enumerate(indices): + # Load image + img, _, (h, w) = load_image(self, index) + + # place img in img9 + if i == 0: # center + img9 = np.full((s * 3, s * 3, img.shape[2]), 114, dtype=np.uint8) # base image with 4 tiles + h0, w0 = h, w + c = s, s, s + w, s + h # xmin, ymin, xmax, ymax (base) coordinates + elif i == 1: # top + c = s, s - h, s + w, s + elif i == 2: # top right + c = s + wp, s - h, s + wp + w, s + elif i == 3: # right + c = s + w0, s, s + w0 + w, s + h + elif i == 4: # bottom right + c = s + w0, s + hp, s + w0 + w, s + hp + h + elif i == 5: # bottom + c = s + w0 - w, s + h0, s + w0, s + h0 + h + elif i == 6: # bottom left + c = s + w0 - wp - w, s + h0, s + w0 - wp, s + h0 + h + elif i == 7: # left + c = s - w, s + h0 - h, s, s + h0 + elif i == 8: # top left + c = s - w, s + h0 - hp - h, s, s + h0 - hp + + padx, pady = c[:2] + x1, y1, x2, y2 = [max(x, 0) for x in c] # allocate coords + + # Labels + x = self.labels[index] + labels = x.copy() + if x.size > 0: # Normalized xywh to pixel xyxy format + labels[:, 1] = w * (x[:, 1] - x[:, 3] / 2) + padx + labels[:, 2] = h * (x[:, 2] - x[:, 4] / 2) + pady + labels[:, 3] = w * (x[:, 1] + x[:, 3] / 2) + padx + labels[:, 4] = h * (x[:, 2] + x[:, 4] / 2) + pady + labels9.append(labels) + + # Image + img9[y1:y2, x1:x2] = img[y1 - pady:, x1 - padx:] # img9[ymin:ymax, xmin:xmax] + hp, wp = h, w # height, width previous + + # Offset + yc, xc = [int(random.uniform(0, s)) for x in self.mosaic_border] # mosaic center x, y + img9 = img9[yc:yc + 2 * s, xc:xc + 2 * s] + + # Concat/clip labels + if len(labels9): + labels9 = np.concatenate(labels9, 0) + labels9[:, [1, 3]] -= xc + labels9[:, [2, 4]] -= yc + + np.clip(labels9[:, 1:], 0, 2 * s, out=labels9[:, 1:]) # use with random_perspective + # img9, labels9 = replicate(img9, labels9) # replicate + + # Augment + img9, labels9 = random_perspective(img9, labels9, + degrees=self.hyp['degrees'], + translate=self.hyp['translate'], + scale=self.hyp['scale'], + shear=self.hyp['shear'], + perspective=self.hyp['perspective'], + border=self.mosaic_border) # border to remove + + return img9, labels9 + + def replicate(img, labels): # Replicate labels h, w = img.shape[:2] @@ -811,12 +912,12 @@ def random_perspective(img, targets=(), degrees=10, translate=.1, scale=.1, shea return img, targets -def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1): # box1(4,n), box2(4,n) +def box_candidates(box1, box2, wh_thr=2, ar_thr=20, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio w1, h1 = box1[2] - box1[0], box1[3] - box1[1] w2, h2 = box2[2] - box2[0], box2[3] - box2[1] - ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio - return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) & (ar < ar_thr) # candidates + ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio + return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates def cutout(image, labels): diff --git a/utils/general.py b/utils/general.py index 22647f6cd2..5126a45ac0 100755 --- a/utils/general.py +++ b/utils/general.py @@ -2,8 +2,8 @@ import glob import logging +import math import os -import platform import random import re import subprocess @@ -11,7 +11,6 @@ from pathlib import Path import cv2 -import math import numpy as np import torch import torchvision @@ -25,6 +24,7 @@ torch.set_printoptions(linewidth=320, precision=5, profile='long') np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) +os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8)) # NumExpr max threads def set_logging(rank=-1): @@ -34,6 +34,7 @@ def set_logging(rank=-1): def init_seeds(seed=0): + # Initialize random number generator (RNG) seeds random.seed(seed) np.random.seed(seed) init_torch_seeds(seed) @@ -45,12 +46,41 @@ def get_latest_run(search_dir='.'): return max(last_list, key=os.path.getctime) if last_list else '' +def check_online(): + # Check internet connectivity + import socket + try: + socket.create_connection(("1.1.1.1", 53)) # check host accesability + return True + except OSError: + return False + + def check_git_status(): - # Suggest 'git pull' if repo is out of date - if platform.system() in ['Linux', 'Darwin'] and not os.path.isfile('/.dockerenv'): - s = subprocess.check_output('if [ -d .git ]; then git fetch && git status -uno; fi', shell=True).decode('utf-8') - if 'Your branch is behind' in s: - print(s[s.find('Your branch is behind'):s.find('\n\n')] + '\n') + # Suggest 'git pull' if YOLOv5 is out of date + print(colorstr('github: '), end='') + try: + if Path('.git').exists() and check_online(): + url = subprocess.check_output( + 'git fetch && git config --get remote.origin.url', shell=True).decode('utf-8')[:-1] + n = int(subprocess.check_output( + 'git rev-list $(git rev-parse --abbrev-ref HEAD)..origin/master --count', shell=True)) # commits behind + if n > 0: + s = f"⚠️ WARNING: code is out of date by {n} {'commits' if n > 1 else 'commmit'}. " \ + f"Use 'git pull' to update or 'git clone {url}' to download latest." + else: + s = f'up to date with {url} ✅' + except Exception as e: + s = str(e) + print(s) + + +def check_requirements(file='requirements.txt'): + # Check installed dependencies meet requirements + import pkg_resources + requirements = pkg_resources.parse_requirements(Path(file).open()) + requirements = [x.name + ''.join(*x.specs) if len(x.specs) else x.name for x in requirements] + pkg_resources.require(requirements) # DistributionNotFound or VersionConflict exception if requirements not met def check_img_size(img_size, s=32): @@ -97,6 +127,41 @@ def make_divisible(x, divisor): return math.ceil(x / divisor) * divisor +def clean_str(s): + # Cleans a string by replacing special characters with underscore _ + return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) + + +def one_cycle(y1=0.0, y2=1.0, steps=100): + # lambda function for sinusoidal ramp from y1 to y2 + return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 + + +def colorstr(*input): + # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') + *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string + colors = {'black': '\033[30m', # basic colors + 'red': '\033[31m', + 'green': '\033[32m', + 'yellow': '\033[33m', + 'blue': '\033[34m', + 'magenta': '\033[35m', + 'cyan': '\033[36m', + 'white': '\033[37m', + 'bright_black': '\033[90m', # bright colors + 'bright_red': '\033[91m', + 'bright_green': '\033[92m', + 'bright_yellow': '\033[93m', + 'bright_blue': '\033[94m', + 'bright_magenta': '\033[95m', + 'bright_cyan': '\033[96m', + 'bright_white': '\033[97m', + 'end': '\033[0m', # misc + 'bold': '\033[1m', + 'underline': '\033[4m'} + return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] + + def labels_to_class_weights(labels, nc=80): # Get class weights (inverse frequency) from training labels if labels[0] is None: # no labels loaded @@ -271,6 +336,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non # Settings min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height max_det = 300 # maximum number of detections per image + max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) @@ -311,20 +377,19 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] # Filter by class - if classes: + if classes is not None: x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] # Apply finite constraint # if not torch.isfinite(x).all(): # x = x[torch.isfinite(x).all(1)] - # If none remain process next image + # Check shape n = x.shape[0] # number of boxes - if not n: + if not n: # no boxes continue - - # Sort by confidence - # x = x[x[:, 4].argsort(descending=True)] + elif n > max_nms: # excess boxes + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes @@ -342,6 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non output[xi] = x[i] if (time.time() - t) > time_limit: + print(f'WARNING: NMS time limit {time_limit}s exceeded') break # time limit exceeded return output diff --git a/utils/loss.py b/utils/loss.py index 4893c99918..2cfd0967b9 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -59,6 +59,32 @@ def forward(self, pred, true): return loss +class QFocalLoss(nn.Module): + # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super(QFocalLoss, self).__init__() + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + + pred_prob = torch.sigmoid(pred) # prob from logits + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = torch.abs(true - pred_prob) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss + + def compute_loss(p, targets, model): # predictions, targets, model device = targets.device lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) @@ -66,8 +92,8 @@ def compute_loss(p, targets, model): # predictions, targets, model h = model.hyp # hyperparameters # Define criteria - BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['cls_pw']])).to(device) - BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([h['obj_pw']])).to(device) + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device)) # weight=model.class_weights) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device)) # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 cp, cn = smooth_BCE(eps=0.0) @@ -79,8 +105,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Losses nt = 0 # number of targets - no = len(p) # number of outputs - balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 + balance = [4.0, 1.0, 0.4, 0.1] # P3-P6 for i, pi in enumerate(p): # layer index, layer predictions b, a, gj, gi = indices[i] # image, anchor, gridy, gridx tobj = torch.zeros_like(pi[..., 0], device=device) # target obj @@ -93,7 +118,7 @@ def compute_loss(p, targets, model): # predictions, targets, model # Regression pxy = ps[:, :2].sigmoid() * 2. - 0.5 pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] - pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box + pbox = torch.cat((pxy, pwh), 1) # predicted box iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) lbox += (1.0 - iou).mean() # iou loss @@ -112,10 +137,9 @@ def compute_loss(p, targets, model): # predictions, targets, model lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss - s = 3 / no # output count scaling - lbox *= h['box'] * s - lobj *= h['obj'] * s * (1.4 if no == 4 else 1.) - lcls *= h['cls'] * s + lbox *= h['box'] + lobj *= h['obj'] + lcls *= h['cls'] bs = tobj.shape[0] # batch size loss = lbox + lobj + lcls diff --git a/utils/plots.py b/utils/plots.py index 8fff8ec6dd..47cd707760 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -1,16 +1,18 @@ # Plotting utils import glob +import math import os import random from copy import copy from pathlib import Path import cv2 -import math import matplotlib import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import seaborn as sns import torch import yaml from PIL import Image, ImageDraw @@ -21,7 +23,7 @@ # Settings matplotlib.rc('font', **{'size': 11}) -matplotlib.use('svg') # for writing to files only +matplotlib.use('Agg') # for writing to files only def color_list(): @@ -188,6 +190,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''): plt.xlim(0, epochs) plt.ylim(0) plt.savefig(Path(save_dir) / 'LR.png', dpi=200) + plt.close() def plot_test_txt(): # from utils.plots import *; plot_test() @@ -220,13 +223,13 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt() plt.savefig('targets.jpg', dpi=200) -def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt() +def plot_study_txt(path='study/', x=None): # from utils.plots import *; plot_study_txt() # Plot study.txt generated by test.py fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True) ax = ax.ravel() fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True) - for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov3', 'yolov3-spp', 'yolov3-tiny']]: + for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s', 'yolov5m', 'yolov5l', 'yolov5x']]: y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T x = np.arange(y.shape[1]) if x is None else np.array(x) s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)'] @@ -242,9 +245,9 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx 'k.-', linewidth=2, markersize=8, alpha=.25, label='EfficientDet') ax2.grid() + ax2.set_yticks(np.arange(30, 60, 5)) ax2.set_xlim(0, 30) - ax2.set_ylim(15, 50) - ax2.set_yticks(np.arange(15, 55, 5)) + ax2.set_ylim(29, 51) ax2.set_xlabel('GPU Speed (ms/img)') ax2.set_ylabel('COCO AP val') ax2.legend(loc='lower right') @@ -253,34 +256,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx def plot_labels(labels, save_dir=Path(''), loggers=None): # plot dataset labels + print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes nc = int(c.max() + 1) # number of classes colors = color_list() + x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) # seaborn correlogram - try: - import seaborn as sns - import pandas as pd - x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) - sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', - plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), - diag_kws=dict(bins=50)) - plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) - plt.close() - except Exception as e: - pass + sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) + plt.close() # matplotlib labels matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) ax[0].set_xlabel('classes') - ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet') - ax[2].set_xlabel('x') - ax[2].set_ylabel('y') - ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') - ax[3].set_xlabel('width') - ax[3].set_ylabel('height') + sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) + sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) # rectangles labels[:, 1:3] = 0.5 # center @@ -329,6 +322,38 @@ def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots impo print('\nPlot saved as evolve.png') +def profile_idetection(start=0, stop=0, labels=(), save_dir=''): + # Plot iDetection '*.txt' per-image logs. from utils.plots import *; profile_idetection() + ax = plt.subplots(2, 4, figsize=(12, 6), tight_layout=True)[1].ravel() + s = ['Images', 'Free Storage (GB)', 'RAM Usage (GB)', 'Battery', 'dt_raw (ms)', 'dt_smooth (ms)', 'real-world FPS'] + files = list(Path(save_dir).glob('frames*.txt')) + for fi, f in enumerate(files): + try: + results = np.loadtxt(f, ndmin=2).T[:, 90:-30] # clip first and last rows + n = results.shape[1] # number of rows + x = np.arange(start, min(stop, n) if stop else n) + results = results[:, x] + t = (results[0] - results[0].min()) # set t0=0s + results[0] = x + for i, a in enumerate(ax): + if i < len(results): + label = labels[fi] if len(labels) else f.stem.replace('frames_', '') + a.plot(t, results[i], marker='.', label=label, linewidth=1, markersize=5) + a.set_title(s[i]) + a.set_xlabel('time (s)') + # if fi == len(files) - 1: + # a.set_ylim(bottom=0) + for side in ['top', 'right']: + a.spines[side].set_visible(False) + else: + a.remove() + except Exception as e: + print('Warning: Plotting error for %s; %s' % (f, e)) + + ax[1].legend() + plt.savefig(Path(save_dir) / 'idetection_profile.png', dpi=200) + + def plot_results_overlay(start=0, stop=0): # from utils.plots import *; plot_results_overlay() # Plot training 'results*.txt', overlaying train and val losses s = ['train', 'train', 'train', 'Precision', 'mAP@0.5', 'val', 'val', 'val', 'Recall', 'mAP@0.5:0.95'] # legends diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 82fc731c02..231dcfd7a5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -61,7 +61,7 @@ def select_device(device='', batch_size=None): os.environ['CUDA_VISIBLE_DEVICES'] = device # set environment variable assert torch.cuda.is_available(), f'CUDA unavailable, invalid device {device} requested' # check availability - cuda = torch.cuda.is_available() and not cpu + cuda = not cpu and torch.cuda.is_available() if cuda: n = torch.cuda.device_count() if n > 1 and batch_size: # check that batch_size is compatible with device_count