├── README.md ├── detect_explorer.py ├── detect_train.py ├── export.py ├── image ├── 2.jpg ├── 6.jpg ├── 7.png ├── 8.jpg ├── 9.jpg ├── chepai.png ├── chepaiwenzi.png ├── gradio1.png ├── hecheng.png └── result.jpg ├── kenshutsu.py ├── ocr_config.py ├── ocr_test.py ├── read_plate.py ├── requirements.txt └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-310.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── __init__.cpython-39.pyc ├── activations.cpython-38.pyc ├── augmentations.cpython-310.pyc ├── augmentations.cpython-36.pyc ├── augmentations.cpython-38.pyc ├── augmentations.cpython-39.pyc ├── autoanchor.cpython-310.pyc ├── autoanchor.cpython-36.pyc ├── autoanchor.cpython-38.pyc ├── autoanchor.cpython-39.pyc ├── autobatch.cpython-38.pyc ├── benchmarks.cpython-38.pyc ├── callbacks.cpython-38.pyc ├── dataloaders.cpython-310.pyc ├── dataloaders.cpython-36.pyc ├── dataloaders.cpython-38.pyc ├── dataloaders.cpython-39.pyc ├── downloads.cpython-310.pyc ├── downloads.cpython-36.pyc ├── downloads.cpython-38.pyc ├── downloads.cpython-39.pyc ├── general.cpython-310.pyc ├── general.cpython-36.pyc ├── general.cpython-38.pyc ├── general.cpython-39.pyc ├── loss.cpython-38.pyc ├── metrics.cpython-310.pyc ├── metrics.cpython-36.pyc ├── metrics.cpython-38.pyc ├── metrics.cpython-39.pyc ├── plots.cpython-310.pyc ├── plots.cpython-36.pyc ├── plots.cpython-38.pyc ├── plots.cpython-39.pyc ├── torch_utils.cpython-310.pyc ├── torch_utils.cpython-36.pyc ├── torch_utils.cpython-38.pyc └── torch_utils.cpython-39.pyc ├── activations.py ├── augmentations.py ├── autoanchor.py ├── autobatch.py ├── aws ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── resume.cpython-38.pyc ├── mime.sh ├── resume.py └── userdata.sh ├── benchmarks.py ├── callbacks.py ├── dataloaders.py ├── docker ├── Dockerfile ├── Dockerfile-arm64 └── Dockerfile-cpu ├── downloads.py ├── flask_rest_api ├── README.md ├── __pycache__ │ ├── example_request.cpython-38.pyc │ └── restapi.cpython-38.pyc ├── example_request.py └── restapi.py ├── general.py ├── google_app_engine ├── Dockerfile ├── additional_requirements.txt └── app.yaml └── loggers ├── __init__.py ├── __pycache__ └── __init__.cpython-38.pyc └── wandb ├── README.md ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── log_dataset.cpython-38.pyc ├── sweep.cpython-38.pyc └── wandb_utils.cpython-38.pyc ├── log_dataset.py └── sweep.py /README.md: -------------------------------------------------------------------------------- 1 | # 计算机毕业设计--基于深度学习的车牌检测与识别算法 2 | 3 | 4 | #### 作者主页中还有其他方向的深度学习毕业设计项目,例如图像破损修复、老照片图像恢复、灰白照片色彩复原、医学图像分割等,具体参考:👇 5 | [深度学习方向毕业设计](https://blog.csdn.net/qq_45566099/category_12507289.html) 6 | 7 | ## :sparkles: Demo 8 | 9 | 10 | 11 |
12 | 13 | ## :sparkles: 车牌识别Web端在线体验连接 14 | 15 | 👇 16 | **Web端在线体验地址:**:white_check_mark:[访问这里进行车牌识别在线体验](http://zxxserver.w1.luyouxia.net/lpd/):white_check_mark: 17 | 18 | PS1:在线体验地址集成了图片识别和视频检测,点击页面下方页面选项即可分别体验!在网页下方提供了若干输入样例,点击样例自动填充到相应位置后即可点击开始修复查看效果。 19 | 20 | PS2:在线体验链接中部署的模型作者只用了少量数据进行了简单的训练,如果需要更高精度以及更好的性能,自行部署项目后可以使用更全面的数据集进行训练 21 | ☝ 22 | 23 | 24 | 25 |
26 | 27 | 28 | ## 介绍 29 |   本项目利用深度学习(卷积神经网络)设计了一个基于深度学习的车牌检测识别系统,适合作为本科毕业论文的研究课题。该系统提供两种检测方式:一种是对上传的图片进行车牌检测识别,另一种是通过视频流自动识别车牌和车牌信息。只需提供包含车牌的图片(无论位置或角度如何),系统即可标记图片中的车牌位置并输出车牌号码。如果你有摄像头,可以通过训练好的模型调用摄像头进行动态车牌监测,或者将.mp4格式的视频文件输入模型,模型将返回标记好车牌位置的图片并输出检测到的车牌号码。 30 | 31 | ## 模型结构设计概述 32 |   模型设计部分,车辆检测网络采用了改良的**YOLOv7**模型进行检测(改进方法是加入了注意力机制以及改良了原模型中的GSConv_slimneck卷积),而车牌检测网络则自行设计了一种**ResNet与Transformr**相结合的网络,该模型输出检测边框的仿射变换矩阵,可识别任意角度的车牌文字。训练数据集使用CCPD2019、2020以及部分自生成数据集。 33 | 34 |   在训练检测模型时,使用了数据增强方法以增强模型的泛化能力。对于车牌号码的序列识别,使用程序生成的车牌图片进行训练,并结合适当的图像增强手段。模型训练采用端到端的方式,输入图片后直接输出车牌号码序列,并将车牌号码打印在原始图片上。 35 | 36 | ## 自构建数据集 37 |   由于所能获取的车牌文字数据集量少且质量较差,本课题通过自动生成多样化车牌用于补充数据训练模型,如下图所示,本课题使用程序构建了不同颜色的车牌底板用于模拟国内五种车牌底色。 38 | 39 | 40 | 41 |   同时,如下图所示,本课题还使用程序构建了不同省份的车牌文字信息,用于模拟车牌内容文字。 42 | 43 | 44 | 45 |   最终,如下图所示,通过程序,本课题构建了大量的车牌信息数据集用于文字识别训练,丰富了训练数据的多样性,显著提升模型对各类复杂真实场景的适应性。 46 | 47 | 48 | 49 |
50 | 51 | ## :rocket: 运行要求 52 | - 运行算法与Web前端需要 Python >= 3.8 53 | - 建议使用带有nvidia系列的显卡(比如1060、3050、3090、4090、5060都是nvidia系列的) 54 | - 如果没有显卡该项目也可以通过CPU+内存的方式部署,平均单图推理速度约为0.7s 55 | 56 |
57 | 58 | ## :zap:项目使用方式 59 | #### 环境配置(推荐使用conda安装环境) 60 | ``` 61 | # 从github上Clone项目 62 | git clone https://github.com/zxx1218/LicensePlateDetection.git 63 | 64 | # 使用conda创建环境 65 | conda create -n lpd python=3.9 -y 66 | conda activate lpd 67 | 68 | # 安装依赖 69 | pip install -r requirements.txt 70 | ``` 71 | ## 检测展示: 72 | 在视频中检测车牌: 73 | - 由于Github上传视频限制,请移步至我的CSDN观看:https://blog.csdn.net/qq_45566099/article/details/134574209 74 | 75 | 使用模型检测单张图片: 76 | 77 | 78 | 79 |
80 | 81 | 82 | 83 | ## 作者联系方式: 84 | - **VX:Accddvva** 85 | - **QQ:1144968929** 86 | - Github提供训练好的模型文件以及调用该文件测试的代码(clone代码后安装环境即可进行测试,但github上代码不包含模型源码) 87 | - 本项目完整代码 + 远程部署服务 == **价格80RMB** 88 | 89 | 90 |
91 | 92 | 93 | ## 广告 94 | 95 | - 作者于浙江某985高校就读人工智能方向研究生(CSDN已认证),可以定制模型,并提供相应技术文档以及各种需要,只需要描述需求即可 96 | - 人工智能、深度学习领域,尤其是计算机视觉(Computer vision,CV)方向的模型or毕业设计,只要你想得出,没有做不出 97 | -------------------------------------------------------------------------------- /detect_explorer.py: -------------------------------------------------------------------------------- 1 | import detect_config as config 2 | import cv2 3 | import torch 4 | from einops import rearrange 5 | import matplotlib.pyplot as plt 6 | import os 7 | import numpy 8 | 9 | 10 | class DExplorer: 11 | 12 | def __init__(self): 13 | self.net = config.net() 14 | if os.path.exists(config.weight): 15 | self.net.load_state_dict(torch.load(config.weight, map_location='cpu')) 16 | else: 17 | raise RuntimeError('Model parameters are not loaded') 18 | # self.net.to(config.device) 19 | self.net.eval() 20 | 21 | def __call__(self, image_o): 22 | image = image_o.copy() 23 | h, w, c = image.shape 24 | f = min(288 * max(h, w) / min(h, w), 608) / min(h, w) 25 | _w = int(w * f) + (0 if w % 16 == 0 else 16 - w % 16) 26 | _h = int(h * f) + (0 if h % 16 == 0 else 16 - h % 16) 27 | image = cv2.resize(image, (_w, _h), interpolation=cv2.INTER_AREA) 28 | image_tensor = torch.from_numpy(image) / 255 29 | image_tensor = rearrange(image_tensor, 'h w c ->() c h w') 30 | # print(image_tensor.shape) 31 | with torch.no_grad(): 32 | y = self.net(image_tensor).cpu() 33 | points = self.select_box(y, (_w, _h)) 34 | # for point, c in points: 35 | # x1, x2, x3, x4, y1, y2, y3, y4 = point.reshape(-1) 36 | # x1, x2, x3, x4 = x1 * _w, x2 * _w, x3 * _w, x4 * _w 37 | # y1, y2, y3, y4 = y1 * _h, y2 * _h, y3 * _h, y4 * _h 38 | # i = 1 39 | # for x, y in [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]: 40 | # image = cv2.circle(image, (int(x), int(y)), 2, (0, 0, 255), -1) 41 | # image = cv2.putText(image, str(i), (int(x), int(y)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) 42 | # i += 1 43 | # cv2.imshow('a', image) 44 | # 45 | # # print(points) 46 | # cv2.waitKey() 47 | return points 48 | 49 | def select_box(self, predict, size, dims=208, stride=16): 50 | wh = numpy.array([[size[0]], [size[1]]]) 51 | probs = predict[0, :, :, 0:2] 52 | # # a = probs[:,:,1]>0.9 53 | # print(a) 54 | probs = torch.softmax(probs, dim=-1).numpy() 55 | # a = probs[:, :, 1] > 0.9 56 | # print(a) 57 | # plt.imshow(a.astype("uint8")) 58 | # plt.show() 59 | # print(predict.shape) 60 | affines = torch.cat( 61 | ( 62 | predict[0, :, :, 2:3], 63 | predict[0, :, :, 3:4], 64 | predict[0, :, :, 4:5], 65 | predict[0, :, :, 5:6], 66 | predict[0, :, :, 6:7], 67 | predict[0, :, :, 7:8] 68 | ), 69 | dim=2 70 | ) 71 | h, w, c = affines.shape 72 | affines = affines.reshape(h, w, 2, 3).numpy() 73 | scale = ((dims + 40.0) / 2.0) / stride 74 | unit = numpy.array([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose((1, 0)) 75 | h, w, _ = probs.shape 76 | candidates = [] 77 | for i in range(h): 78 | for j in range(w): 79 | if probs[i, j, 1] > config.confidence_threshold: 80 | affine = affines[i, j] 81 | pts = affine @ unit 82 | # print(affine) 83 | # print(affine) 84 | pts *= scale 85 | pts += numpy.array([[j + 0.5], [i + 0.5]]) 86 | pts *= stride 87 | # print(pts) 88 | pts /= wh 89 | # exit() 90 | candidates.append((pts, probs[i, j, 1])) 91 | # break 92 | 93 | candidates.sort(key=lambda x: x[1], reverse=True) 94 | # print(candidates) 95 | labels = [] 96 | # exit() 97 | '''非极大值抑制''' 98 | for pts_c, prob_c in candidates: 99 | tl_c = pts_c.min(axis=1) 100 | # print('tl_c:',tl_c) 101 | # exit() 102 | br_c = pts_c.max(axis=1) 103 | overlap = False 104 | for pts_l, _ in labels: 105 | tl_l = pts_l.min(axis=1) 106 | br_l = pts_l.max(axis=1) 107 | if self.iou(tl_c, br_c, tl_l, br_l) > 0.1: 108 | overlap = True 109 | break 110 | if not overlap: 111 | labels.append((pts_c, prob_c)) 112 | return labels 113 | 114 | @staticmethod 115 | def iou(tl1, br1, tl2, br2): 116 | x1, y1 = tl1 117 | x2, y2 = br1 118 | x3, y3 = tl2 119 | x4, y4 = br2 120 | wh1 = br1 - tl1 121 | wh2 = br2 - tl2 122 | assert ((wh1 >= 0).sum() > 0 and (wh2 >= 0).sum() > 0) 123 | s1 = (y2 - y1) * (x2 - x1) 124 | s2 = (y4 - y3) * (x4 - x3) 125 | _x1 = max(x1, x3) 126 | _y1 = max(y1, y3) 127 | _x2 = min(x2, x4) 128 | _y2 = max(y2, y4) 129 | w = max(0, _x2 - _x1) 130 | h = max(0, _y2 - _y1) 131 | i = w * h 132 | return i / (s1 + s2 - i) 133 | 134 | 135 | if __name__ == '__main__': 136 | # import numpy 137 | 138 | e = DExplorer() 139 | image = cv2.imread('test_image.jpg') 140 | # image = numpy.zeros((208, 208, 3), dtype=numpy.uint8) 141 | labe = e(image) 142 | print(labe) 143 | -------------------------------------------------------------------------------- /detect_train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torch import nn 3 | from utils_1.lossfunction import FocalLossManyClassification 4 | from utils_1.dataset import DetectDataset 5 | from einops import rearrange 6 | from tqdm import tqdm 7 | import detect_config as config 8 | import torch 9 | import os 10 | 11 | 12 | class Trainer: 13 | 14 | def __init__(self): 15 | self.net = config.net() 16 | if os.path.exists(config.weight): 17 | self.net.load_state_dict(torch.load(config.weight, map_location='cpu')) 18 | print('成功加载网络参数') 19 | else: 20 | print('未加载网络参数') 21 | 22 | self.l1_loss = nn.L1Loss() 23 | self.c_loss = FocalLossManyClassification(2) 24 | self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00001) 25 | self.dataset = DetectDataset() 26 | self.data_loader = DataLoader(self.dataset, config.batch_size, drop_last=True) 27 | self.net.to(config.device) 28 | 29 | def train(self): 30 | 31 | for epoch in range(config.epoch): 32 | self.net.train() 33 | loss_sum = 0 34 | for i, (images, labels) in enumerate(self.data_loader): 35 | images = images.to(config.device) 36 | labels = labels.to(config.device) 37 | 38 | predict = self.net(images) 39 | loss_c, loss_p = self.count_loss(predict, labels) 40 | loss = loss_c + loss_p 41 | self.optimizer.zero_grad() 42 | loss.backward() 43 | self.optimizer.step() 44 | if i % 10 == 0: 45 | print(epoch, i, loss.item(), 'loss_c:', loss_c.item(), 'loss_p:', loss_p.item()) 46 | if i % 100 == 0: 47 | torch.save(self.net.state_dict(), config.weight) 48 | loss_sum += loss.item() 49 | logs = f'epoch:{epoch},loss:{loss_sum / len(self.data_loader)}' 50 | print(logs) 51 | torch.save(self.net.state_dict(), config.weight) 52 | 53 | def count_loss(self, predict, target): 54 | condition_positive = target[:, :, :, 0] == 1 55 | condition_negative = target[:, :, :, 0] == 0 56 | 57 | predict_positive = predict[condition_positive] 58 | predict_negative = predict[condition_negative] 59 | 60 | target_positive = target[condition_positive] 61 | target_negative = target[condition_negative] 62 | # print(target_positive.shape) 63 | n, v = predict_positive.shape 64 | if n > 0: 65 | loss_c_positive = self.c_loss(predict_positive[:, 0:2], target_positive[:, 0].long()) 66 | else: 67 | loss_c_positive = 0 68 | loss_c_nagative = self.c_loss(predict_negative[:, 0:2], target_negative[:, 0].long()) 69 | loss_c = loss_c_nagative + loss_c_positive 70 | 71 | if n > 0: 72 | affine = torch.cat( 73 | ( 74 | predict_positive[:, 2:3], 75 | predict_positive[:, 3:4], 76 | predict_positive[:, 4:5], 77 | predict_positive[:, 5:6], 78 | predict_positive[:, 6:7], 79 | predict_positive[:, 7:8] 80 | ), 81 | dim=1 82 | ) 83 | # print(affine.shape) 84 | # exit() 85 | trans_m = affine.reshape(-1, 2, 3) 86 | unit = torch.tensor([[-0.5, -0.5, 1], [0.5, -0.5, 1], [0.5, 0.5, 1], [-0.5, 0.5, 1]]).transpose(0, 1).to( 87 | trans_m.device).float() 88 | # print(unit) 89 | point_pred = torch.einsum('n j k, k d -> n j d', trans_m, unit) 90 | point_pred = rearrange(point_pred, 'n j k -> n (j k)') 91 | loss_p = self.l1_loss(point_pred, target_positive[:, 1:]) 92 | else: 93 | loss_p = 0 94 | # exit() 95 | return loss_c, loss_p 96 | 97 | # return loss 98 | 99 | 100 | if __name__ == '__main__': 101 | trainer = Trainer() 102 | trainer.train() 103 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Export a YOLOv5 PyTorch model to other formats. TensorFlow exports authored by https://github.com/zldrobit 4 | 5 | Format | `export.py --include` | Model 6 | --- | --- | --- 7 | PyTorch | - | yolov5s.pt 8 | TorchScript | `torchscript` | yolov5s.torchscript 9 | ONNX | `onnx` | yolov5s.onnx 10 | OpenVINO | `openvino` | yolov5s_openvino_model/ 11 | TensorRT | `engine` | yolov5s.engine 12 | CoreML | `coreml` | yolov5s.mlmodel 13 | TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ 14 | TensorFlow GraphDef | `pb` | yolov5s.pb 15 | TensorFlow Lite | `tflite` | yolov5s.tflite 16 | TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite 17 | TensorFlow.js | `tfjs` | yolov5s_web_model/ 18 | 19 | Requirements: 20 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU 21 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU 22 | 23 | Usage: 24 | $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ... 25 | 26 | Inference: 27 | $ python path/to/detect.py --weights yolov5s.pt # PyTorch 28 | yolov5s.torchscript # TorchScript 29 | yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn 30 | yolov5s.xml # OpenVINO 31 | yolov5s.engine # TensorRT 32 | yolov5s.mlmodel # CoreML (macOS-only) 33 | yolov5s_saved_model # TensorFlow SavedModel 34 | yolov5s.pb # TensorFlow GraphDef 35 | yolov5s.tflite # TensorFlow Lite 36 | yolov5s_edgetpu.tflite # TensorFlow Edge TPU 37 | 38 | TensorFlow.js: 39 | $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example 40 | $ npm install 41 | $ ln -s ../../yolov5/yolov5s_web_model public/yolov5s_web_model 42 | $ npm start 43 | """ 44 | 45 | import argparse 46 | import json 47 | import os 48 | import platform 49 | import subprocess 50 | import sys 51 | import time 52 | import warnings 53 | from pathlib import Path 54 | 55 | import pandas as pd 56 | import torch 57 | import yaml 58 | from torch.utils.mobile_optimizer import optimize_for_mobile 59 | 60 | FILE = Path(__file__).resolve() 61 | ROOT = FILE.parents[0] # YOLOv5 root directory 62 | if str(ROOT) not in sys.path: 63 | sys.path.append(str(ROOT)) # add ROOT to PATH 64 | if platform.system() != 'Windows': 65 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 66 | 67 | from models.experimental import attempt_load 68 | from models.yolo import Detect 69 | from utils.dataloaders import LoadImages 70 | from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr, 71 | file_size, print_args, url2file) 72 | from utils.torch_utils import select_device 73 | 74 | 75 | def export_formats(): 76 | # YOLOv5 export formats 77 | x = [ 78 | ['PyTorch', '-', '.pt', True], 79 | ['TorchScript', 'torchscript', '.torchscript', True], 80 | ['ONNX', 'onnx', '.onnx', True], 81 | ['OpenVINO', 'openvino', '_openvino_model', False], 82 | ['TensorRT', 'engine', '.engine', True], 83 | ['CoreML', 'coreml', '.mlmodel', False], 84 | ['TensorFlow SavedModel', 'saved_model', '_saved_model', True], 85 | ['TensorFlow GraphDef', 'pb', '.pb', True], 86 | ['TensorFlow Lite', 'tflite', '.tflite', False], 87 | ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False], 88 | ['TensorFlow.js', 'tfjs', '_web_model', False],] 89 | return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'GPU']) 90 | 91 | 92 | def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): 93 | # YOLOv5 TorchScript model export 94 | try: 95 | LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') 96 | f = file.with_suffix('.torchscript') 97 | 98 | ts = torch.jit.trace(model, im, strict=False) 99 | d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names} 100 | extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() 101 | if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html 102 | optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) 103 | else: 104 | ts.save(str(f), _extra_files=extra_files) 105 | 106 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 107 | return f 108 | except Exception as e: 109 | LOGGER.info(f'{prefix} export failure: {e}') 110 | 111 | 112 | def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')): 113 | # YOLOv5 ONNX export 114 | try: 115 | check_requirements(('onnx',)) 116 | import onnx 117 | 118 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 119 | f = file.with_suffix('.onnx') 120 | 121 | torch.onnx.export( 122 | model, 123 | im, 124 | f, 125 | verbose=False, 126 | opset_version=opset, 127 | training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL, 128 | do_constant_folding=not train, 129 | input_names=['images'], 130 | output_names=['output'], 131 | dynamic_axes={ 132 | 'images': { 133 | 0: 'batch', 134 | 2: 'height', 135 | 3: 'width'}, # shape(1,3,640,640) 136 | 'output': { 137 | 0: 'batch', 138 | 1: 'anchors'} # shape(1,25200,85) 139 | } if dynamic else None) 140 | 141 | # Checks 142 | model_onnx = onnx.load(f) # load onnx model 143 | onnx.checker.check_model(model_onnx) # check onnx model 144 | 145 | # Metadata 146 | d = {'stride': int(max(model.stride)), 'names': model.names} 147 | for k, v in d.items(): 148 | meta = model_onnx.metadata_props.add() 149 | meta.key, meta.value = k, str(v) 150 | onnx.save(model_onnx, f) 151 | 152 | # Simplify 153 | if simplify: 154 | try: 155 | check_requirements(('onnx-simplifier',)) 156 | import onnxsim 157 | 158 | LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') 159 | model_onnx, check = onnxsim.simplify(model_onnx, 160 | dynamic_input_shape=dynamic, 161 | input_shapes={'images': list(im.shape)} if dynamic else None) 162 | assert check, 'assert check failed' 163 | onnx.save(model_onnx, f) 164 | except Exception as e: 165 | LOGGER.info(f'{prefix} simplifier failure: {e}') 166 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 167 | return f 168 | except Exception as e: 169 | LOGGER.info(f'{prefix} export failure: {e}') 170 | 171 | 172 | def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')): 173 | # YOLOv5 OpenVINO export 174 | try: 175 | check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/ 176 | import openvino.inference_engine as ie 177 | 178 | LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') 179 | f = str(file).replace('.pt', f'_openvino_model{os.sep}') 180 | 181 | cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" 182 | subprocess.check_output(cmd.split()) # export 183 | with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g: 184 | yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml 185 | 186 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 187 | return f 188 | except Exception as e: 189 | LOGGER.info(f'\n{prefix} export failure: {e}') 190 | 191 | 192 | def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): 193 | # YOLOv5 CoreML export 194 | try: 195 | check_requirements(('coremltools',)) 196 | import coremltools as ct 197 | 198 | LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') 199 | f = file.with_suffix('.mlmodel') 200 | 201 | ts = torch.jit.trace(model, im, strict=False) # TorchScript model 202 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) 203 | bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) 204 | if bits < 32: 205 | if platform.system() == 'Darwin': # quantization only supported on macOS 206 | with warnings.catch_warnings(): 207 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning 208 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) 209 | else: 210 | print(f'{prefix} quantization only supported on macOS, skipping...') 211 | ct_model.save(f) 212 | 213 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 214 | return ct_model, f 215 | except Exception as e: 216 | LOGGER.info(f'\n{prefix} export failure: {e}') 217 | return None, None 218 | 219 | 220 | def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): 221 | # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt 222 | try: 223 | assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' 224 | try: 225 | import tensorrt as trt 226 | except Exception: 227 | if platform.system() == 'Linux': 228 | check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',)) 229 | import tensorrt as trt 230 | 231 | if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 232 | grid = model.model[-1].anchor_grid 233 | model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] 234 | export_onnx(model, im, file, 12, train, False, simplify) # opset 12 235 | model.model[-1].anchor_grid = grid 236 | else: # TensorRT >= 8 237 | check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 238 | export_onnx(model, im, file, 13, train, False, simplify) # opset 13 239 | onnx = file.with_suffix('.onnx') 240 | 241 | LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') 242 | assert onnx.exists(), f'failed to export ONNX file: {onnx}' 243 | f = file.with_suffix('.engine') # TensorRT engine file 244 | logger = trt.Logger(trt.Logger.INFO) 245 | if verbose: 246 | logger.min_severity = trt.Logger.Severity.VERBOSE 247 | 248 | builder = trt.Builder(logger) 249 | config = builder.create_builder_config() 250 | config.max_workspace_size = workspace * 1 << 30 251 | # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice 252 | 253 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 254 | network = builder.create_network(flag) 255 | parser = trt.OnnxParser(network, logger) 256 | if not parser.parse_from_file(str(onnx)): 257 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 258 | 259 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 260 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 261 | LOGGER.info(f'{prefix} Network Description:') 262 | for inp in inputs: 263 | LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}') 264 | for out in outputs: 265 | LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}') 266 | 267 | LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 else 32} engine in {f}') 268 | if builder.platform_has_fast_fp16: 269 | config.set_flag(trt.BuilderFlag.FP16) 270 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 271 | t.write(engine.serialize()) 272 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 273 | return f 274 | except Exception as e: 275 | LOGGER.info(f'\n{prefix} export failure: {e}') 276 | 277 | 278 | def export_saved_model(model, 279 | im, 280 | file, 281 | dynamic, 282 | tf_nms=False, 283 | agnostic_nms=False, 284 | topk_per_class=100, 285 | topk_all=100, 286 | iou_thres=0.45, 287 | conf_thres=0.25, 288 | keras=False, 289 | prefix=colorstr('TensorFlow SavedModel:')): 290 | # YOLOv5 TensorFlow SavedModel export 291 | try: 292 | import tensorflow as tf 293 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 294 | 295 | from models.tf import TFDetect, TFModel 296 | 297 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 298 | f = str(file).replace('.pt', '_saved_model') 299 | batch_size, ch, *imgsz = list(im.shape) # BCHW 300 | 301 | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) 302 | im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow 303 | _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 304 | inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) 305 | outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 306 | keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) 307 | keras_model.trainable = False 308 | keras_model.summary() 309 | if keras: 310 | keras_model.save(f, save_format='tf') 311 | else: 312 | spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) 313 | m = tf.function(lambda x: keras_model(x)) # full model 314 | m = m.get_concrete_function(spec) 315 | frozen_func = convert_variables_to_constants_v2(m) 316 | tfm = tf.Module() 317 | tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec]) 318 | tfm.__call__(im) 319 | tf.saved_model.save(tfm, 320 | f, 321 | options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) 322 | if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions()) 323 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 324 | return keras_model, f 325 | except Exception as e: 326 | LOGGER.info(f'\n{prefix} export failure: {e}') 327 | return None, None 328 | 329 | 330 | def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): 331 | # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow 332 | try: 333 | import tensorflow as tf 334 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 335 | 336 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 337 | f = file.with_suffix('.pb') 338 | 339 | m = tf.function(lambda x: keras_model(x)) # full model 340 | m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) 341 | frozen_func = convert_variables_to_constants_v2(m) 342 | frozen_func.graph.as_graph_def() 343 | tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) 344 | 345 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 346 | return f 347 | except Exception as e: 348 | LOGGER.info(f'\n{prefix} export failure: {e}') 349 | 350 | 351 | def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): 352 | # YOLOv5 TensorFlow Lite export 353 | try: 354 | import tensorflow as tf 355 | 356 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 357 | batch_size, ch, *imgsz = list(im.shape) # BCHW 358 | f = str(file).replace('.pt', '-fp16.tflite') 359 | 360 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) 361 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] 362 | converter.target_spec.supported_types = [tf.float16] 363 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 364 | if int8: 365 | from models.tf import representative_dataset_gen 366 | dataset = LoadImages(check_dataset(data)['train'], img_size=imgsz, auto=False) # representative data 367 | converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) 368 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 369 | converter.target_spec.supported_types = [] 370 | converter.inference_input_type = tf.uint8 # or tf.int8 371 | converter.inference_output_type = tf.uint8 # or tf.int8 372 | converter.experimental_new_quantizer = True 373 | f = str(file).replace('.pt', '-int8.tflite') 374 | if nms or agnostic_nms: 375 | converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) 376 | 377 | tflite_model = converter.convert() 378 | open(f, "wb").write(tflite_model) 379 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 380 | return f 381 | except Exception as e: 382 | LOGGER.info(f'\n{prefix} export failure: {e}') 383 | 384 | 385 | def export_edgetpu(file, prefix=colorstr('Edge TPU:')): 386 | # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ 387 | try: 388 | cmd = 'edgetpu_compiler --version' 389 | help_url = 'https://coral.ai/docs/edgetpu/compiler/' 390 | assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' 391 | if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0: 392 | LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') 393 | sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system 394 | for c in ( 395 | 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', 396 | 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 397 | 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): 398 | subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) 399 | ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] 400 | 401 | LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') 402 | f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model 403 | f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model 404 | 405 | cmd = f"edgetpu_compiler -s -o {file.parent} {f_tfl}" 406 | subprocess.run(cmd.split(), check=True) 407 | 408 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 409 | return f 410 | except Exception as e: 411 | LOGGER.info(f'\n{prefix} export failure: {e}') 412 | 413 | 414 | def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): 415 | # YOLOv5 TensorFlow.js export 416 | try: 417 | check_requirements(('tensorflowjs',)) 418 | import re 419 | 420 | import tensorflowjs as tfjs 421 | 422 | LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') 423 | f = str(file).replace('.pt', '_web_model') # js dir 424 | f_pb = file.with_suffix('.pb') # *.pb path 425 | f_json = f'{f}/model.json' # *.json path 426 | 427 | cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \ 428 | f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}' 429 | subprocess.run(cmd.split()) 430 | 431 | with open(f_json) as j: 432 | json = j.read() 433 | with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order 434 | subst = re.sub( 435 | r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' 436 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 437 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 438 | r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' 439 | r'"Identity_1": {"name": "Identity_1"}, ' 440 | r'"Identity_2": {"name": "Identity_2"}, ' 441 | r'"Identity_3": {"name": "Identity_3"}}}', json) 442 | j.write(subst) 443 | 444 | LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)') 445 | return f 446 | except Exception as e: 447 | LOGGER.info(f'\n{prefix} export failure: {e}') 448 | 449 | 450 | @torch.no_grad() 451 | def run( 452 | data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path' 453 | weights=ROOT / 'yolov5s.pt', # weights path 454 | imgsz=(640, 640), # image (height, width) 455 | batch_size=1, # batch size 456 | device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu 457 | include=('torchscript', 'onnx'), # include formats 458 | half=False, # FP16 half-precision export 459 | inplace=False, # set YOLOv5 Detect() inplace=True 460 | train=False, # model.train() mode 461 | keras=False, # use Keras 462 | optimize=False, # TorchScript: optimize for mobile 463 | int8=False, # CoreML/TF INT8 quantization 464 | dynamic=False, # ONNX/TF: dynamic axes 465 | simplify=False, # ONNX: simplify model 466 | opset=12, # ONNX: opset version 467 | verbose=False, # TensorRT: verbose log 468 | workspace=4, # TensorRT: workspace size (GB) 469 | nms=False, # TF: add NMS to model 470 | agnostic_nms=False, # TF: add agnostic NMS to model 471 | topk_per_class=100, # TF.js NMS: topk per class to keep 472 | topk_all=100, # TF.js NMS: topk for all classes to keep 473 | iou_thres=0.45, # TF.js NMS: IoU threshold 474 | conf_thres=0.25, # TF.js NMS: confidence threshold 475 | ): 476 | t = time.time() 477 | include = [x.lower() for x in include] # to lowercase 478 | formats = tuple(export_formats()['Argument'][1:]) # --include arguments 479 | flags = [x in include for x in formats] 480 | assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {formats}' 481 | jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = flags # export booleans 482 | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights 483 | 484 | # Load PyTorch model 485 | device = select_device(device) 486 | if half: 487 | assert device.type != 'cpu' or coreml or xml, '--half only compatible with GPU export, i.e. use --device 0' 488 | assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' 489 | model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model 490 | nc, names = model.nc, model.names # number of classes, class names 491 | 492 | # Checks 493 | imgsz *= 2 if len(imgsz) == 1 else 1 # expand 494 | assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}' 495 | 496 | # Input 497 | gs = int(max(model.stride)) # grid size (max stride) 498 | imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples 499 | im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 500 | 501 | # Update model 502 | if half and not (coreml or xml): 503 | im, model = im.half(), model.half() # to FP16 504 | model.train() if train else model.eval() # training mode = no Detect() layer grid construction 505 | for k, m in model.named_modules(): 506 | if isinstance(m, Detect): 507 | m.inplace = inplace 508 | m.onnx_dynamic = dynamic 509 | m.export = True 510 | 511 | for _ in range(2): 512 | y = model(im) # dry runs 513 | shape = tuple(y[0].shape) # model output shape 514 | LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") 515 | 516 | # Exports 517 | f = [''] * 10 # exported filenames 518 | warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning 519 | if jit: 520 | f[0] = export_torchscript(model, im, file, optimize) 521 | if engine: # TensorRT required before ONNX 522 | f[1] = export_engine(model, im, file, train, half, simplify, workspace, verbose) 523 | if onnx or xml: # OpenVINO requires ONNX 524 | f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify) 525 | if xml: # OpenVINO 526 | f[3] = export_openvino(model, file, half) 527 | if coreml: 528 | _, f[4] = export_coreml(model, im, file, int8, half) 529 | 530 | # TensorFlow Exports 531 | if any((saved_model, pb, tflite, edgetpu, tfjs)): 532 | if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707 533 | check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow` 534 | assert not (tflite and tfjs), 'TFLite and TF.js models must be exported separately, please pass only one type.' 535 | model, f[5] = export_saved_model(model.cpu(), 536 | im, 537 | file, 538 | dynamic, 539 | tf_nms=nms or agnostic_nms or tfjs, 540 | agnostic_nms=agnostic_nms or tfjs, 541 | topk_per_class=topk_per_class, 542 | topk_all=topk_all, 543 | iou_thres=iou_thres, 544 | conf_thres=conf_thres, 545 | keras=keras) 546 | if pb or tfjs: # pb prerequisite to tfjs 547 | f[6] = export_pb(model, file) 548 | if tflite or edgetpu: 549 | f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) 550 | if edgetpu: 551 | f[8] = export_edgetpu(file) 552 | if tfjs: 553 | f[9] = export_tfjs(file) 554 | 555 | # Finish 556 | f = [str(x) for x in f if x] # filter out '' and None 557 | if any(f): 558 | LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)' 559 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}" 560 | f"\nDetect: python detect.py --weights {f[-1]}" 561 | f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}')" 562 | f"\nValidate: python val.py --weights {f[-1]}" 563 | f"\nVisualize: https://netron.app") 564 | return f # return list of exported files/dirs 565 | 566 | 567 | def parse_opt(): 568 | parser = argparse.ArgumentParser() 569 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') 570 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') 571 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') 572 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 573 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 574 | parser.add_argument('--half', action='store_true', help='FP16 half-precision export') 575 | parser.add_argument('--inplace', action='store_true', help='set YOLOv5 Detect() inplace=True') 576 | parser.add_argument('--train', action='store_true', help='model.train() mode') 577 | parser.add_argument('--keras', action='store_true', help='TF: use Keras') 578 | parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') 579 | parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') 580 | parser.add_argument('--dynamic', action='store_true', help='ONNX/TF: dynamic axes') 581 | parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') 582 | parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version') 583 | parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') 584 | parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') 585 | parser.add_argument('--nms', action='store_true', help='TF: add NMS to model') 586 | parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model') 587 | parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') 588 | parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep') 589 | parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold') 590 | parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold') 591 | parser.add_argument('--include', 592 | nargs='+', 593 | default=['torchscript', 'onnx'], 594 | help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs') 595 | opt = parser.parse_args() 596 | print_args(vars(opt)) 597 | return opt 598 | 599 | 600 | def main(opt): 601 | for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]): 602 | run(**vars(opt)) 603 | 604 | 605 | if __name__ == "__main__": 606 | opt = parse_opt() 607 | main(opt) 608 | -------------------------------------------------------------------------------- /image/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/2.jpg -------------------------------------------------------------------------------- /image/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/6.jpg -------------------------------------------------------------------------------- /image/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/7.png -------------------------------------------------------------------------------- /image/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/8.jpg -------------------------------------------------------------------------------- /image/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/9.jpg -------------------------------------------------------------------------------- /image/chepai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/chepai.png -------------------------------------------------------------------------------- /image/chepaiwenzi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/chepaiwenzi.png -------------------------------------------------------------------------------- /image/gradio1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/gradio1.png -------------------------------------------------------------------------------- /image/hecheng.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/hecheng.png -------------------------------------------------------------------------------- /image/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/image/result.jpg -------------------------------------------------------------------------------- /kenshutsu.py: -------------------------------------------------------------------------------- 1 | # import argparse 2 | import os 3 | import sys 4 | from pathlib import Path 5 | import numpy 6 | import torch 7 | # import torch.backends.cudnn as cudnn 8 | from read_plate import ReadPlate 9 | 10 | FILE = Path(__file__).resolve() 11 | ROOT = FILE.parents[0] # YOLOv5 root directory 12 | if str(ROOT) not in sys.path: 13 | sys.path.append(str(ROOT)) # add ROOT to PATH 14 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 15 | 16 | from models.common import DetectMultiBackend 17 | # from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams 18 | from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2, 19 | increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh) 20 | # from utils.plots import Annotator, colors, save_one_box 21 | from utils.torch_utils import select_device, time_sync 22 | from PIL import Image, ImageDraw, ImageFont 23 | 24 | 25 | class Kenshutsu(object): 26 | 27 | def __init__(self, is_cuda): 28 | device = '0' if is_cuda and torch.cuda.is_available() else 'cpu' 29 | weights = './weights/yolov7.pt' 30 | if not os.path.exists(weights): 31 | raise RuntimeError('Model parameters not found') 32 | self.device = select_device(device) 33 | self.model = DetectMultiBackend(weights, device=self.device) 34 | imgsz = (640, 640) 35 | stride, names, pt = self.model.stride, self.model.names, self.model.pt 36 | imgsz = check_img_size(imgsz, s=stride) 37 | bs = 1 38 | self.model.warmup(imgsz=(1 if pt else bs, 3, *imgsz)) # warmup 39 | self.agnostic_nms = False 40 | self.classes = None 41 | self.iou_thres = 0.45 42 | self.conf_thres = 0.25 43 | 44 | def __call__(self, image): 45 | h, w, c = image.shape 46 | image, h2, w2, fx = self.square_picture(image, 640) 47 | image_tensor = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 48 | image_tensor = numpy.transpose(image_tensor, axes=(2, 0, 1)) / 255 49 | image_tensor = torch.from_numpy(image_tensor).float().to(self.device) 50 | image_tensor = image_tensor.unsqueeze(0) 51 | pred = self.model(image_tensor) 52 | pred = \ 53 | non_max_suppression(pred, self.conf_thres, self.iou_thres, self.classes, self.agnostic_nms, max_det=1000)[0] 54 | boxes = pred.cpu() 55 | result = [] 56 | for box in boxes: 57 | x1, y1, x2, y2, the, c = box 58 | x1, y1, x2, y2 = max(0, int((x1 - (640 // 2 - w2 // 2)) / fx)), max(0, int((y1 - ( 59 | 640 // 2 - h2 // 2)) / fx)), min(w, int((x2 - (640 // 2 - w2 // 2)) / fx)), min(h, int((y2 - ( 60 | 640 // 2 - h2 // 2)) / fx)) 61 | result.append([x1, y1, x2, y2, the, c]) 62 | return result 63 | 64 | @staticmethod 65 | def square_picture(image, image_size): 66 | h1, w1, _ = image.shape 67 | max_len = max(h1, w1) 68 | if max_len >= image_size: 69 | fx = image_size / max_len 70 | fy = image_size / max_len 71 | image = cv2.resize(image, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) 72 | h2, w2, _ = image.shape 73 | background = numpy.zeros((image_size, image_size, 3), dtype=numpy.uint8) 74 | background[:, :, :] = 127 75 | s_h = image_size // 2 - h2 // 2 76 | s_w = image_size // 2 - w2 // 2 77 | background[s_h:s_h + h2, s_w:s_w + w2] = image 78 | return background, h2, w2, fx 79 | else: 80 | h2, w2, _ = image.shape 81 | background = numpy.zeros((image_size, image_size, 3), dtype=numpy.uint8) 82 | background[:, :, :] = 127 83 | s_h = image_size // 2 - h2 // 2 84 | s_w = image_size // 2 - w2 // 2 85 | background[s_h:s_h + h2, s_w:s_w + w2] = image 86 | return background, h2, w2, 1 87 | 88 | 89 | def DrawChinese(img, text, positive, fontSize=20, fontColor=( 90 | 255, 0, 0)): # args-(img:numpy.ndarray, text:中文文本, positive:位置, fontSize:字体大小默认20, fontColor:字体颜色默认绿色) 91 | cv2img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # cv2和PIL中颜色的hex码的储存顺序不同 92 | pilimg = Image.fromarray(cv2img) 93 | # PIL图片上打印汉字 94 | draw = ImageDraw.Draw(pilimg) # 图片上打印 95 | font = ImageFont.truetype("MSJHL.TTC", fontSize, encoding="utf-8") # 参数1:字体文件路径,参数2:字体大小 96 | draw.text(positive, text, fontColor, font=font) # 参数1:打印坐标,参数2:文本,参数3:字体颜色,参数4:字体格式 97 | print(text) 98 | cv2charimg = cv2.cvtColor(numpy.array(pilimg), cv2.COLOR_RGB2BGR) # PIL图片转cv2 图片 99 | return cv2charimg 100 | 101 | 102 | if __name__ == '__main__': 103 | import cv2 104 | from shutil import copy 105 | 106 | class_name = ['main'] 107 | root = 'videoToImg' # 检测视屏帧用这个 108 | # root = "test_image" # 测试图片用这个 109 | detecter = Kenshutsu(False) 110 | read_plate = ReadPlate() 111 | count = 0 112 | 113 | output_dir = 'imgToVideo' # 新建保存图像的目录 114 | os.makedirs(output_dir, exist_ok=True) # 确保目录存在 115 | 116 | for image_name in os.listdir(root): 117 | image_path = f'{root}/{image_name}' 118 | image = cv2.imread(image_path) 119 | boxes = detecter(image) 120 | plates = [] 121 | for box in boxes: 122 | x1, y1, x2, y2, the, c = box 123 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 124 | if c == 2 or c == 5: 125 | image_ = image[y1:y2, x1:x2] 126 | result = read_plate(image_) 127 | if result: 128 | plate, (x11, y11, x22, y22) = result[0] 129 | plates.append((x1, y1, x2, y2, plate, x11 + x1, y11 + y1, x22 + x1, y22 + y1)) 130 | for plate in plates: 131 | x1, y1, x2, y2, plate_name, x11, y11, x22, y22 = plate 132 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 133 | x11, y11, x22, y22 = int(x11), int(y11), int(x22), int(y22) 134 | image = cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2) 135 | image = cv2.rectangle(image, (x11 - 5, y11 - 5), (x22 + 5, y22 + 5), (0, 0, 255), 2) 136 | image = DrawChinese(image, plate_name, (x11, y22), 30) 137 | 138 | # 在指定文件夹下保存图像 139 | output_path = os.path.join(output_dir, image_name) 140 | cv2.imwrite(output_path, image) 141 | 142 | print(image_name) 143 | # cv2.imshow('a', image) # 如果要边检测边查看检测结果图,就把这句和下面一句话打开 144 | # cv2.waitKey() 145 | -------------------------------------------------------------------------------- /ocr_config.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class_name = ['*', "皖", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "京", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂", "琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "警", "学",'港','澳','A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 4 | 'Y', 'Z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 5 | 6 | device = 'cuda:0' 7 | print(len(class_name)) 8 | device = torch.device(device if torch.cuda.is_available() else 'cpu') 9 | num_class = len(class_name) 10 | weight = 'weights/ocr_net2.pth' 11 | # print(num_class) -------------------------------------------------------------------------------- /ocr_test.py: -------------------------------------------------------------------------------- 1 | from fake_chs_lp.random_plate import Draw 2 | # from models.ocr_net2 import OcrNet 3 | from ocr_explorer import Explorer 4 | import cv2 5 | import torch 6 | 7 | draw = Draw() 8 | explorer = Explorer() 9 | yes = 0 10 | count = 0 11 | for i in range(1000): 12 | plate, label = draw() 13 | plate = cv2.cvtColor(plate, cv2.COLOR_RGB2BGR) 14 | plate = cv2.resize(plate, (144, 48)) 15 | cv2.imshow('a', plate) 16 | a = explorer(plate) 17 | if a == label: 18 | yes += 1 19 | count += 1 20 | print(a) 21 | # print(a) 22 | # cv2.waitKey(0) 23 | print(yes / count, yes, count) 24 | # cv2.waitKey() 25 | -------------------------------------------------------------------------------- /read_plate.py: -------------------------------------------------------------------------------- 1 | from detect_explorer import DExplorer 2 | from ocr_explorer import Explorer 3 | import cv2 4 | import numpy 5 | 6 | 7 | class ReadPlate: 8 | """ 9 | 读取车牌号 10 | 传入侦测到的车辆图片,即可识别车牌号。 11 | 返回: 12 | [[车牌号,回归框],……] 13 | """ 14 | def __init__(self): 15 | self.detect_exp = DExplorer() 16 | self.ocr_exp = Explorer() 17 | 18 | def __call__(self, image): 19 | points = self.detect_exp(image) 20 | h, w, _ = image.shape 21 | result = [] 22 | # print(points) 23 | for point, _ in points: 24 | plate, box = self.cutout_plate(image, point) 25 | # print(box) 26 | lp = self.ocr_exp(plate) 27 | result.append([lp, box]) 28 | # cv2.imshow('b', plate) 29 | # cv2.waitKey() 30 | return result 31 | 32 | def cutout_plate(self, image, point): 33 | h, w, _ = image.shape 34 | x1, x2, x3, x4, y1, y2, y3, y4 = point.reshape(-1) 35 | x1, x2, x3, x4 = x1 * w, x2 * w, x3 * w, x4 * w 36 | y1, y2, y3, y4 = y1 * h, y2 * h, y3 * h, y4 * h 37 | src = numpy.array([[x1, y1], [x2, y2], [x4, y4], [x3, y3]], dtype="float32") 38 | dst = numpy.array([[0, 0], [144, 0], [0, 48], [144, 48]], dtype="float32") 39 | box = [min(x1, x2, x3, x4), min(y1, y2, y3, y4), max(x1, x2, x3, x4), max(y1, y2, y3, y4)] 40 | M = cv2.getPerspectiveTransform(src, dst) 41 | out_img = cv2.warpPerspective(image, M, (144, 48)) 42 | return out_img, box 43 | 44 | 45 | if __name__ == '__main__': 46 | read_plate = ReadPlate() 47 | image = cv2.imread('test_image.jpg') 48 | # image = cv2.imread('2.png') 49 | boxes = read_plate(image) 50 | print(boxes) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch~=1.12.1 2 | pyyaml~=6.0 3 | tqdm~=4.64.1 4 | opencv-python~=4.6.0.66 5 | matplotlib~=3.6.2 6 | numpy~=1.24.1 7 | pandas~=1.5.2 8 | seaborn~=0.12.1 9 | pillow~=8.2.0 10 | torchvision~=0.13.1 11 | requests~=2.28.1 12 | flask~=2.2.2 13 | thop~=0.1.1-2209072238 14 | einops~=0.6.0 -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | utils/initialization 4 | """ 5 | 6 | 7 | def notebook_init(verbose=True): 8 | # Check system software and hardware 9 | print('Checking setup...') 10 | 11 | import os 12 | import shutil 13 | 14 | from utils.general import check_requirements, emojis, is_colab 15 | from utils.torch_utils import select_device # imports 16 | 17 | check_requirements(('psutil', 'IPython')) 18 | import psutil 19 | from IPython import display # to display images and clear console output 20 | 21 | if is_colab(): 22 | shutil.rmtree('/content/sample_data', ignore_errors=True) # remove colab /sample_data directory 23 | 24 | # System info 25 | if verbose: 26 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 27 | ram = psutil.virtual_memory().total 28 | total, used, free = shutil.disk_usage("/") 29 | display.clear_output() 30 | s = f'({os.cpu_count()} CPUs, {ram / gb:.1f} GB RAM, {(total - free) / gb:.1f}/{total / gb:.1f} GB disk)' 31 | else: 32 | s = '' 33 | 34 | select_device(newline=False) 35 | print(emojis(f'Setup complete ✅ {s}')) 36 | return display 37 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/activations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/activations.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/augmentations.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/augmentations.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/augmentations.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/augmentations.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoanchor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/autoanchor.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoanchor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/autoanchor.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoanchor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/autoanchor.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autoanchor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/autoanchor.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/autobatch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/autobatch.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/benchmarks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/benchmarks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloaders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/dataloaders.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloaders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/dataloaders.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloaders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/dataloaders.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloaders.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/dataloaders.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/downloads.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/downloads.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/downloads.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/downloads.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/downloads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/downloads.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/downloads.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/downloads.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/general.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/general.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/general.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/general.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/general.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/metrics.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/plots.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/plots.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/plots.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/plots.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/plots.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/torch_utils.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/torch_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/torch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/torch_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/__pycache__/torch_utils.cpython-39.pyc -------------------------------------------------------------------------------- /utils/activations.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Activation functions 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class SiLU(nn.Module): 12 | # SiLU activation https://arxiv.org/pdf/1606.08415.pdf 13 | @staticmethod 14 | def forward(x): 15 | return x * torch.sigmoid(x) 16 | 17 | 18 | class Hardswish(nn.Module): 19 | # Hard-SiLU activation 20 | @staticmethod 21 | def forward(x): 22 | # return x * F.hardsigmoid(x) # for TorchScript and CoreML 23 | return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0 # for TorchScript, CoreML and ONNX 24 | 25 | 26 | class Mish(nn.Module): 27 | # Mish activation https://github.com/digantamisra98/Mish 28 | @staticmethod 29 | def forward(x): 30 | return x * F.softplus(x).tanh() 31 | 32 | 33 | class MemoryEfficientMish(nn.Module): 34 | # Mish activation memory-efficient 35 | class F(torch.autograd.Function): 36 | 37 | @staticmethod 38 | def forward(ctx, x): 39 | ctx.save_for_backward(x) 40 | return x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) 41 | 42 | @staticmethod 43 | def backward(ctx, grad_output): 44 | x = ctx.saved_tensors[0] 45 | sx = torch.sigmoid(x) 46 | fx = F.softplus(x).tanh() 47 | return grad_output * (fx + x * sx * (1 - fx * fx)) 48 | 49 | def forward(self, x): 50 | return self.F.apply(x) 51 | 52 | 53 | class FReLU(nn.Module): 54 | # FReLU activation https://arxiv.org/abs/2007.11824 55 | def __init__(self, c1, k=3): # ch_in, kernel 56 | super().__init__() 57 | self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False) 58 | self.bn = nn.BatchNorm2d(c1) 59 | 60 | def forward(self, x): 61 | return torch.max(x, self.bn(self.conv(x))) 62 | 63 | 64 | class AconC(nn.Module): 65 | r""" ACON activation (activate or not) 66 | AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameter 67 | according to "Activate or Not: Learning Customized Activation" . 68 | """ 69 | 70 | def __init__(self, c1): 71 | super().__init__() 72 | self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) 73 | self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) 74 | self.beta = nn.Parameter(torch.ones(1, c1, 1, 1)) 75 | 76 | def forward(self, x): 77 | dpx = (self.p1 - self.p2) * x 78 | return dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x 79 | 80 | 81 | class MetaAconC(nn.Module): 82 | r""" ACON activation (activate or not) 83 | MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small network 84 | according to "Activate or Not: Learning Customized Activation" . 85 | """ 86 | 87 | def __init__(self, c1, k=1, s=1, r=16): # ch_in, kernel, stride, r 88 | super().__init__() 89 | c2 = max(r, c1 // r) 90 | self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1)) 91 | self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1)) 92 | self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True) 93 | self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True) 94 | # self.bn1 = nn.BatchNorm2d(c2) 95 | # self.bn2 = nn.BatchNorm2d(c1) 96 | 97 | def forward(self, x): 98 | y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True) 99 | # batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891 100 | # beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y))))) # bug/unstable 101 | beta = torch.sigmoid(self.fc2(self.fc1(y))) # bug patch BN layers removed 102 | dpx = (self.p1 - self.p2) * x 103 | return dpx * torch.sigmoid(beta * dpx) + self.p2 * x 104 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Image augmentation functions 4 | """ 5 | 6 | import math 7 | import random 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box 13 | from utils.metrics import bbox_ioa 14 | 15 | 16 | class Albumentations: 17 | # YOLOv5 Albumentations class (optional, only used if package is installed) 18 | def __init__(self): 19 | self.transform = None 20 | try: 21 | import albumentations as A 22 | check_version(A.__version__, '1.0.3', hard=True) # version requirement 23 | 24 | T = [ 25 | A.Blur(p=0.01), 26 | A.MedianBlur(p=0.01), 27 | A.ToGray(p=0.01), 28 | A.CLAHE(p=0.01), 29 | A.RandomBrightnessContrast(p=0.0), 30 | A.RandomGamma(p=0.0), 31 | A.ImageCompression(quality_lower=75, p=0.0)] # transforms 32 | self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels'])) 33 | 34 | LOGGER.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p)) 35 | except ImportError: # package not installed, skip 36 | pass 37 | except Exception as e: 38 | LOGGER.info(colorstr('albumentations: ') + f'{e}') 39 | 40 | def __call__(self, im, labels, p=1.0): 41 | if self.transform and random.random() < p: 42 | new = self.transform(image=im, bboxes=labels[:, 1:], class_labels=labels[:, 0]) # transformed 43 | im, labels = new['image'], np.array([[c, *b] for c, b in zip(new['class_labels'], new['bboxes'])]) 44 | return im, labels 45 | 46 | 47 | def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): 48 | # HSV color-space augmentation 49 | if hgain or sgain or vgain: 50 | r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains 51 | hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) 52 | dtype = im.dtype # uint8 53 | 54 | x = np.arange(0, 256, dtype=r.dtype) 55 | lut_hue = ((x * r[0]) % 180).astype(dtype) 56 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 57 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 58 | 59 | im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 60 | cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed 61 | 62 | 63 | def hist_equalize(im, clahe=True, bgr=False): 64 | # Equalize histogram on BGR image 'im' with im.shape(n,m,3) and range 0-255 65 | yuv = cv2.cvtColor(im, cv2.COLOR_BGR2YUV if bgr else cv2.COLOR_RGB2YUV) 66 | if clahe: 67 | c = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) 68 | yuv[:, :, 0] = c.apply(yuv[:, :, 0]) 69 | else: 70 | yuv[:, :, 0] = cv2.equalizeHist(yuv[:, :, 0]) # equalize Y channel histogram 71 | return cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR if bgr else cv2.COLOR_YUV2RGB) # convert YUV image to RGB 72 | 73 | 74 | def replicate(im, labels): 75 | # Replicate labels 76 | h, w = im.shape[:2] 77 | boxes = labels[:, 1:].astype(int) 78 | x1, y1, x2, y2 = boxes.T 79 | s = ((x2 - x1) + (y2 - y1)) / 2 # side length (pixels) 80 | for i in s.argsort()[:round(s.size * 0.5)]: # smallest indices 81 | x1b, y1b, x2b, y2b = boxes[i] 82 | bh, bw = y2b - y1b, x2b - x1b 83 | yc, xc = int(random.uniform(0, h - bh)), int(random.uniform(0, w - bw)) # offset x, y 84 | x1a, y1a, x2a, y2a = [xc, yc, xc + bw, yc + bh] 85 | im[y1a:y2a, x1a:x2a] = im[y1b:y2b, x1b:x2b] # im4[ymin:ymax, xmin:xmax] 86 | labels = np.append(labels, [[labels[i, 0], x1a, y1a, x2a, y2a]], axis=0) 87 | 88 | return im, labels 89 | 90 | 91 | def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32): 92 | # Resize and pad image while meeting stride-multiple constraints 93 | shape = im.shape[:2] # current shape [height, width] 94 | if isinstance(new_shape, int): 95 | new_shape = (new_shape, new_shape) 96 | 97 | # Scale ratio (new / old) 98 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 99 | if not scaleup: # only scale down, do not scale up (for better val mAP) 100 | r = min(r, 1.0) 101 | 102 | # Compute padding 103 | ratio = r, r # width, height ratios 104 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 105 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 106 | if auto: # minimum rectangle 107 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 108 | elif scaleFill: # stretch 109 | dw, dh = 0.0, 0.0 110 | new_unpad = (new_shape[1], new_shape[0]) 111 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 112 | 113 | dw /= 2 # divide padding into 2 sides 114 | dh /= 2 115 | 116 | if shape[::-1] != new_unpad: # resize 117 | im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) 118 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 119 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 120 | im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 121 | return im, ratio, (dw, dh) 122 | 123 | 124 | def random_perspective(im, 125 | targets=(), 126 | segments=(), 127 | degrees=10, 128 | translate=.1, 129 | scale=.1, 130 | shear=10, 131 | perspective=0.0, 132 | border=(0, 0)): 133 | # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10)) 134 | # targets = [cls, xyxy] 135 | 136 | height = im.shape[0] + border[0] * 2 # shape(h,w,c) 137 | width = im.shape[1] + border[1] * 2 138 | 139 | # Center 140 | C = np.eye(3) 141 | C[0, 2] = -im.shape[1] / 2 # x translation (pixels) 142 | C[1, 2] = -im.shape[0] / 2 # y translation (pixels) 143 | 144 | # Perspective 145 | P = np.eye(3) 146 | P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) 147 | P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) 148 | 149 | # Rotation and Scale 150 | R = np.eye(3) 151 | a = random.uniform(-degrees, degrees) 152 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 153 | s = random.uniform(1 - scale, 1 + scale) 154 | # s = 2 ** random.uniform(-scale, scale) 155 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) 156 | 157 | # Shear 158 | S = np.eye(3) 159 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 160 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 161 | 162 | # Translation 163 | T = np.eye(3) 164 | T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) 165 | T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) 166 | 167 | # Combined rotation matrix 168 | M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT 169 | if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed 170 | if perspective: 171 | im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114)) 172 | else: # affine 173 | im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) 174 | 175 | # Visualize 176 | # import matplotlib.pyplot as plt 177 | # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() 178 | # ax[0].imshow(im[:, :, ::-1]) # base 179 | # ax[1].imshow(im2[:, :, ::-1]) # warped 180 | 181 | # Transform label coordinates 182 | n = len(targets) 183 | if n: 184 | use_segments = any(x.any() for x in segments) 185 | new = np.zeros((n, 4)) 186 | if use_segments: # warp segments 187 | segments = resample_segments(segments) # upsample 188 | for i, segment in enumerate(segments): 189 | xy = np.ones((len(segment), 3)) 190 | xy[:, :2] = segment 191 | xy = xy @ M.T # transform 192 | xy = xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2] # perspective rescale or affine 193 | 194 | # clip 195 | new[i] = segment2box(xy, width, height) 196 | 197 | else: # warp boxes 198 | xy = np.ones((n * 4, 3)) 199 | xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 200 | xy = xy @ M.T # transform 201 | xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine 202 | 203 | # create new boxes 204 | x = xy[:, [0, 2, 4, 6]] 205 | y = xy[:, [1, 3, 5, 7]] 206 | new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T 207 | 208 | # clip 209 | new[:, [0, 2]] = new[:, [0, 2]].clip(0, width) 210 | new[:, [1, 3]] = new[:, [1, 3]].clip(0, height) 211 | 212 | # filter candidates 213 | i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10) 214 | targets = targets[i] 215 | targets[:, 1:5] = new[i] 216 | 217 | return im, targets 218 | 219 | 220 | def copy_paste(im, labels, segments, p=0.5): 221 | # Implement Copy-Paste augmentation https://arxiv.org/abs/2012.07177, labels as nx5 np.array(cls, xyxy) 222 | n = len(segments) 223 | if p and n: 224 | h, w, c = im.shape # height, width, channels 225 | im_new = np.zeros(im.shape, np.uint8) 226 | for j in random.sample(range(n), k=round(p * n)): 227 | l, s = labels[j], segments[j] 228 | box = w - l[3], l[2], w - l[1], l[4] 229 | ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area 230 | if (ioa < 0.30).all(): # allow 30% obscuration of existing labels 231 | labels = np.concatenate((labels, [[l[0], *box]]), 0) 232 | segments.append(np.concatenate((w - s[:, 0:1], s[:, 1:2]), 1)) 233 | cv2.drawContours(im_new, [segments[j].astype(np.int32)], -1, (255, 255, 255), cv2.FILLED) 234 | 235 | result = cv2.bitwise_and(src1=im, src2=im_new) 236 | result = cv2.flip(result, 1) # augment segments (flip left-right) 237 | i = result > 0 # pixels to replace 238 | # i[:, :] = result.max(2).reshape(h, w, 1) # act over ch 239 | im[i] = result[i] # cv2.imwrite('debug.jpg', im) # debug 240 | 241 | return im, labels, segments 242 | 243 | 244 | def cutout(im, labels, p=0.5): 245 | # Applies image cutout augmentation https://arxiv.org/abs/1708.04552 246 | if random.random() < p: 247 | h, w = im.shape[:2] 248 | scales = [0.5] * 1 + [0.25] * 2 + [0.125] * 4 + [0.0625] * 8 + [0.03125] * 16 # image size fraction 249 | for s in scales: 250 | mask_h = random.randint(1, int(h * s)) # create random masks 251 | mask_w = random.randint(1, int(w * s)) 252 | 253 | # box 254 | xmin = max(0, random.randint(0, w) - mask_w // 2) 255 | ymin = max(0, random.randint(0, h) - mask_h // 2) 256 | xmax = min(w, xmin + mask_w) 257 | ymax = min(h, ymin + mask_h) 258 | 259 | # apply random color mask 260 | im[ymin:ymax, xmin:xmax] = [random.randint(64, 191) for _ in range(3)] 261 | 262 | # return unobscured labels 263 | if len(labels) and s > 0.03: 264 | box = np.array([xmin, ymin, xmax, ymax], dtype=np.float32) 265 | ioa = bbox_ioa(box, labels[:, 1:5]) # intersection over area 266 | labels = labels[ioa < 0.60] # remove >60% obscured labels 267 | 268 | return labels 269 | 270 | 271 | def mixup(im, labels, im2, labels2): 272 | # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf 273 | r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0 274 | im = (im * r + im2 * (1 - r)).astype(np.uint8) 275 | labels = np.concatenate((labels, labels2), 0) 276 | return im, labels 277 | 278 | 279 | def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16): # box1(4,n), box2(4,n) 280 | # Compute candidate boxes: box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio 281 | w1, h1 = box1[2] - box1[0], box1[3] - box1[1] 282 | w2, h2 = box2[2] - box2[0], box2[3] - box2[1] 283 | ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio 284 | return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates 285 | -------------------------------------------------------------------------------- /utils/autoanchor.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | AutoAnchor utils 4 | """ 5 | 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | import yaml 11 | from tqdm import tqdm 12 | 13 | from utils.general import LOGGER, colorstr, emojis 14 | 15 | PREFIX = colorstr('AutoAnchor: ') 16 | 17 | 18 | def check_anchor_order(m): 19 | # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary 20 | a = m.anchors.prod(-1).mean(-1).view(-1) # mean anchor area per output layer 21 | da = a[-1] - a[0] # delta a 22 | ds = m.stride[-1] - m.stride[0] # delta s 23 | if da and (da.sign() != ds.sign()): # same order 24 | LOGGER.info(f'{PREFIX}Reversing anchor order') 25 | m.anchors[:] = m.anchors.flip(0) 26 | 27 | 28 | def check_anchors(dataset, model, thr=4.0, imgsz=640): 29 | # Check anchor fit to data, recompute if necessary 30 | m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect() 31 | shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True) 32 | scale = np.random.uniform(0.9, 1.1, size=(shapes.shape[0], 1)) # augment scale 33 | wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes * scale, dataset.labels)])).float() # wh 34 | 35 | def metric(k): # compute metric 36 | r = wh[:, None] / k[None] 37 | x = torch.min(r, 1 / r).min(2)[0] # ratio metric 38 | best = x.max(1)[0] # best_x 39 | aat = (x > 1 / thr).float().sum(1).mean() # anchors above threshold 40 | bpr = (best > 1 / thr).float().mean() # best possible recall 41 | return bpr, aat 42 | 43 | stride = m.stride.to(m.anchors.device).view(-1, 1, 1) # model strides 44 | anchors = m.anchors.clone() * stride # current anchors 45 | bpr, aat = metric(anchors.cpu().view(-1, 2)) 46 | s = f'\n{PREFIX}{aat:.2f} anchors/target, {bpr:.3f} Best Possible Recall (BPR). ' 47 | if bpr > 0.98: # threshold to recompute 48 | LOGGER.info(emojis(f'{s}Current anchors are a good fit to dataset ✅')) 49 | else: 50 | LOGGER.info(emojis(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')) 51 | na = m.anchors.numel() // 2 # number of anchors 52 | try: 53 | anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False) 54 | except Exception as e: 55 | LOGGER.info(f'{PREFIX}ERROR: {e}') 56 | new_bpr = metric(anchors)[0] 57 | if new_bpr > bpr: # replace anchors 58 | anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors) 59 | m.anchors[:] = anchors.clone().view_as(m.anchors) 60 | check_anchor_order(m) # must be in pixel-space (not grid-space) 61 | m.anchors /= stride 62 | s = f'{PREFIX}Done ✅ (optional: update model *.yaml to use these anchors in the future)' 63 | else: 64 | s = f'{PREFIX}Done ⚠️ (original anchors better than new anchors, proceeding with original anchors)' 65 | LOGGER.info(emojis(s)) 66 | 67 | 68 | def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=1000, verbose=True): 69 | """ Creates kmeans-evolved anchors from training dataset 70 | 71 | Arguments: 72 | dataset: path to data.yaml, or a loaded dataset 73 | n: number of anchors 74 | img_size: image size used for training 75 | thr: anchor-label wh ratio threshold hyperparameter hyp['anchor_t'] used for training, default=4.0 76 | gen: generations to evolve anchors using genetic algorithm 77 | verbose: print all results 78 | 79 | Return: 80 | k: kmeans evolved anchors 81 | 82 | Usage: 83 | from utils.autoanchor import *; _ = kmean_anchors() 84 | """ 85 | from scipy.cluster.vq import kmeans 86 | 87 | npr = np.random 88 | thr = 1 / thr 89 | 90 | def metric(k, wh): # compute metrics 91 | r = wh[:, None] / k[None] 92 | x = torch.min(r, 1 / r).min(2)[0] # ratio metric 93 | # x = wh_iou(wh, torch.tensor(k)) # iou metric 94 | return x, x.max(1)[0] # x, best_x 95 | 96 | def anchor_fitness(k): # mutation fitness 97 | _, best = metric(torch.tensor(k, dtype=torch.float32), wh) 98 | return (best * (best > thr).float()).mean() # fitness 99 | 100 | def print_results(k, verbose=True): 101 | k = k[np.argsort(k.prod(1))] # sort small to large 102 | x, best = metric(k, wh0) 103 | bpr, aat = (best > thr).float().mean(), (x > thr).float().mean() * n # best possible recall, anch > thr 104 | s = f'{PREFIX}thr={thr:.2f}: {bpr:.4f} best possible recall, {aat:.2f} anchors past thr\n' \ 105 | f'{PREFIX}n={n}, img_size={img_size}, metric_all={x.mean():.3f}/{best.mean():.3f}-mean/best, ' \ 106 | f'past_thr={x[x > thr].mean():.3f}-mean: ' 107 | for x in k: 108 | s += '%i,%i, ' % (round(x[0]), round(x[1])) 109 | if verbose: 110 | LOGGER.info(s[:-2]) 111 | return k 112 | 113 | if isinstance(dataset, str): # *.yaml file 114 | with open(dataset, errors='ignore') as f: 115 | data_dict = yaml.safe_load(f) # model dict 116 | from utils.dataloaders import LoadImagesAndLabels 117 | dataset = LoadImagesAndLabels(data_dict['train'], augment=True, rect=True) 118 | 119 | # Get label wh 120 | shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True) 121 | wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh 122 | 123 | # Filter 124 | i = (wh0 < 3.0).any(1).sum() 125 | if i: 126 | LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size') 127 | wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels 128 | # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1 129 | 130 | # Kmeans init 131 | try: 132 | LOGGER.info(f'{PREFIX}Running kmeans for {n} anchors on {len(wh)} points...') 133 | assert n <= len(wh) # apply overdetermined constraint 134 | s = wh.std(0) # sigmas for whitening 135 | k = kmeans(wh / s, n, iter=30)[0] * s # points 136 | assert n == len(k) # kmeans may return fewer points than requested if wh is insufficient or too similar 137 | except Exception: 138 | LOGGER.warning(f'{PREFIX}WARNING: switching strategies from kmeans to random init') 139 | k = np.sort(npr.rand(n * 2)).reshape(n, 2) * img_size # random init 140 | wh, wh0 = (torch.tensor(x, dtype=torch.float32) for x in (wh, wh0)) 141 | k = print_results(k, verbose=False) 142 | 143 | # Plot 144 | # k, d = [None] * 20, [None] * 20 145 | # for i in tqdm(range(1, 21)): 146 | # k[i-1], d[i-1] = kmeans(wh / s, i) # points, mean distance 147 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7), tight_layout=True) 148 | # ax = ax.ravel() 149 | # ax[0].plot(np.arange(1, 21), np.array(d) ** 2, marker='.') 150 | # fig, ax = plt.subplots(1, 2, figsize=(14, 7)) # plot wh 151 | # ax[0].hist(wh[wh[:, 0]<100, 0],400) 152 | # ax[1].hist(wh[wh[:, 1]<100, 1],400) 153 | # fig.savefig('wh.png', dpi=200) 154 | 155 | # Evolve 156 | f, sh, mp, s = anchor_fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma 157 | pbar = tqdm(range(gen), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar 158 | for _ in pbar: 159 | v = np.ones(sh) 160 | while (v == 1).all(): # mutate until a change occurs (prevent duplicates) 161 | v = ((npr.random(sh) < mp) * random.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0) 162 | kg = (k.copy() * v).clip(min=2.0) 163 | fg = anchor_fitness(kg) 164 | if fg > f: 165 | f, k = fg, kg.copy() 166 | pbar.desc = f'{PREFIX}Evolving anchors with Genetic Algorithm: fitness = {f:.4f}' 167 | if verbose: 168 | print_results(k, verbose) 169 | 170 | return print_results(k) 171 | -------------------------------------------------------------------------------- /utils/autobatch.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Auto-batch utils 4 | """ 5 | 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from utils.general import LOGGER, colorstr 12 | from utils.torch_utils import profile 13 | 14 | 15 | def check_train_batch_size(model, imgsz=640, amp=True): 16 | # Check YOLOv5 training batch size 17 | with torch.cuda.amp.autocast(amp): 18 | return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size 19 | 20 | 21 | def autobatch(model, imgsz=640, fraction=0.9, batch_size=16): 22 | # Automatically estimate best batch size to use `fraction` of available CUDA memory 23 | # Usage: 24 | # import torch 25 | # from utils.autobatch import autobatch 26 | # model = torch.hub.load('ultralytics/yolov5', 'yolov5s', autoshape=False) 27 | # print(autobatch(model)) 28 | 29 | prefix = colorstr('AutoBatch: ') 30 | LOGGER.info(f'{prefix}Computing optimal batch size for --imgsz {imgsz}') 31 | device = next(model.parameters()).device # get model device 32 | if device.type == 'cpu': 33 | LOGGER.info(f'{prefix}CUDA not detected, using default CPU batch-size {batch_size}') 34 | return batch_size 35 | 36 | gb = 1 << 30 # bytes to GiB (1024 ** 3) 37 | d = str(device).upper() # 'CUDA:0' 38 | properties = torch.cuda.get_device_properties(device) # device properties 39 | t = properties.total_memory / gb # (GiB) 40 | r = torch.cuda.memory_reserved(device) / gb # (GiB) 41 | a = torch.cuda.memory_allocated(device) / gb # (GiB) 42 | f = t - (r + a) # free inside reserved 43 | LOGGER.info(f'{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free') 44 | 45 | batch_sizes = [1, 2, 4, 8, 16] 46 | try: 47 | img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes] 48 | y = profile(img, model, n=3, device=device) 49 | except Exception as e: 50 | LOGGER.warning(f'{prefix}{e}') 51 | 52 | y = [x[2] for x in y if x] # memory [2] 53 | batch_sizes = batch_sizes[:len(y)] 54 | p = np.polyfit(batch_sizes, y, deg=1) # first degree polynomial fit 55 | b = int((f * fraction - p[1]) / p[0]) # y intercept (optimal batch size) 56 | LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%)') 57 | return b 58 | -------------------------------------------------------------------------------- /utils/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/aws/__init__.py -------------------------------------------------------------------------------- /utils/aws/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/aws/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/aws/__pycache__/resume.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/aws/__pycache__/resume.cpython-38.pyc -------------------------------------------------------------------------------- /utils/aws/mime.sh: -------------------------------------------------------------------------------- 1 | # AWS EC2 instance startup 'MIME' script https://aws.amazon.com/premiumsupport/knowledge-center/execute-user-data-ec2/ 2 | # This script will run on every instance restart, not only on first start 3 | # --- DO NOT COPY ABOVE COMMENTS WHEN PASTING INTO USERDATA --- 4 | 5 | Content-Type: multipart/mixed; boundary="//" 6 | MIME-Version: 1.0 7 | 8 | --// 9 | Content-Type: text/cloud-config; charset="us-ascii" 10 | MIME-Version: 1.0 11 | Content-Transfer-Encoding: 7bit 12 | Content-Disposition: attachment; filename="cloud-config.txt" 13 | 14 | #cloud-config 15 | cloud_final_modules: 16 | - [scripts-user, always] 17 | 18 | --// 19 | Content-Type: text/x-shellscript; charset="us-ascii" 20 | MIME-Version: 1.0 21 | Content-Transfer-Encoding: 7bit 22 | Content-Disposition: attachment; filename="userdata.txt" 23 | 24 | #!/bin/bash 25 | # --- paste contents of userdata.sh here --- 26 | --// 27 | -------------------------------------------------------------------------------- /utils/aws/resume.py: -------------------------------------------------------------------------------- 1 | # Resume all interrupted trainings in yolov5/ dir including DDP trainings 2 | # Usage: $ python utils/aws/resume.py 3 | 4 | import os 5 | import sys 6 | from pathlib import Path 7 | 8 | import torch 9 | import yaml 10 | 11 | FILE = Path(__file__).resolve() 12 | ROOT = FILE.parents[2] # YOLOv5 root directory 13 | if str(ROOT) not in sys.path: 14 | sys.path.append(str(ROOT)) # add ROOT to PATH 15 | 16 | port = 0 # --master_port 17 | path = Path('').resolve() 18 | for last in path.rglob('*/**/last.pt'): 19 | ckpt = torch.load(last) 20 | if ckpt['optimizer'] is None: 21 | continue 22 | 23 | # Load opt.yaml 24 | with open(last.parent.parent / 'opt.yaml', errors='ignore') as f: 25 | opt = yaml.safe_load(f) 26 | 27 | # Get device count 28 | d = opt['device'].split(',') # devices 29 | nd = len(d) # number of devices 30 | ddp = nd > 1 or (nd == 0 and torch.cuda.device_count() > 1) # distributed data parallel 31 | 32 | if ddp: # multi-GPU 33 | port += 1 34 | cmd = f'python -m torch.distributed.run --nproc_per_node {nd} --master_port {port} train.py --resume {last}' 35 | else: # single-GPU 36 | cmd = f'python train.py --resume {last}' 37 | 38 | cmd += ' > /dev/null 2>&1 &' # redirect output to dev/null and run in daemon thread 39 | print(cmd) 40 | os.system(cmd) 41 | -------------------------------------------------------------------------------- /utils/aws/userdata.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # AWS EC2 instance startup script https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html 3 | # This script will run only once on first instance start (for a re-start script see mime.sh) 4 | # /home/ubuntu (ubuntu) or /home/ec2-user (amazon-linux) is working dir 5 | # Use >300 GB SSD 6 | 7 | cd home/ubuntu 8 | if [ ! -d yolov5 ]; then 9 | echo "Running first-time script." # install dependencies, download COCO, pull Docker 10 | git clone https://github.com/ultralytics/yolov5 -b master && sudo chmod -R 777 yolov5 11 | cd yolov5 12 | bash data/scripts/get_coco.sh && echo "COCO done." & 13 | sudo docker pull ultralytics/yolov5:latest && echo "Docker done." & 14 | python -m pip install --upgrade pip && pip install -r requirements.txt && python detect.py && echo "Requirements done." & 15 | wait && echo "All tasks done." # finish background tasks 16 | else 17 | echo "Running re-start script." # resume interrupted runs 18 | i=0 19 | list=$(sudo docker ps -qa) # container list i.e. $'one\ntwo\nthree\nfour' 20 | while IFS= read -r id; do 21 | ((i++)) 22 | echo "restarting container $i: $id" 23 | sudo docker start $id 24 | # sudo docker exec -it $id python train.py --resume # single-GPU 25 | sudo docker exec -d $id python utils/aws/resume.py # multi-scenario 26 | done <<<"$list" 27 | fi 28 | -------------------------------------------------------------------------------- /utils/benchmarks.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Run YOLOv5 benchmarks on all supported export formats 4 | 5 | Format | `export.py --include` | Model 6 | --- | --- | --- 7 | PyTorch | - | yolov5s.pt 8 | TorchScript | `torchscript` | yolov5s.torchscript 9 | ONNX | `onnx` | yolov5s.onnx 10 | OpenVINO | `openvino` | yolov5s_openvino_model/ 11 | TensorRT | `engine` | yolov5s.engine 12 | CoreML | `coreml` | yolov5s.mlmodel 13 | TensorFlow SavedModel | `saved_model` | yolov5s_saved_model/ 14 | TensorFlow GraphDef | `pb` | yolov5s.pb 15 | TensorFlow Lite | `tflite` | yolov5s.tflite 16 | TensorFlow Edge TPU | `edgetpu` | yolov5s_edgetpu.tflite 17 | TensorFlow.js | `tfjs` | yolov5s_web_model/ 18 | 19 | Requirements: 20 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu # CPU 21 | $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU 22 | $ pip install -U nvidia-tensorrt --index-url https://pypi.ngc.nvidia.com # TensorRT 23 | 24 | Usage: 25 | $ python utils/benchmarks.py --weights yolov5s.pt --img 640 26 | """ 27 | 28 | import argparse 29 | import sys 30 | import time 31 | from pathlib import Path 32 | 33 | import pandas as pd 34 | 35 | FILE = Path(__file__).resolve() 36 | ROOT = FILE.parents[1] # YOLOv5 root directory 37 | if str(ROOT) not in sys.path: 38 | sys.path.append(str(ROOT)) # add ROOT to PATH 39 | # ROOT = ROOT.relative_to(Path.cwd()) # relative 40 | 41 | import export 42 | import val 43 | from utils import notebook_init 44 | from utils.general import LOGGER, check_yaml, print_args 45 | from utils.torch_utils import select_device 46 | 47 | 48 | def run( 49 | weights=ROOT / 'yolov5s.pt', # weights path 50 | imgsz=640, # inference size (pixels) 51 | batch_size=1, # batch size 52 | data=ROOT / 'data/coco128.yaml', # dataset.yaml path 53 | device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu 54 | half=False, # use FP16 half-precision inference 55 | test=False, # test exports only 56 | pt_only=False, # test PyTorch only 57 | ): 58 | y, t = [], time.time() 59 | formats = export.export_formats() 60 | device = select_device(device) 61 | for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) 62 | try: 63 | assert i != 9, 'Edge TPU not supported' 64 | assert i != 10, 'TF.js not supported' 65 | if device.type != 'cpu': 66 | assert gpu, f'{name} inference not supported on GPU' 67 | 68 | # Export 69 | if f == '-': 70 | w = weights # PyTorch format 71 | else: 72 | w = export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # all others 73 | assert suffix in str(w), 'export failed' 74 | 75 | # Validate 76 | result = val.run(data, w, batch_size, imgsz, plots=False, device=device, task='benchmark', half=half) 77 | metrics = result[0] # metrics (mp, mr, map50, map, *losses(box, obj, cls)) 78 | speeds = result[2] # times (preprocess, inference, postprocess) 79 | y.append([name, round(metrics[3], 4), round(speeds[1], 2)]) # mAP, t_inference 80 | except Exception as e: 81 | LOGGER.warning(f'WARNING: Benchmark failure for {name}: {e}') 82 | y.append([name, None, None]) # mAP, t_inference 83 | if pt_only and i == 0: 84 | break # break after PyTorch 85 | 86 | # Print results 87 | LOGGER.info('\n') 88 | parse_opt() 89 | notebook_init() # print system info 90 | py = pd.DataFrame(y, columns=['Format', 'mAP@0.5:0.95', 'Inference time (ms)'] if map else ['Format', 'Export', '']) 91 | LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)') 92 | LOGGER.info(str(py if map else py.iloc[:, :2])) 93 | return py 94 | 95 | 96 | def test( 97 | weights=ROOT / 'yolov5s.pt', # weights path 98 | imgsz=640, # inference size (pixels) 99 | batch_size=1, # batch size 100 | data=ROOT / 'data/coco128.yaml', # dataset.yaml path 101 | device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu 102 | half=False, # use FP16 half-precision inference 103 | test=False, # test exports only 104 | pt_only=False, # test PyTorch only 105 | ): 106 | y, t = [], time.time() 107 | formats = export.export_formats() 108 | device = select_device(device) 109 | for i, (name, f, suffix, gpu) in formats.iterrows(): # index, (name, file, suffix, gpu-capable) 110 | try: 111 | w = weights if f == '-' else \ 112 | export.run(weights=weights, imgsz=[imgsz], include=[f], device=device, half=half)[-1] # weights 113 | assert suffix in str(w), 'export failed' 114 | y.append([name, True]) 115 | except Exception: 116 | y.append([name, False]) # mAP, t_inference 117 | 118 | # Print results 119 | LOGGER.info('\n') 120 | parse_opt() 121 | notebook_init() # print system info 122 | py = pd.DataFrame(y, columns=['Format', 'Export']) 123 | LOGGER.info(f'\nExports complete ({time.time() - t:.2f}s)') 124 | LOGGER.info(str(py)) 125 | return py 126 | 127 | 128 | def parse_opt(): 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument('--weights', type=str, default=ROOT / 'yolov5s.pt', help='weights path') 131 | parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='inference size (pixels)') 132 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 133 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') 134 | parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 135 | parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference') 136 | parser.add_argument('--test', action='store_true', help='test exports only') 137 | parser.add_argument('--pt-only', action='store_true', help='test PyTorch only') 138 | opt = parser.parse_args() 139 | opt.data = check_yaml(opt.data) # check YAML 140 | print_args(vars(opt)) 141 | return opt 142 | 143 | 144 | def main(opt): 145 | test(**vars(opt)) if opt.test else run(**vars(opt)) 146 | 147 | 148 | if __name__ == "__main__": 149 | opt = parse_opt() 150 | main(opt) 151 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Callback utils 4 | """ 5 | 6 | 7 | class Callbacks: 8 | """" 9 | Handles all registered callbacks for YOLOv5 Hooks 10 | """ 11 | 12 | def __init__(self): 13 | # Define the available callbacks 14 | self._callbacks = { 15 | 'on_pretrain_routine_start': [], 16 | 'on_pretrain_routine_end': [], 17 | 'on_train_start': [], 18 | 'on_train_epoch_start': [], 19 | 'on_train_batch_start': [], 20 | 'optimizer_step': [], 21 | 'on_before_zero_grad': [], 22 | 'on_train_batch_end': [], 23 | 'on_train_epoch_end': [], 24 | 'on_val_start': [], 25 | 'on_val_batch_start': [], 26 | 'on_val_image_end': [], 27 | 'on_val_batch_end': [], 28 | 'on_val_end': [], 29 | 'on_fit_epoch_end': [], # fit = train + val 30 | 'on_model_save': [], 31 | 'on_train_end': [], 32 | 'on_params_update': [], 33 | 'teardown': [],} 34 | self.stop_training = False # set True to interrupt training 35 | 36 | def register_action(self, hook, name='', callback=None): 37 | """ 38 | Register a new action to a callback hook 39 | 40 | Args: 41 | hook: The callback hook name to register the action to 42 | name: The name of the action for later reference 43 | callback: The callback to fire 44 | """ 45 | assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" 46 | assert callable(callback), f"callback '{callback}' is not callable" 47 | self._callbacks[hook].append({'name': name, 'callback': callback}) 48 | 49 | def get_registered_actions(self, hook=None): 50 | """" 51 | Returns all the registered actions by callback hook 52 | 53 | Args: 54 | hook: The name of the hook to check, defaults to all 55 | """ 56 | return self._callbacks[hook] if hook else self._callbacks 57 | 58 | def run(self, hook, *args, **kwargs): 59 | """ 60 | Loop through the registered actions and fire all callbacks 61 | 62 | Args: 63 | hook: The name of the hook to check, defaults to all 64 | args: Arguments to receive from YOLOv5 65 | kwargs: Keyword Arguments to receive from YOLOv5 66 | """ 67 | 68 | assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" 69 | 70 | for logger in self._callbacks[hook]: 71 | logger['callback'](*args, **kwargs) 72 | -------------------------------------------------------------------------------- /utils/docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | 3 | # Start FROM NVIDIA PyTorch image https://ngc.nvidia.com/catalog/containers/nvidia:pytorch 4 | FROM nvcr.io/nvidia/pytorch:22.04-py3 5 | RUN rm -rf /opt/pytorch # remove 1.2GB dir 6 | 7 | # Downloads to user config dir 8 | ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ 9 | 10 | # Install linux packages 11 | RUN apt update && apt install --no-install-recommends -y zip htop screen libgl1-mesa-glx 12 | 13 | # Install pip packages 14 | COPY requirements.txt . 15 | RUN python -m pip install --upgrade pip 16 | RUN pip uninstall -y torch torchvision torchtext Pillow 17 | RUN pip install --no-cache -r requirements.txt albumentations wandb gsutil notebook Pillow>=9.1.0 \ 18 | --extra-index-url https://download.pytorch.org/whl/cu113 19 | 20 | # Create working directory 21 | RUN mkdir -p /usr/src/app 22 | WORKDIR /usr/src/app 23 | 24 | # Copy contents 25 | COPY . /usr/src/app 26 | RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5 27 | 28 | # Set environment variables 29 | ENV OMP_NUM_THREADS=8 30 | 31 | 32 | # Usage Examples ------------------------------------------------------------------------------------------------------- 33 | 34 | # Build and Push 35 | # t=ultralytics/yolov5:latest && sudo docker build -f utils/docker/Dockerfile -t $t . && sudo docker push $t 36 | 37 | # Pull and Run 38 | # t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all $t 39 | 40 | # Pull and Run with local directory access 41 | # t=ultralytics/yolov5:latest && sudo docker pull $t && sudo docker run -it --ipc=host --gpus all -v "$(pwd)"/datasets:/usr/src/datasets $t 42 | 43 | # Kill all 44 | # sudo docker kill $(sudo docker ps -q) 45 | 46 | # Kill all image-based 47 | # sudo docker kill $(sudo docker ps -qa --filter ancestor=ultralytics/yolov5:latest) 48 | 49 | # Bash into running container 50 | # sudo docker exec -it 5a9b5863d93d bash 51 | 52 | # Bash into stopped container 53 | # id=$(sudo docker ps -qa) && sudo docker start $id && sudo docker exec -it $id bash 54 | 55 | # Clean up 56 | # docker system prune -a --volumes 57 | 58 | # Update Ubuntu drivers 59 | # https://www.maketecheasier.com/install-nvidia-drivers-ubuntu/ 60 | 61 | # DDP test 62 | # python -m torch.distributed.run --nproc_per_node 2 --master_port 1 train.py --epochs 3 63 | 64 | # GCP VM from Image 65 | # docker.io/ultralytics/yolov5:latest 66 | -------------------------------------------------------------------------------- /utils/docker/Dockerfile-arm64: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | # aarch64-compatible YOLOv5 Docker image for use with Apple M1 and other ARM architectures like Jetson Nano and Raspberry Pi 3 | 4 | # Start FROM Ubuntu image https://hub.docker.com/_/ubuntu 5 | FROM arm64v8/ubuntu:20.04 6 | 7 | # Downloads to user config dir 8 | ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ 9 | 10 | # Install linux packages 11 | RUN apt update 12 | RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata 13 | RUN apt install --no-install-recommends -y python3-pip git zip curl htop gcc \ 14 | libgl1-mesa-glx libglib2.0-0 libpython3.8-dev 15 | # RUN alias python=python3 16 | 17 | # Install pip packages 18 | COPY requirements.txt . 19 | RUN python3 -m pip install --upgrade pip 20 | RUN pip install --no-cache -r requirements.txt gsutil notebook \ 21 | tensorflow-aarch64 22 | # tensorflowjs \ 23 | # onnx onnx-simplifier onnxruntime \ 24 | # coremltools openvino-dev \ 25 | 26 | # Create working directory 27 | RUN mkdir -p /usr/src/app 28 | WORKDIR /usr/src/app 29 | 30 | # Copy contents 31 | COPY . /usr/src/app 32 | RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5 33 | 34 | 35 | # Usage Examples ------------------------------------------------------------------------------------------------------- 36 | 37 | # Build and Push 38 | # t=ultralytics/yolov5:latest-M1 && sudo docker build --platform linux/arm64 -f utils/docker/Dockerfile-arm64 -t $t . && sudo docker push $t 39 | 40 | # Pull and Run 41 | # t=ultralytics/yolov5:latest-M1 && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t 42 | -------------------------------------------------------------------------------- /utils/docker/Dockerfile-cpu: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | 3 | # Start FROM Ubuntu image https://hub.docker.com/_/ubuntu 4 | FROM ubuntu:20.04 5 | 6 | # Downloads to user config dir 7 | ADD https://ultralytics.com/assets/Arial.ttf https://ultralytics.com/assets/Arial.Unicode.ttf /root/.config/Ultralytics/ 8 | 9 | # Install linux packages 10 | RUN apt update 11 | RUN DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt install -y tzdata 12 | RUN apt install --no-install-recommends -y python3-pip git zip curl htop libgl1-mesa-glx libglib2.0-0 libpython3.8-dev 13 | # RUN alias python=python3 14 | 15 | # Install pip packages 16 | COPY requirements.txt . 17 | RUN python3 -m pip install --upgrade pip 18 | RUN pip install --no-cache -r requirements.txt albumentations gsutil notebook \ 19 | coremltools onnx onnx-simplifier onnxruntime openvino-dev tensorflow-cpu tensorflowjs \ 20 | --extra-index-url https://download.pytorch.org/whl/cpu 21 | 22 | # Create working directory 23 | RUN mkdir -p /usr/src/app 24 | WORKDIR /usr/src/app 25 | 26 | # Copy contents 27 | COPY . /usr/src/app 28 | RUN git clone https://github.com/ultralytics/yolov5 /usr/src/yolov5 29 | 30 | 31 | # Usage Examples ------------------------------------------------------------------------------------------------------- 32 | 33 | # Build and Push 34 | # t=ultralytics/yolov5:latest-cpu && sudo docker build -f utils/docker/Dockerfile-cpu -t $t . && sudo docker push $t 35 | 36 | # Pull and Run 37 | # t=ultralytics/yolov5:latest-cpu && sudo docker pull $t && sudo docker run -it --ipc=host -v "$(pwd)"/datasets:/usr/src/datasets $t 38 | -------------------------------------------------------------------------------- /utils/downloads.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Download utils 4 | """ 5 | 6 | import logging 7 | import os 8 | import platform 9 | import subprocess 10 | import time 11 | import urllib 12 | from pathlib import Path 13 | from zipfile import ZipFile 14 | 15 | import requests 16 | import torch 17 | 18 | 19 | def is_url(url): 20 | # Check if online file exists 21 | try: 22 | r = urllib.request.urlopen(url) # response 23 | return r.getcode() == 200 24 | except urllib.request.HTTPError: 25 | return False 26 | 27 | 28 | def gsutil_getsize(url=''): 29 | # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du 30 | s = subprocess.check_output(f'gsutil du {url}', shell=True).decode('utf-8') 31 | return eval(s.split(' ')[0]) if len(s) else 0 # bytes 32 | 33 | 34 | def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''): 35 | # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes 36 | from utils.general import LOGGER 37 | 38 | file = Path(file) 39 | assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" 40 | try: # url1 41 | LOGGER.info(f'Downloading {url} to {file}...') 42 | torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO) 43 | assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check 44 | except Exception as e: # url2 45 | file.unlink(missing_ok=True) # remove partial downloads 46 | LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...') 47 | os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail 48 | finally: 49 | if not file.exists() or file.stat().st_size < min_bytes: # check 50 | file.unlink(missing_ok=True) # remove partial downloads 51 | LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}") 52 | LOGGER.info('') 53 | 54 | 55 | def attempt_download(file, repo='ultralytics/yolov5', release='v6.1'): 56 | # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.1', etc. 57 | from utils.general import LOGGER 58 | 59 | def github_assets(repository, version='latest'): 60 | # Return GitHub repo tag (i.e. 'v6.1') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...]) 61 | if version != 'latest': 62 | version = f'tags/{version}' # i.e. tags/v6.1 63 | response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api 64 | return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets 65 | 66 | file = Path(str(file).strip().replace("'", '')) 67 | if not file.exists(): 68 | # URL specified 69 | name = Path(urllib.parse.unquote(str(file))).name # decode '%2F' to '/' etc. 70 | if str(file).startswith(('http:/', 'https:/')): # download 71 | url = str(file).replace(':/', '://') # Pathlib turns :// -> :/ 72 | file = name.split('?')[0] # parse authentication https://url.com/file.txt?auth... 73 | if Path(file).is_file(): 74 | LOGGER.info(f'Found {url} locally at {file}') # file already exists 75 | else: 76 | safe_download(file=file, url=url, min_bytes=1E5) 77 | return file 78 | 79 | # GitHub assets 80 | assets = [ 81 | 'yolov5n.pt', 'yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov5n6.pt', 'yolov5s6.pt', 82 | 'yolov5m6.pt', 'yolov5l6.pt', 'yolov5x6.pt'] 83 | try: 84 | tag, assets = github_assets(repo, release) 85 | except Exception: 86 | try: 87 | tag, assets = github_assets(repo) # latest release 88 | except Exception: 89 | try: 90 | tag = subprocess.check_output('git tag', shell=True, stderr=subprocess.STDOUT).decode().split()[-1] 91 | except Exception: 92 | tag = release 93 | 94 | file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) 95 | if name in assets: 96 | url3 = 'https://drive.google.com/drive/folders/1EFQTEUeXWSFww0luse2jB9M1QNZQGwNl' # backup gdrive mirror 97 | safe_download( 98 | file, 99 | url=f'https://github.com/{repo}/releases/download/{tag}/{name}', 100 | url2=f'https://storage.googleapis.com/{repo}/{tag}/{name}', # backup url (optional) 101 | min_bytes=1E5, 102 | error_msg=f'{file} missing, try downloading from https://github.com/{repo}/releases/{tag} or {url3}') 103 | 104 | return str(file) 105 | 106 | 107 | def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'): 108 | # Downloads a file from Google Drive. from yolov5.utils.downloads import *; gdrive_download() 109 | t = time.time() 110 | file = Path(file) 111 | cookie = Path('cookie') # gdrive cookie 112 | print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='') 113 | file.unlink(missing_ok=True) # remove existing file 114 | cookie.unlink(missing_ok=True) # remove existing cookie 115 | 116 | # Attempt file download 117 | out = "NUL" if platform.system() == "Windows" else "/dev/null" 118 | os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') 119 | if os.path.exists('cookie'): # large file 120 | s = f'curl -Lb ./cookie "drive.google.com/uc?export=download&confirm={get_token()}&id={id}" -o {file}' 121 | else: # small file 122 | s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"' 123 | r = os.system(s) # execute, capture return 124 | cookie.unlink(missing_ok=True) # remove existing cookie 125 | 126 | # Error check 127 | if r != 0: 128 | file.unlink(missing_ok=True) # remove partial 129 | print('Download error ') # raise Exception('Download error') 130 | return r 131 | 132 | # Unzip if archive 133 | if file.suffix == '.zip': 134 | print('unzipping... ', end='') 135 | ZipFile(file).extractall(path=file.parent) # unzip 136 | file.unlink() # remove zip 137 | 138 | print(f'Done ({time.time() - t:.1f}s)') 139 | return r 140 | 141 | 142 | def get_token(cookie="./cookie"): 143 | with open(cookie) as f: 144 | for line in f: 145 | if "download" in line: 146 | return line.split()[-1] 147 | return "" 148 | 149 | 150 | # Google utils: https://cloud.google.com/storage/docs/reference/libraries ---------------------------------------------- 151 | # 152 | # 153 | # def upload_blob(bucket_name, source_file_name, destination_blob_name): 154 | # # Uploads a file to a bucket 155 | # # https://cloud.google.com/storage/docs/uploading-objects#storage-upload-object-python 156 | # 157 | # storage_client = storage.Client() 158 | # bucket = storage_client.get_bucket(bucket_name) 159 | # blob = bucket.blob(destination_blob_name) 160 | # 161 | # blob.upload_from_filename(source_file_name) 162 | # 163 | # print('File {} uploaded to {}.'.format( 164 | # source_file_name, 165 | # destination_blob_name)) 166 | # 167 | # 168 | # def download_blob(bucket_name, source_blob_name, destination_file_name): 169 | # # Uploads a blob from a bucket 170 | # storage_client = storage.Client() 171 | # bucket = storage_client.get_bucket(bucket_name) 172 | # blob = bucket.blob(source_blob_name) 173 | # 174 | # blob.download_to_filename(destination_file_name) 175 | # 176 | # print('Blob {} downloaded to {}.'.format( 177 | # source_blob_name, 178 | # destination_file_name)) 179 | -------------------------------------------------------------------------------- /utils/flask_rest_api/README.md: -------------------------------------------------------------------------------- 1 | # Flask REST API 2 | 3 | [REST](https://en.wikipedia.org/wiki/Representational_state_transfer) [API](https://en.wikipedia.org/wiki/API)s are 4 | commonly used to expose Machine Learning (ML) models to other services. This folder contains an example REST API 5 | created using Flask to expose the YOLOv5s model from [PyTorch Hub](https://pytorch.org/hub/ultralytics_yolov5/). 6 | 7 | ## Requirements 8 | 9 | [Flask](https://palletsprojects.com/p/flask/) is required. Install with: 10 | 11 | ```shell 12 | $ pip install Flask 13 | ``` 14 | 15 | ## Run 16 | 17 | After Flask installation run: 18 | 19 | ```shell 20 | $ python3 restapi.py --port 5000 21 | ``` 22 | 23 | Then use [curl](https://curl.se/) to perform a request: 24 | 25 | ```shell 26 | $ curl -X POST -F image=@zidane.jpg 'http://localhost:5000/v1/object-detection/yolov5s' 27 | ``` 28 | 29 | The model inference results are returned as a JSON response: 30 | 31 | ```json 32 | [ 33 | { 34 | "class": 0, 35 | "confidence": 0.8900438547, 36 | "height": 0.9318675399, 37 | "name": "person", 38 | "width": 0.3264600933, 39 | "xcenter": 0.7438579798, 40 | "ycenter": 0.5207948685 41 | }, 42 | { 43 | "class": 0, 44 | "confidence": 0.8440024257, 45 | "height": 0.7155083418, 46 | "name": "person", 47 | "width": 0.6546785235, 48 | "xcenter": 0.427829951, 49 | "ycenter": 0.6334488392 50 | }, 51 | { 52 | "class": 27, 53 | "confidence": 0.3771208823, 54 | "height": 0.3902671337, 55 | "name": "tie", 56 | "width": 0.0696444362, 57 | "xcenter": 0.3675483763, 58 | "ycenter": 0.7991207838 59 | }, 60 | { 61 | "class": 27, 62 | "confidence": 0.3527112305, 63 | "height": 0.1540903747, 64 | "name": "tie", 65 | "width": 0.0336618312, 66 | "xcenter": 0.7814827561, 67 | "ycenter": 0.5065554976 68 | } 69 | ] 70 | ``` 71 | 72 | An example python script to perform inference using [requests](https://docs.python-requests.org/en/master/) is given 73 | in `example_request.py` 74 | -------------------------------------------------------------------------------- /utils/flask_rest_api/__pycache__/example_request.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/flask_rest_api/__pycache__/example_request.cpython-38.pyc -------------------------------------------------------------------------------- /utils/flask_rest_api/__pycache__/restapi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/flask_rest_api/__pycache__/restapi.cpython-38.pyc -------------------------------------------------------------------------------- /utils/flask_rest_api/example_request.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Perform test request 4 | """ 5 | 6 | import pprint 7 | 8 | import requests 9 | 10 | DETECTION_URL = "http://localhost:5000/v1/object-detection/yolov5s" 11 | IMAGE = "zidane.jpg" 12 | 13 | # Read image 14 | with open(IMAGE, "rb") as f: 15 | image_data = f.read() 16 | 17 | response = requests.post(DETECTION_URL, files={"image": image_data}).json() 18 | 19 | pprint.pprint(response) 20 | -------------------------------------------------------------------------------- /utils/flask_rest_api/restapi.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Run a Flask REST API exposing a YOLOv5s model 4 | """ 5 | 6 | import argparse 7 | import io 8 | 9 | import torch 10 | from flask import Flask, request 11 | from PIL import Image 12 | 13 | app = Flask(__name__) 14 | 15 | DETECTION_URL = "/v1/object-detection/yolov5s" 16 | 17 | 18 | @app.route(DETECTION_URL, methods=["POST"]) 19 | def predict(): 20 | if request.method != "POST": 21 | return 22 | 23 | if request.files.get("image"): 24 | # Method 1 25 | # with request.files["image"] as f: 26 | # im = Image.open(io.BytesIO(f.read())) 27 | 28 | # Method 2 29 | im_file = request.files["image"] 30 | im_bytes = im_file.read() 31 | im = Image.open(io.BytesIO(im_bytes)) 32 | 33 | results = model(im, size=640) # reduce size=320 for faster inference 34 | return results.pandas().xyxy[0].to_json(orient="records") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser(description="Flask API exposing YOLOv5 model") 39 | parser.add_argument("--port", default=5000, type=int, help="port number") 40 | opt = parser.parse_args() 41 | 42 | # Fix known issue urllib.error.HTTPError 403: rate limit exceeded https://github.com/ultralytics/yolov5/pull/7210 43 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True 44 | 45 | model = torch.hub.load("ultralytics/yolov5", "yolov5s", force_reload=True) # force_reload to recache 46 | app.run(host="0.0.0.0", port=opt.port) # debug=True causes Restarting with stat 47 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | General utils 4 | """ 5 | 6 | import contextlib 7 | import glob 8 | import inspect 9 | import logging 10 | import math 11 | import os 12 | import platform 13 | import random 14 | import re 15 | import shutil 16 | import signal 17 | import threading 18 | import time 19 | import urllib 20 | from datetime import datetime 21 | from itertools import repeat 22 | from multiprocessing.pool import ThreadPool 23 | from pathlib import Path 24 | from subprocess import check_output 25 | from typing import Optional 26 | from zipfile import ZipFile 27 | 28 | import cv2 29 | import numpy as np 30 | import pandas as pd 31 | import pkg_resources as pkg 32 | import torch 33 | import torchvision 34 | import yaml 35 | 36 | from utils.downloads import gsutil_getsize 37 | from utils.metrics import box_iou, fitness 38 | 39 | FILE = Path(__file__).resolve() 40 | ROOT = FILE.parents[1] # YOLOv5 root directory 41 | RANK = int(os.getenv('RANK', -1)) 42 | 43 | # Settings 44 | DATASETS_DIR = ROOT.parent / 'datasets' # YOLOv5 datasets directory 45 | NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLOv5 multiprocessing threads 46 | AUTOINSTALL = str(os.getenv('YOLOv5_AUTOINSTALL', True)).lower() == 'true' # global auto-install mode 47 | VERBOSE = str(os.getenv('YOLOv5_VERBOSE', True)).lower() == 'true' # global verbose mode 48 | FONT = 'Arial.ttf' # https://ultralytics.com/assets/Arial.ttf 49 | 50 | torch.set_printoptions(linewidth=320, precision=5, profile='long') 51 | np.set_printoptions(linewidth=320, formatter={'float_kind': '{:11.5g}'.format}) # format short g, %precision=5 52 | pd.options.display.max_columns = 10 53 | cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) 54 | os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads 55 | os.environ['OMP_NUM_THREADS'] = str(NUM_THREADS) # OpenMP max threads (PyTorch and SciPy) 56 | 57 | 58 | def is_kaggle(): 59 | # Is environment a Kaggle Notebook? 60 | try: 61 | assert os.environ.get('PWD') == '/kaggle/working' 62 | assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com' 63 | return True 64 | except AssertionError: 65 | return False 66 | 67 | 68 | def is_writeable(dir, test=False): 69 | # Return True if directory has write permissions, test opening a file with write permissions if test=True 70 | if not test: 71 | return os.access(dir, os.R_OK) # possible issues on Windows 72 | file = Path(dir) / 'tmp.txt' 73 | try: 74 | with open(file, 'w'): # open file with write permissions 75 | pass 76 | file.unlink() # remove file 77 | return True 78 | except OSError: 79 | return False 80 | 81 | 82 | def set_logging(name=None, verbose=VERBOSE): 83 | # Sets level and returns logger 84 | if is_kaggle(): 85 | for h in logging.root.handlers: 86 | logging.root.removeHandler(h) # remove all handlers associated with the root logger object 87 | rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings 88 | level = logging.INFO if verbose and rank in {-1, 0} else logging.WARNING 89 | log = logging.getLogger(name) 90 | log.setLevel(level) 91 | handler = logging.StreamHandler() 92 | handler.setFormatter(logging.Formatter("%(message)s")) 93 | handler.setLevel(level) 94 | log.addHandler(handler) 95 | 96 | 97 | set_logging() # run before defining LOGGER 98 | LOGGER = logging.getLogger("yolov5") # define globally (used in train.py, val.py, detect.py, etc.) 99 | 100 | 101 | def user_config_dir(dir='Ultralytics', env_var='YOLOV5_CONFIG_DIR'): 102 | # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. 103 | env = os.getenv(env_var) 104 | if env: 105 | path = Path(env) # use environment variable 106 | else: 107 | cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs 108 | path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir 109 | path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable 110 | path.mkdir(exist_ok=True) # make if required 111 | return path 112 | 113 | 114 | CONFIG_DIR = user_config_dir() # Ultralytics settings dir 115 | 116 | 117 | class Profile(contextlib.ContextDecorator): 118 | # Usage: @Profile() decorator or 'with Profile():' context manager 119 | def __enter__(self): 120 | self.start = time.time() 121 | 122 | def __exit__(self, type, value, traceback): 123 | print(f'Profile results: {time.time() - self.start:.5f}s') 124 | 125 | 126 | class Timeout(contextlib.ContextDecorator): 127 | # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager 128 | def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True): 129 | self.seconds = int(seconds) 130 | self.timeout_message = timeout_msg 131 | self.suppress = bool(suppress_timeout_errors) 132 | 133 | def _timeout_handler(self, signum, frame): 134 | raise TimeoutError(self.timeout_message) 135 | 136 | def __enter__(self): 137 | if platform.system() != 'Windows': # not supported on Windows 138 | signal.signal(signal.SIGALRM, self._timeout_handler) # Set handler for SIGALRM 139 | signal.alarm(self.seconds) # start countdown for SIGALRM to be raised 140 | 141 | def __exit__(self, exc_type, exc_val, exc_tb): 142 | if platform.system() != 'Windows': 143 | signal.alarm(0) # Cancel SIGALRM if it's scheduled 144 | if self.suppress and exc_type is TimeoutError: # Suppress TimeoutError 145 | return True 146 | 147 | 148 | class WorkingDirectory(contextlib.ContextDecorator): 149 | # Usage: @WorkingDirectory(dir) decorator or 'with WorkingDirectory(dir):' context manager 150 | def __init__(self, new_dir): 151 | self.dir = new_dir # new dir 152 | self.cwd = Path.cwd().resolve() # current dir 153 | 154 | def __enter__(self): 155 | os.chdir(self.dir) 156 | 157 | def __exit__(self, exc_type, exc_val, exc_tb): 158 | os.chdir(self.cwd) 159 | 160 | 161 | def try_except(func): 162 | # try-except function. Usage: @try_except decorator 163 | def handler(*args, **kwargs): 164 | try: 165 | func(*args, **kwargs) 166 | except Exception as e: 167 | print(e) 168 | 169 | return handler 170 | 171 | 172 | def threaded(func): 173 | # Multi-threads a target function and returns thread. Usage: @threaded decorator 174 | def wrapper(*args, **kwargs): 175 | thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) 176 | thread.start() 177 | return thread 178 | 179 | return wrapper 180 | 181 | 182 | def methods(instance): 183 | # Get class/instance methods 184 | return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")] 185 | 186 | 187 | def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False): 188 | # Print function arguments (optional args dict) 189 | x = inspect.currentframe().f_back # previous frame 190 | file, _, fcn, _, _ = inspect.getframeinfo(x) 191 | if args is None: # get args automatically 192 | args, _, _, frm = inspect.getargvalues(x) 193 | args = {k: v for k, v in frm.items() if k in args} 194 | s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '') 195 | LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items())) 196 | 197 | 198 | def init_seeds(seed=0): 199 | # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html 200 | # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible 201 | import torch.backends.cudnn as cudnn 202 | random.seed(seed) 203 | np.random.seed(seed) 204 | torch.manual_seed(seed) 205 | cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False) 206 | 207 | 208 | def intersect_dicts(da, db, exclude=()): 209 | # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values 210 | return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape} 211 | 212 | 213 | def get_latest_run(search_dir='.'): 214 | # Return path to most recent 'last.pt' in /runs (i.e. to --resume from) 215 | last_list = glob.glob(f'{search_dir}/**/last*.pt', recursive=True) 216 | return max(last_list, key=os.path.getctime) if last_list else '' 217 | 218 | 219 | def is_docker(): 220 | # Is environment a Docker container? 221 | return Path('/workspace').exists() # or Path('/.dockerenv').exists() 222 | 223 | 224 | def is_colab(): 225 | # Is environment a Google Colab instance? 226 | try: 227 | import google.colab 228 | return True 229 | except ImportError: 230 | return False 231 | 232 | 233 | def is_pip(): 234 | # Is file in a pip package? 235 | return 'site-packages' in Path(__file__).resolve().parts 236 | 237 | 238 | def is_ascii(s=''): 239 | # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) 240 | s = str(s) # convert list, tuple, None, etc. to str 241 | return len(s.encode().decode('ascii', 'ignore')) == len(s) 242 | 243 | 244 | def is_chinese(s='人工智能'): 245 | # Is string composed of any Chinese characters? 246 | return bool(re.search('[\u4e00-\u9fff]', str(s))) 247 | 248 | 249 | def emojis(str=''): 250 | # Return platform-dependent emoji-safe version of string 251 | return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str 252 | 253 | 254 | def file_age(path=__file__): 255 | # Return days since last file update 256 | dt = (datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime)) # delta 257 | return dt.days # + dt.seconds / 86400 # fractional days 258 | 259 | 260 | def file_date(path=__file__): 261 | # Return human-readable file modification date, i.e. '2021-3-26' 262 | t = datetime.fromtimestamp(Path(path).stat().st_mtime) 263 | return f'{t.year}-{t.month}-{t.day}' 264 | 265 | 266 | def file_size(path): 267 | # Return file/dir size (MB) 268 | mb = 1 << 20 # bytes to MiB (1024 ** 2) 269 | path = Path(path) 270 | if path.is_file(): 271 | return path.stat().st_size / mb 272 | elif path.is_dir(): 273 | return sum(f.stat().st_size for f in path.glob('**/*') if f.is_file()) / mb 274 | else: 275 | return 0.0 276 | 277 | 278 | def check_online(): 279 | # Check internet connectivity 280 | import socket 281 | try: 282 | socket.create_connection(("1.1.1.1", 443), 5) # check host accessibility 283 | return True 284 | except OSError: 285 | return False 286 | 287 | 288 | def git_describe(path=ROOT): # path must be a directory 289 | # Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe 290 | try: 291 | assert (Path(path) / '.git').is_dir() 292 | return check_output(f'git -C {path} describe --tags --long --always', shell=True).decode()[:-1] 293 | except Exception: 294 | return '' 295 | 296 | 297 | @try_except 298 | @WorkingDirectory(ROOT) 299 | def check_git_status(): 300 | # Recommend 'git pull' if code is out of date 301 | msg = ', for updates see https://github.com/ultralytics/yolov5' 302 | s = colorstr('github: ') # string 303 | assert Path('.git').exists(), s + 'skipping check (not a git repository)' + msg 304 | assert not is_docker(), s + 'skipping check (Docker image)' + msg 305 | assert check_online(), s + 'skipping check (offline)' + msg 306 | 307 | cmd = 'git fetch && git config --get remote.origin.url' 308 | url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git') # git fetch 309 | branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out 310 | n = int(check_output(f'git rev-list {branch}..origin/master --count', shell=True)) # commits behind 311 | if n > 0: 312 | s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `git pull` or `git clone {url}` to update." 313 | else: 314 | s += f'up to date with {url} ✅' 315 | LOGGER.info(emojis(s)) # emoji-safe 316 | 317 | 318 | def check_python(minimum='3.7.0'): 319 | # Check current python version vs. required python version 320 | check_version(platform.python_version(), minimum, name='Python ', hard=True) 321 | 322 | 323 | def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False): 324 | # Check version vs. required version 325 | current, minimum = (pkg.parse_version(x) for x in (current, minimum)) 326 | result = (current == minimum) if pinned else (current >= minimum) # bool 327 | s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string 328 | if hard: 329 | assert result, s # assert min requirements met 330 | if verbose and not result: 331 | LOGGER.warning(s) 332 | return result 333 | 334 | 335 | @try_except 336 | def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()): 337 | # Check installed dependencies meet requirements (pass *.txt file or list of packages) 338 | prefix = colorstr('red', 'bold', 'requirements:') 339 | check_python() # check python version 340 | if isinstance(requirements, (str, Path)): # requirements.txt file 341 | file = Path(requirements) 342 | assert file.exists(), f"{prefix} {file.resolve()} not found, check failed." 343 | with file.open() as f: 344 | requirements = [f'{x.name}{x.specifier}' for x in pkg.parse_requirements(f) if x.name not in exclude] 345 | else: # list or tuple of packages 346 | requirements = [x for x in requirements if x not in exclude] 347 | 348 | n = 0 # number of packages updates 349 | for i, r in enumerate(requirements): 350 | try: 351 | pkg.require(r) 352 | except Exception: # DistributionNotFound or VersionConflict if requirements not met 353 | s = f"{prefix} {r} not found and is required by YOLOv5" 354 | if install and AUTOINSTALL: # check environment variable 355 | LOGGER.info(f"{s}, attempting auto-update...") 356 | try: 357 | assert check_online(), f"'pip install {r}' skipped (offline)" 358 | LOGGER.info(check_output(f"pip install '{r}' {cmds[i] if cmds else ''}", shell=True).decode()) 359 | n += 1 360 | except Exception as e: 361 | LOGGER.warning(f'{prefix} {e}') 362 | else: 363 | LOGGER.info(f'{s}. Please install and rerun your command.') 364 | 365 | if n: # if packages updated 366 | source = file.resolve() if 'file' in locals() else requirements 367 | s = f"{prefix} {n} package{'s' * (n > 1)} updated per {source}\n" \ 368 | f"{prefix} ⚠️ {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" 369 | LOGGER.info(emojis(s)) 370 | 371 | 372 | def check_img_size(imgsz, s=32, floor=0): 373 | # Verify image size is a multiple of stride s in each dimension 374 | if isinstance(imgsz, int): # integer i.e. img_size=640 375 | new_size = max(make_divisible(imgsz, int(s)), floor) 376 | else: # list i.e. img_size=[640, 480] 377 | imgsz = list(imgsz) # convert to list if tuple 378 | new_size = [max(make_divisible(x, int(s)), floor) for x in imgsz] 379 | if new_size != imgsz: 380 | LOGGER.warning(f'WARNING: --img-size {imgsz} must be multiple of max stride {s}, updating to {new_size}') 381 | return new_size 382 | 383 | 384 | def check_imshow(): 385 | # Check if environment supports image displays 386 | try: 387 | assert not is_docker(), 'cv2.imshow() is disabled in Docker environments' 388 | assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments' 389 | cv2.imshow('test', np.zeros((1, 1, 3))) 390 | cv2.waitKey(1) 391 | cv2.destroyAllWindows() 392 | cv2.waitKey(1) 393 | return True 394 | except Exception as e: 395 | LOGGER.warning(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{e}') 396 | return False 397 | 398 | 399 | def check_suffix(file='yolov5s.pt', suffix=('.pt',), msg=''): 400 | # Check file(s) for acceptable suffix 401 | if file and suffix: 402 | if isinstance(suffix, str): 403 | suffix = [suffix] 404 | for f in file if isinstance(file, (list, tuple)) else [file]: 405 | s = Path(f).suffix.lower() # file suffix 406 | if len(s): 407 | assert s in suffix, f"{msg}{f} acceptable suffix is {suffix}" 408 | 409 | 410 | def check_yaml(file, suffix=('.yaml', '.yml')): 411 | # Search/download YAML file (if necessary) and return path, checking suffix 412 | return check_file(file, suffix) 413 | 414 | 415 | def check_file(file, suffix=''): 416 | # Search/download file (if necessary) and return path 417 | check_suffix(file, suffix) # optional 418 | file = str(file) # convert to str() 419 | if Path(file).is_file() or not file: # exists 420 | return file 421 | elif file.startswith(('http:/', 'https:/')): # download 422 | url = file # warning: Pathlib turns :// -> :/ 423 | file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth 424 | if Path(file).is_file(): 425 | LOGGER.info(f'Found {url} locally at {file}') # file already exists 426 | else: 427 | LOGGER.info(f'Downloading {url} to {file}...') 428 | torch.hub.download_url_to_file(url, file) 429 | assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check 430 | return file 431 | else: # search 432 | files = [] 433 | for d in 'data', 'models', 'utils': # search directories 434 | files.extend(glob.glob(str(ROOT / d / '**' / file), recursive=True)) # find file 435 | assert len(files), f'File not found: {file}' # assert file was found 436 | assert len(files) == 1, f"Multiple files match '{file}', specify exact path: {files}" # assert unique 437 | return files[0] # return file 438 | 439 | 440 | def check_font(font=FONT, progress=False): 441 | # Download font to CONFIG_DIR if necessary 442 | font = Path(font) 443 | file = CONFIG_DIR / font.name 444 | if not font.exists() and not file.exists(): 445 | url = "https://ultralytics.com/assets/" + font.name 446 | LOGGER.info(f'Downloading {url} to {file}...') 447 | torch.hub.download_url_to_file(url, str(file), progress=progress) 448 | 449 | 450 | def check_dataset(data, autodownload=True): 451 | # Download and/or unzip dataset if not found locally 452 | # Usage: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128_with_yaml.zip 453 | 454 | # Download (optional) 455 | extract_dir = '' 456 | if isinstance(data, (str, Path)) and str(data).endswith('.zip'): # i.e. gs://bucket/dir/coco128.zip 457 | download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False, threads=1) 458 | data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) 459 | extract_dir, autodownload = data.parent, False 460 | 461 | # Read yaml (optional) 462 | if isinstance(data, (str, Path)): 463 | with open(data, errors='ignore') as f: 464 | data = yaml.safe_load(f) # dictionary 465 | 466 | # Resolve paths 467 | path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' 468 | if not path.is_absolute(): 469 | path = (ROOT / path).resolve() 470 | for k in 'train', 'val', 'test': 471 | if data.get(k): # prepend path 472 | data[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]] 473 | 474 | # Parse yaml 475 | assert 'nc' in data, "Dataset 'nc' key missing." 476 | if 'names' not in data: 477 | data['names'] = [f'class{i}' for i in range(data['nc'])] # assign class names if missing 478 | train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) 479 | if val: 480 | val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path 481 | if not all(x.exists() for x in val): 482 | LOGGER.info(emojis('\nDataset not found ⚠, missing paths %s' % [str(x) for x in val if not x.exists()])) 483 | if not s or not autodownload: 484 | raise Exception(emojis('Dataset not found ❌')) 485 | t = time.time() 486 | root = path.parent if 'path' in data else '..' # unzip directory i.e. '../' 487 | if s.startswith('http') and s.endswith('.zip'): # URL 488 | f = Path(s).name # filename 489 | LOGGER.info(f'Downloading {s} to {f}...') 490 | torch.hub.download_url_to_file(s, f) 491 | Path(root).mkdir(parents=True, exist_ok=True) # create root 492 | ZipFile(f).extractall(path=root) # unzip 493 | Path(f).unlink() # remove zip 494 | r = None # success 495 | elif s.startswith('bash '): # bash script 496 | LOGGER.info(f'Running {s} ...') 497 | r = os.system(s) 498 | else: # python script 499 | r = exec(s, {'yaml': data}) # return None 500 | dt = f'({round(time.time() - t, 1)}s)' 501 | s = f"success ✅ {dt}, saved to {colorstr('bold', root)}" if r in (0, None) else f"failure {dt} ❌" 502 | LOGGER.info(emojis(f"Dataset download {s}")) 503 | check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf', progress=True) # download fonts 504 | return data # dictionary 505 | 506 | 507 | def check_amp(model): 508 | # Check PyTorch Automatic Mixed Precision (AMP) functionality. Return True on correct operation 509 | from models.common import AutoShape, DetectMultiBackend 510 | 511 | def amp_allclose(model, im): 512 | # All close FP32 vs AMP results 513 | m = AutoShape(model, verbose=False) # model 514 | a = m(im).xywhn[0] # FP32 inference 515 | m.amp = True 516 | b = m(im).xywhn[0] # AMP inference 517 | return a.shape == b.shape and torch.allclose(a, b, atol=0.1) # close to 10% absolute tolerance 518 | 519 | prefix = colorstr('AMP: ') 520 | device = next(model.parameters()).device # get model device 521 | if device.type == 'cpu': 522 | return False # AMP disabled on CPU 523 | f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check 524 | im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3)) 525 | try: 526 | assert amp_allclose(model, im) or amp_allclose(DetectMultiBackend('yolov5n.pt', device), im) 527 | LOGGER.info(emojis(f'{prefix}checks passed ✅')) 528 | return True 529 | except Exception: 530 | help_url = 'https://github.com/ultralytics/yolov5/issues/7908' 531 | LOGGER.warning(emojis(f'{prefix}checks failed ❌, disabling Automatic Mixed Precision. See {help_url}')) 532 | return False 533 | 534 | 535 | def url2file(url): 536 | # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt 537 | url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/ 538 | return Path(urllib.parse.unquote(url)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth 539 | 540 | 541 | def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3): 542 | # Multi-threaded file download and unzip function, used in data.yaml for autodownload 543 | def download_one(url, dir): 544 | # Download 1 file 545 | success = True 546 | f = dir / Path(url).name # filename 547 | if Path(url).is_file(): # exists in current path 548 | Path(url).rename(f) # move to dir 549 | elif not f.exists(): 550 | LOGGER.info(f'Downloading {url} to {f}...') 551 | for i in range(retry + 1): 552 | if curl: 553 | s = 'sS' if threads > 1 else '' # silent 554 | r = os.system(f"curl -{s}L '{url}' -o '{f}' --retry 9 -C -") # curl download 555 | success = r == 0 556 | else: 557 | torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download 558 | success = f.is_file() 559 | if success: 560 | break 561 | elif i < retry: 562 | LOGGER.warning(f'Download failure, retrying {i + 1}/{retry} {url}...') 563 | else: 564 | LOGGER.warning(f'Failed to download {url}...') 565 | 566 | if unzip and success and f.suffix in ('.zip', '.gz'): 567 | LOGGER.info(f'Unzipping {f}...') 568 | if f.suffix == '.zip': 569 | ZipFile(f).extractall(path=dir) # unzip 570 | elif f.suffix == '.gz': 571 | os.system(f'tar xfz {f} --directory {f.parent}') # unzip 572 | if delete: 573 | f.unlink() # remove zip 574 | 575 | dir = Path(dir) 576 | dir.mkdir(parents=True, exist_ok=True) # make directory 577 | if threads > 1: 578 | pool = ThreadPool(threads) 579 | pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded 580 | pool.close() 581 | pool.join() 582 | else: 583 | for u in [url] if isinstance(url, (str, Path)) else url: 584 | download_one(u, dir) 585 | 586 | 587 | def make_divisible(x, divisor): 588 | # Returns nearest x divisible by divisor 589 | if isinstance(divisor, torch.Tensor): 590 | divisor = int(divisor.max()) # to int 591 | return math.ceil(x / divisor) * divisor 592 | 593 | 594 | def clean_str(s): 595 | # Cleans a string by replacing special characters with underscore _ 596 | return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) 597 | 598 | 599 | def one_cycle(y1=0.0, y2=1.0, steps=100): 600 | # lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf 601 | return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1 602 | 603 | 604 | def colorstr(*input): 605 | # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') 606 | *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string 607 | colors = { 608 | 'black': '\033[30m', # basic colors 609 | 'red': '\033[31m', 610 | 'green': '\033[32m', 611 | 'yellow': '\033[33m', 612 | 'blue': '\033[34m', 613 | 'magenta': '\033[35m', 614 | 'cyan': '\033[36m', 615 | 'white': '\033[37m', 616 | 'bright_black': '\033[90m', # bright colors 617 | 'bright_red': '\033[91m', 618 | 'bright_green': '\033[92m', 619 | 'bright_yellow': '\033[93m', 620 | 'bright_blue': '\033[94m', 621 | 'bright_magenta': '\033[95m', 622 | 'bright_cyan': '\033[96m', 623 | 'bright_white': '\033[97m', 624 | 'end': '\033[0m', # misc 625 | 'bold': '\033[1m', 626 | 'underline': '\033[4m'} 627 | return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] 628 | 629 | 630 | def labels_to_class_weights(labels, nc=80): 631 | # Get class weights (inverse frequency) from training labels 632 | if labels[0] is None: # no labels loaded 633 | return torch.Tensor() 634 | 635 | labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO 636 | classes = labels[:, 0].astype(np.int) # labels = [class xywh] 637 | weights = np.bincount(classes, minlength=nc) # occurrences per class 638 | 639 | # Prepend gridpoint count (for uCE training) 640 | # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image 641 | # weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start 642 | 643 | weights[weights == 0] = 1 # replace empty bins with 1 644 | weights = 1 / weights # number of targets per class 645 | weights /= weights.sum() # normalize 646 | return torch.from_numpy(weights) 647 | 648 | 649 | def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)): 650 | # Produces image weights based on class_weights and image contents 651 | # Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample 652 | class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels]) 653 | return (class_weights.reshape(1, nc) * class_counts).sum(1) 654 | 655 | 656 | def coco80_to_coco91_class(): # converts 80-index (val2014) to 91-index (paper) 657 | # https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/ 658 | # a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n') 659 | # b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n') 660 | # x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco 661 | # x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknet 662 | return [ 663 | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 664 | 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 665 | 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 666 | 667 | 668 | def xyxy2xywh(x): 669 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right 670 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 671 | y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center 672 | y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center 673 | y[:, 2] = x[:, 2] - x[:, 0] # width 674 | y[:, 3] = x[:, 3] - x[:, 1] # height 675 | return y 676 | 677 | 678 | def xywh2xyxy(x): 679 | # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 680 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 681 | y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x 682 | y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y 683 | y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x 684 | y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y 685 | return y 686 | 687 | 688 | def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0): 689 | # Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right 690 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 691 | y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw # top left x 692 | y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh # top left y 693 | y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw # bottom right x 694 | y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh # bottom right y 695 | return y 696 | 697 | 698 | def xyxy2xywhn(x, w=640, h=640, clip=False, eps=0.0): 699 | # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right 700 | if clip: 701 | clip_coords(x, (h - eps, w - eps)) # warning: inplace clip 702 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 703 | y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center 704 | y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center 705 | y[:, 2] = (x[:, 2] - x[:, 0]) / w # width 706 | y[:, 3] = (x[:, 3] - x[:, 1]) / h # height 707 | return y 708 | 709 | 710 | def xyn2xy(x, w=640, h=640, padw=0, padh=0): 711 | # Convert normalized segments into pixel segments, shape (n,2) 712 | y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) 713 | y[:, 0] = w * x[:, 0] + padw # top left x 714 | y[:, 1] = h * x[:, 1] + padh # top left y 715 | return y 716 | 717 | 718 | def segment2box(segment, width=640, height=640): 719 | # Convert 1 segment label to 1 box label, applying inside-image constraint, i.e. (xy1, xy2, ...) to (xyxy) 720 | x, y = segment.T # segment xy 721 | inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) 722 | x, y, = x[inside], y[inside] 723 | return np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4)) # xyxy 724 | 725 | 726 | def segments2boxes(segments): 727 | # Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh) 728 | boxes = [] 729 | for s in segments: 730 | x, y = s.T # segment xy 731 | boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy 732 | return xyxy2xywh(np.array(boxes)) # cls, xywh 733 | 734 | 735 | def resample_segments(segments, n=1000): 736 | # Up-sample an (n,2) segment 737 | for i, s in enumerate(segments): 738 | x = np.linspace(0, len(s) - 1, n) 739 | xp = np.arange(len(s)) 740 | segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T # segment xy 741 | return segments 742 | 743 | 744 | def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): 745 | # Rescale coords (xyxy) from img1_shape to img0_shape 746 | if ratio_pad is None: # calculate from img0_shape 747 | gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new 748 | pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding 749 | else: 750 | gain = ratio_pad[0][0] 751 | pad = ratio_pad[1] 752 | 753 | coords[:, [0, 2]] -= pad[0] # x padding 754 | coords[:, [1, 3]] -= pad[1] # y padding 755 | coords[:, :4] /= gain 756 | clip_coords(coords, img0_shape) 757 | return coords 758 | 759 | 760 | def clip_coords(boxes, shape): 761 | # Clip bounding xyxy bounding boxes to image shape (height, width) 762 | if isinstance(boxes, torch.Tensor): # faster individually 763 | boxes[:, 0].clamp_(0, shape[1]) # x1 764 | boxes[:, 1].clamp_(0, shape[0]) # y1 765 | boxes[:, 2].clamp_(0, shape[1]) # x2 766 | boxes[:, 3].clamp_(0, shape[0]) # y2 767 | else: # np.array (faster grouped) 768 | boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(0, shape[1]) # x1, x2 769 | boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(0, shape[0]) # y1, y2 770 | 771 | 772 | def non_max_suppression(prediction, 773 | conf_thres=0.25, 774 | iou_thres=0.45, 775 | classes=None, 776 | agnostic=False, 777 | multi_label=False, 778 | labels=(), 779 | max_det=300): 780 | """Non-Maximum Suppression (NMS) on inference results to reject overlapping bounding boxes 781 | 782 | Returns: 783 | list of detections, on (n,6) tensor per image [xyxy, conf, cls] 784 | """ 785 | 786 | bs = prediction.shape[0] # batch size 787 | nc = prediction.shape[2] - 5 # number of classes 788 | xc = prediction[..., 4] > conf_thres # candidates 789 | 790 | # Checks 791 | assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' 792 | assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' 793 | 794 | # Settings 795 | # min_wh = 2 # (pixels) minimum box width and height 796 | max_wh = 7680 # (pixels) maximum box width and height 797 | max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() 798 | time_limit = 0.3 + 0.03 * bs # seconds to quit after 799 | redundant = True # require redundant detections 800 | multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) 801 | merge = False # use merge-NMS 802 | 803 | t = time.time() 804 | output = [torch.zeros((0, 6), device=prediction.device)] * bs 805 | for xi, x in enumerate(prediction): # image index, image inference 806 | # Apply constraints 807 | # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height 808 | x = x[xc[xi]] # confidence 809 | 810 | # Cat apriori labels if autolabelling 811 | if labels and len(labels[xi]): 812 | lb = labels[xi] 813 | v = torch.zeros((len(lb), nc + 5), device=x.device) 814 | v[:, :4] = lb[:, 1:5] # box 815 | v[:, 4] = 1.0 # conf 816 | v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls 817 | x = torch.cat((x, v), 0) 818 | 819 | # If none remain process next image 820 | if not x.shape[0]: 821 | continue 822 | 823 | # Compute conf 824 | x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf 825 | 826 | # Box (center x, center y, width, height) to (x1, y1, x2, y2) 827 | box = xywh2xyxy(x[:, :4]) 828 | 829 | # Detections matrix nx6 (xyxy, conf, cls) 830 | if multi_label: 831 | i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T 832 | x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) 833 | else: # best class only 834 | conf, j = x[:, 5:].max(1, keepdim=True) 835 | x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] 836 | 837 | # Filter by class 838 | if classes is not None: 839 | x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] 840 | 841 | # Apply finite constraint 842 | # if not torch.isfinite(x).all(): 843 | # x = x[torch.isfinite(x).all(1)] 844 | 845 | # Check shape 846 | n = x.shape[0] # number of boxes 847 | if not n: # no boxes 848 | continue 849 | elif n > max_nms: # excess boxes 850 | x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence 851 | 852 | # Batched NMS 853 | c = x[:, 5:6] * (0 if agnostic else max_wh) # classes 854 | boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores 855 | i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS 856 | if i.shape[0] > max_det: # limit detections 857 | i = i[:max_det] 858 | if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean) 859 | # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) 860 | iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix 861 | weights = iou * scores[None] # box weights 862 | x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes 863 | if redundant: 864 | i = i[iou.sum(1) > 1] # require redundancy 865 | 866 | output[xi] = x[i] 867 | if (time.time() - t) > time_limit: 868 | LOGGER.warning(f'WARNING: NMS time limit {time_limit:.3f}s exceeded') 869 | break # time limit exceeded 870 | 871 | return output 872 | 873 | 874 | def strip_optimizer(f='detector.pt', s=''): # from utils.general import *; strip_optimizer() 875 | # Strip optimizer from 'f' to finalize training, optionally save as 's' 876 | x = torch.load(f, map_location=torch.device('cpu')) 877 | if x.get('ema'): 878 | x['model'] = x['ema'] # replace model with ema 879 | for k in 'optimizer', 'best_fitness', 'wandb_id', 'ema', 'updates': # keys 880 | x[k] = None 881 | x['epoch'] = -1 882 | x['model'].half() # to FP16 883 | for p in x['model'].parameters(): 884 | p.requires_grad = False 885 | torch.save(x, s or f) 886 | mb = os.path.getsize(s or f) / 1E6 # filesize 887 | LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") 888 | 889 | 890 | def print_mutation(results, hyp, save_dir, bucket, prefix=colorstr('evolve: ')): 891 | evolve_csv = save_dir / 'evolve.csv' 892 | evolve_yaml = save_dir / 'hyp_evolve.yaml' 893 | keys = ('metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', 'val/box_loss', 894 | 'val/obj_loss', 'val/cls_loss') + tuple(hyp.keys()) # [results + hyps] 895 | keys = tuple(x.strip() for x in keys) 896 | vals = results + tuple(hyp.values()) 897 | n = len(keys) 898 | 899 | # Download (optional) 900 | if bucket: 901 | url = f'gs://{bucket}/evolve.csv' 902 | if gsutil_getsize(url) > (evolve_csv.stat().st_size if evolve_csv.exists() else 0): 903 | os.system(f'gsutil cp {url} {save_dir}') # download evolve.csv if larger than local 904 | 905 | # Log to evolve.csv 906 | s = '' if evolve_csv.exists() else (('%20s,' * n % keys).rstrip(',') + '\n') # add header 907 | with open(evolve_csv, 'a') as f: 908 | f.write(s + ('%20.5g,' * n % vals).rstrip(',') + '\n') 909 | 910 | # Save yaml 911 | with open(evolve_yaml, 'w') as f: 912 | data = pd.read_csv(evolve_csv) 913 | data = data.rename(columns=lambda x: x.strip()) # strip keys 914 | i = np.argmax(fitness(data.values[:, :4])) # 915 | generations = len(data) 916 | f.write('# YOLOv5 Hyperparameter Evolution Results\n' + f'# Best generation: {i}\n' + 917 | f'# Last generation: {generations - 1}\n' + '# ' + ', '.join(f'{x.strip():>20s}' for x in keys[:7]) + 918 | '\n' + '# ' + ', '.join(f'{x:>20.5g}' for x in data.values[i, :7]) + '\n\n') 919 | yaml.safe_dump(data.loc[i][7:].to_dict(), f, sort_keys=False) 920 | 921 | # Print to screen 922 | LOGGER.info(prefix + f'{generations} generations finished, current result:\n' + prefix + 923 | ', '.join(f'{x.strip():>20s}' for x in keys) + '\n' + prefix + ', '.join(f'{x:20.5g}' 924 | for x in vals) + '\n\n') 925 | 926 | if bucket: 927 | os.system(f'gsutil cp {evolve_csv} {evolve_yaml} gs://{bucket}') # upload 928 | 929 | 930 | def apply_classifier(x, model, img, im0): 931 | # Apply a second stage classifier to YOLO outputs 932 | # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval() 933 | im0 = [im0] if isinstance(im0, np.ndarray) else im0 934 | for i, d in enumerate(x): # per image 935 | if d is not None and len(d): 936 | d = d.clone() 937 | 938 | # Reshape and pad cutouts 939 | b = xyxy2xywh(d[:, :4]) # boxes 940 | b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # rectangle to square 941 | b[:, 2:] = b[:, 2:] * 1.3 + 30 # pad 942 | d[:, :4] = xywh2xyxy(b).long() 943 | 944 | # Rescale boxes from img_size to im0 size 945 | scale_coords(img.shape[2:], d[:, :4], im0[i].shape) 946 | 947 | # Classes 948 | pred_cls1 = d[:, 5].long() 949 | ims = [] 950 | for a in d: 951 | cutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])] 952 | im = cv2.resize(cutout, (224, 224)) # BGR 953 | 954 | im = im[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 955 | im = np.ascontiguousarray(im, dtype=np.float32) # uint8 to float32 956 | im /= 255 # 0 - 255 to 0.0 - 1.0 957 | ims.append(im) 958 | 959 | pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1) # classifier prediction 960 | x[i] = x[i][pred_cls1 == pred_cls2] # retain matching class detections 961 | 962 | return x 963 | 964 | 965 | def increment_path(path, exist_ok=False, sep='', mkdir=False): 966 | # Increment file or directory path, i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. 967 | path = Path(path) # os-agnostic 968 | if path.exists() and not exist_ok: 969 | path, suffix = (path.with_suffix(''), path.suffix) if path.is_file() else (path, '') 970 | 971 | # Method 1 972 | for n in range(2, 9999): 973 | p = f'{path}{sep}{n}{suffix}' # increment path 974 | if not os.path.exists(p): # 975 | break 976 | path = Path(p) 977 | 978 | # Method 2 (deprecated) 979 | # dirs = glob.glob(f"{path}{sep}*") # similar paths 980 | # matches = [re.search(rf"{path.stem}{sep}(\d+)", d) for d in dirs] 981 | # i = [int(m.groups()[0]) for m in matches if m] # indices 982 | # n = max(i) + 1 if i else 2 # increment number 983 | # path = Path(f"{path}{sep}{n}{suffix}") # increment path 984 | 985 | if mkdir: 986 | path.mkdir(parents=True, exist_ok=True) # make directory 987 | 988 | return path 989 | 990 | 991 | # OpenCV Chinese-friendly functions ------------------------------------------------------------------------------------ 992 | imshow_ = cv2.imshow # copy to avoid recursion errors 993 | 994 | 995 | def imread(path, flags=cv2.IMREAD_COLOR): 996 | return cv2.imdecode(np.fromfile(path, np.uint8), flags) 997 | 998 | 999 | def imwrite(path, im): 1000 | try: 1001 | cv2.imencode(Path(path).suffix, im)[1].tofile(path) 1002 | return True 1003 | except Exception: 1004 | return False 1005 | 1006 | 1007 | def imshow(path, im): 1008 | imshow_(path.encode('unicode_escape').decode(), im) 1009 | 1010 | 1011 | cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow # redefine 1012 | 1013 | # Variables ------------------------------------------------------------------------------------------------------------ 1014 | NCOLS = 0 if is_docker() else shutil.get_terminal_size().columns # terminal window size for tqdm 1015 | -------------------------------------------------------------------------------- /utils/google_app_engine/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM gcr.io/google-appengine/python 2 | 3 | # Create a virtualenv for dependencies. This isolates these packages from 4 | # system-level packages. 5 | # Use -p python3 or -p python3.7 to select python version. Default is version 2. 6 | RUN virtualenv /env -p python3 7 | 8 | # Setting these environment variables are the same as running 9 | # source /env/bin/activate. 10 | ENV VIRTUAL_ENV /env 11 | ENV PATH /env/bin:$PATH 12 | 13 | RUN apt-get update && apt-get install -y python-opencv 14 | 15 | # Copy the application's requirements.txt and run pip to install all 16 | # dependencies into the virtualenv. 17 | ADD requirements.txt /app/requirements.txt 18 | RUN pip install -r /app/requirements.txt 19 | 20 | # Add the application source code. 21 | ADD . /app 22 | 23 | # Run a WSGI server to serve the application. gunicorn must be declared as 24 | # a dependency in requirements.txt. 25 | CMD gunicorn -b :$PORT main:app 26 | -------------------------------------------------------------------------------- /utils/google_app_engine/additional_requirements.txt: -------------------------------------------------------------------------------- 1 | # add these requirements in your app on top of the existing ones 2 | pip==21.1 3 | Flask==1.0.2 4 | gunicorn==19.9.0 5 | -------------------------------------------------------------------------------- /utils/google_app_engine/app.yaml: -------------------------------------------------------------------------------- 1 | runtime: custom 2 | env: flex 3 | 4 | service: yolov5app 5 | 6 | liveness_check: 7 | initial_delay_sec: 600 8 | 9 | manual_scaling: 10 | instances: 1 11 | resources: 12 | cpu: 1 13 | memory_gb: 4 14 | disk_size_gb: 20 15 | -------------------------------------------------------------------------------- /utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | """ 3 | Logging utils 4 | """ 5 | 6 | import os 7 | import warnings 8 | 9 | import pkg_resources as pkg 10 | import torch 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from utils.general import colorstr, cv2, emojis 14 | from utils.loggers.wandb.wandb_utils import WandbLogger 15 | from utils.plots import plot_images, plot_results 16 | from utils.torch_utils import de_parallel 17 | 18 | LOGGERS = ('csv', 'tb', 'wandb') # text-file, TensorBoard, Weights & Biases 19 | RANK = int(os.getenv('RANK', -1)) 20 | 21 | try: 22 | import wandb 23 | 24 | assert hasattr(wandb, '__version__') # verify package import not local dir 25 | if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.2') and RANK in {0, -1}: 26 | try: 27 | wandb_login_success = wandb.login(timeout=30) 28 | except wandb.errors.UsageError: # known non-TTY terminal issue 29 | wandb_login_success = False 30 | if not wandb_login_success: 31 | wandb = None 32 | except (ImportError, AssertionError): 33 | wandb = None 34 | 35 | 36 | class Loggers(): 37 | # YOLOv5 Loggers class 38 | def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None, include=LOGGERS): 39 | self.save_dir = save_dir 40 | self.weights = weights 41 | self.opt = opt 42 | self.hyp = hyp 43 | self.logger = logger # for printing results to console 44 | self.include = include 45 | self.keys = [ 46 | 'train/box_loss', 47 | 'train/obj_loss', 48 | 'train/cls_loss', # train loss 49 | 'metrics/precision', 50 | 'metrics/recall', 51 | 'metrics/mAP_0.5', 52 | 'metrics/mAP_0.5:0.95', # metrics 53 | 'val/box_loss', 54 | 'val/obj_loss', 55 | 'val/cls_loss', # val loss 56 | 'x/lr0', 57 | 'x/lr1', 58 | 'x/lr2'] # params 59 | self.best_keys = ['best/epoch', 'best/precision', 'best/recall', 'best/mAP_0.5', 'best/mAP_0.5:0.95'] 60 | for k in LOGGERS: 61 | setattr(self, k, None) # init empty logger dictionary 62 | self.csv = True # always log to csv 63 | 64 | # Message 65 | if not wandb: 66 | prefix = colorstr('Weights & Biases: ') 67 | s = f"{prefix}run 'pip install wandb' to automatically track and visualize YOLOv5 🚀 runs (RECOMMENDED)" 68 | self.logger.info(emojis(s)) 69 | 70 | # TensorBoard 71 | s = self.save_dir 72 | if 'tb' in self.include and not self.opt.evolve: 73 | prefix = colorstr('TensorBoard: ') 74 | self.logger.info(f"{prefix}Start with 'tensorboard --logdir {s.parent}', view at http://localhost:6006/") 75 | self.tb = SummaryWriter(str(s)) 76 | 77 | # W&B 78 | if wandb and 'wandb' in self.include: 79 | wandb_artifact_resume = isinstance(self.opt.resume, str) and self.opt.resume.startswith('wandb-artifact://') 80 | run_id = torch.load(self.weights).get('wandb_id') if self.opt.resume and not wandb_artifact_resume else None 81 | self.opt.hyp = self.hyp # add hyperparameters 82 | self.wandb = WandbLogger(self.opt, run_id) 83 | # temp warn. because nested artifacts not supported after 0.12.10 84 | if pkg.parse_version(wandb.__version__) >= pkg.parse_version('0.12.11'): 85 | self.logger.warning( 86 | "YOLOv5 temporarily requires wandb version 0.12.10 or below. Some features may not work as expected." 87 | ) 88 | else: 89 | self.wandb = None 90 | 91 | def on_train_start(self): 92 | # Callback runs on train start 93 | pass 94 | 95 | def on_pretrain_routine_end(self): 96 | # Callback runs on pre-train routine end 97 | paths = self.save_dir.glob('*labels*.jpg') # training labels 98 | if self.wandb: 99 | self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) 100 | 101 | def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): 102 | # Callback runs on train batch end 103 | if plots: 104 | if ni == 0: 105 | if not self.opt.sync_bn: # --sync known issue https://github.com/ultralytics/yolov5/issues/3754 106 | with warnings.catch_warnings(): 107 | warnings.simplefilter('ignore') # suppress jit trace warning 108 | self.tb.add_graph(torch.jit.trace(de_parallel(model), imgs[0:1], strict=False), []) 109 | if ni < 3: 110 | f = self.save_dir / f'train_batch{ni}.jpg' # filename 111 | plot_images(imgs, targets, paths, f) 112 | if self.wandb and ni == 10: 113 | files = sorted(self.save_dir.glob('train*.jpg')) 114 | self.wandb.log({'Mosaics': [wandb.Image(str(f), caption=f.name) for f in files if f.exists()]}) 115 | 116 | def on_train_epoch_end(self, epoch): 117 | # Callback runs on train epoch end 118 | if self.wandb: 119 | self.wandb.current_epoch = epoch + 1 120 | 121 | def on_val_image_end(self, pred, predn, path, names, im): 122 | # Callback runs on val image end 123 | if self.wandb: 124 | self.wandb.val_one_image(pred, predn, path, names, im) 125 | 126 | def on_val_end(self): 127 | # Callback runs on val end 128 | if self.wandb: 129 | files = sorted(self.save_dir.glob('val*.jpg')) 130 | self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]}) 131 | 132 | def on_fit_epoch_end(self, vals, epoch, best_fitness, fi): 133 | # Callback runs at the end of each fit (train+val) epoch 134 | x = dict(zip(self.keys, vals)) 135 | if self.csv: 136 | file = self.save_dir / 'results.csv' 137 | n = len(x) + 1 # number of cols 138 | s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header 139 | with open(file, 'a') as f: 140 | f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') 141 | 142 | if self.tb: 143 | for k, v in x.items(): 144 | self.tb.add_scalar(k, v, epoch) 145 | 146 | if self.wandb: 147 | if best_fitness == fi: 148 | best_results = [epoch] + vals[3:7] 149 | for i, name in enumerate(self.best_keys): 150 | self.wandb.wandb_run.summary[name] = best_results[i] # log best results in the summary 151 | self.wandb.log(x) 152 | self.wandb.end_epoch(best_result=best_fitness == fi) 153 | 154 | def on_model_save(self, last, epoch, final_epoch, best_fitness, fi): 155 | # Callback runs on model save event 156 | if self.wandb: 157 | if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1: 158 | self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi) 159 | 160 | def on_train_end(self, last, best, plots, epoch, results): 161 | # Callback runs on training end 162 | if plots: 163 | plot_results(file=self.save_dir / 'results.csv') # save results.png 164 | files = ['results.png', 'confusion_matrix.png', *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))] 165 | files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter 166 | self.logger.info(f"Results saved to {colorstr('bold', self.save_dir)}") 167 | 168 | if self.tb: 169 | for f in files: 170 | self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC') 171 | 172 | if self.wandb: 173 | self.wandb.log(dict(zip(self.keys[3:10], results))) 174 | self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]}) 175 | # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model 176 | if not self.opt.evolve: 177 | wandb.log_artifact(str(best if best.exists() else last), 178 | type='model', 179 | name=f'run_{self.wandb.wandb_run.id}_model', 180 | aliases=['latest', 'best', 'stripped']) 181 | self.wandb.finish_run() 182 | 183 | def on_params_update(self, params): 184 | # Update hyperparams or configs of the experiment 185 | # params: A dict containing {param: value} pairs 186 | if self.wandb: 187 | self.wandb.wandb_run.config.update(params, allow_val_change=True) 188 | -------------------------------------------------------------------------------- /utils/loggers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loggers/wandb/README.md: -------------------------------------------------------------------------------- 1 | 📚 This guide explains how to use **Weights & Biases** (W&B) with YOLOv5 🚀. UPDATED 29 September 2021. 2 | 3 | - [About Weights & Biases](#about-weights-&-biases) 4 | - [First-Time Setup](#first-time-setup) 5 | - [Viewing runs](#viewing-runs) 6 | - [Disabling wandb](#disabling-wandb) 7 | - [Advanced Usage: Dataset Versioning and Evaluation](#advanced-usage) 8 | - [Reports: Share your work with the world!](#reports) 9 | 10 | ## About Weights & Biases 11 | 12 | Think of [W&B](https://wandb.ai/site?utm_campaign=repo_yolo_wandbtutorial) like GitHub for machine learning models. With a few lines of code, save everything you need to debug, compare and reproduce your models — architecture, hyperparameters, git commits, model weights, GPU usage, and even datasets and predictions. 13 | 14 | Used by top researchers including teams at OpenAI, Lyft, Github, and MILA, W&B is part of the new standard of best practices for machine learning. How W&B can help you optimize your machine learning workflows: 15 | 16 | - [Debug](https://wandb.ai/wandb/getting-started/reports/Visualize-Debug-Machine-Learning-Models--VmlldzoyNzY5MDk#Free-2) model performance in real time 17 | - [GPU usage](https://wandb.ai/wandb/getting-started/reports/Visualize-Debug-Machine-Learning-Models--VmlldzoyNzY5MDk#System-4) visualized automatically 18 | - [Custom charts](https://wandb.ai/wandb/customizable-charts/reports/Powerful-Custom-Charts-To-Debug-Model-Peformance--VmlldzoyNzY4ODI) for powerful, extensible visualization 19 | - [Share insights](https://wandb.ai/wandb/getting-started/reports/Visualize-Debug-Machine-Learning-Models--VmlldzoyNzY5MDk#Share-8) interactively with collaborators 20 | - [Optimize hyperparameters](https://docs.wandb.com/sweeps) efficiently 21 | - [Track](https://docs.wandb.com/artifacts) datasets, pipelines, and production models 22 | 23 | ## First-Time Setup 24 | 25 |
26 | Toggle Details 27 | When you first train, W&B will prompt you to create a new account and will generate an **API key** for you. If you are an existing user you can retrieve your key from https://wandb.ai/authorize. This key is used to tell W&B where to log your data. You only need to supply your key once, and then it is remembered on the same device. 28 | 29 | W&B will create a cloud **project** (default is 'YOLOv5') for your training runs, and each new training run will be provided a unique run **name** within that project as project/name. You can also manually set your project and run name as: 30 | 31 | ```shell 32 | $ python train.py --project ... --name ... 33 | ``` 34 | 35 | YOLOv5 notebook example: Open In Colab Open In Kaggle 36 | Screen Shot 2021-09-29 at 10 23 13 PM 37 | 38 |
39 | 40 | ## Viewing Runs 41 | 42 |
43 | Toggle Details 44 | Run information streams from your environment to the W&B cloud console as you train. This allows you to monitor and even cancel runs in realtime . All important information is logged: 45 | 46 | - Training & Validation losses 47 | - Metrics: Precision, Recall, mAP@0.5, mAP@0.5:0.95 48 | - Learning Rate over time 49 | - A bounding box debugging panel, showing the training progress over time 50 | - GPU: Type, **GPU Utilization**, power, temperature, **CUDA memory usage** 51 | - System: Disk I/0, CPU utilization, RAM memory usage 52 | - Your trained model as W&B Artifact 53 | - Environment: OS and Python types, Git repository and state, **training command** 54 | 55 |

Weights & Biases dashboard

56 |
57 | 58 | ## Disabling wandb 59 | 60 | - training after running `wandb disabled` inside that directory creates no wandb run 61 | ![Screenshot (84)](https://user-images.githubusercontent.com/15766192/143441777-c780bdd7-7cb4-4404-9559-b4316030a985.png) 62 | 63 | - To enable wandb again, run `wandb online` 64 | ![Screenshot (85)](https://user-images.githubusercontent.com/15766192/143441866-7191b2cb-22f0-4e0f-ae64-2dc47dc13078.png) 65 | 66 | ## Advanced Usage 67 | 68 | You can leverage W&B artifacts and Tables integration to easily visualize and manage your datasets, models and training evaluations. Here are some quick examples to get you started. 69 | 70 |
71 |

1: Train and Log Evaluation simultaneousy

72 | This is an extension of the previous section, but it'll also training after uploading the dataset. This also evaluation Table 73 | Evaluation table compares your predictions and ground truths across the validation set for each epoch. It uses the references to the already uploaded datasets, 74 | so no images will be uploaded from your system more than once. 75 |
76 | Usage 77 | Code $ python train.py --upload_data val 78 | 79 | ![Screenshot from 2021-11-21 17-40-06](https://user-images.githubusercontent.com/15766192/142761183-c1696d8c-3f38-45ab-991a-bb0dfd98ae7d.png) 80 | 81 |
82 | 83 |

2. Visualize and Version Datasets

84 | Log, visualize, dynamically query, and understand your data with W&B Tables. You can use the following command to log your dataset as a W&B Table. This will generate a {dataset}_wandb.yaml file which can be used to train from dataset artifact. 85 |
86 | Usage 87 | Code $ python utils/logger/wandb/log_dataset.py --project ... --name ... --data .. 88 | 89 | ![Screenshot (64)](https://user-images.githubusercontent.com/15766192/128486078-d8433890-98a3-4d12-8986-b6c0e3fc64b9.png) 90 | 91 |
92 | 93 |

3: Train using dataset artifact

94 | When you upload a dataset as described in the first section, you get a new config file with an added `_wandb` to its name. This file contains the information that 95 | can be used to train a model directly from the dataset artifact. This also logs evaluation 96 |
97 | Usage 98 | Code $ python train.py --data {data}_wandb.yaml 99 | 100 | ![Screenshot (72)](https://user-images.githubusercontent.com/15766192/128979739-4cf63aeb-a76f-483f-8861-1c0100b938a5.png) 101 | 102 |
103 | 104 |

4: Save model checkpoints as artifacts

105 | To enable saving and versioning checkpoints of your experiment, pass `--save_period n` with the base cammand, where `n` represents checkpoint interval. 106 | You can also log both the dataset and model checkpoints simultaneously. If not passed, only the final model will be logged 107 | 108 |
109 | Usage 110 | Code $ python train.py --save_period 1 111 | 112 | ![Screenshot (68)](https://user-images.githubusercontent.com/15766192/128726138-ec6c1f60-639d-437d-b4ee-3acd9de47ef3.png) 113 | 114 |
115 | 116 |
117 | 118 |

5: Resume runs from checkpoint artifacts.

119 | Any run can be resumed using artifacts if the --resume argument starts with wandb-artifact:// prefix followed by the run path, i.e, wandb-artifact://username/project/runid . This doesn't require the model checkpoint to be present on the local system. 120 | 121 |
122 | Usage 123 | Code $ python train.py --resume wandb-artifact://{run_path} 124 | 125 | ![Screenshot (70)](https://user-images.githubusercontent.com/15766192/128728988-4e84b355-6c87-41ae-a591-14aecf45343e.png) 126 | 127 |
128 | 129 |

6: Resume runs from dataset artifact & checkpoint artifacts.

130 | Local dataset or model checkpoints are not required. This can be used to resume runs directly on a different device 131 | The syntax is same as the previous section, but you'll need to lof both the dataset and model checkpoints as artifacts, i.e, set bot --upload_dataset or 132 | train from _wandb.yaml file and set --save_period 133 | 134 |
135 | Usage 136 | Code $ python train.py --resume wandb-artifact://{run_path} 137 | 138 | ![Screenshot (70)](https://user-images.githubusercontent.com/15766192/128728988-4e84b355-6c87-41ae-a591-14aecf45343e.png) 139 | 140 |
141 | 142 | 143 | 144 |

Reports

145 | W&B Reports can be created from your saved runs for sharing online. Once a report is created you will receive a link you can use to publically share your results. Here is an example report created from the COCO128 tutorial trainings of all four YOLOv5 models ([link](https://wandb.ai/glenn-jocher/yolov5_tutorial/reports/YOLOv5-COCO128-Tutorial-Results--VmlldzozMDI5OTY)). 146 | 147 | Weights & Biases Reports 148 | 149 | ## Environments 150 | 151 | YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including [CUDA](https://developer.nvidia.com/cuda)/[CUDNN](https://developer.nvidia.com/cudnn), [Python](https://www.python.org/) and [PyTorch](https://pytorch.org/) preinstalled): 152 | 153 | - **Google Colab and Kaggle** notebooks with free GPU: Open In Colab Open In Kaggle 154 | - **Google Cloud** Deep Learning VM. See [GCP Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/GCP-Quickstart) 155 | - **Amazon** Deep Learning AMI. See [AWS Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/AWS-Quickstart) 156 | - **Docker Image**. See [Docker Quickstart Guide](https://github.com/ultralytics/yolov5/wiki/Docker-Quickstart) Docker Pulls 157 | 158 | ## Status 159 | 160 | ![CI CPU testing](https://github.com/ultralytics/yolov5/workflows/CI%20CPU%20testing/badge.svg) 161 | 162 | If this badge is green, all [YOLOv5 GitHub Actions](https://github.com/ultralytics/yolov5/actions) Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training ([train.py](https://github.com/ultralytics/yolov5/blob/master/train.py)), validation ([val.py](https://github.com/ultralytics/yolov5/blob/master/val.py)), inference ([detect.py](https://github.com/ultralytics/yolov5/blob/master/detect.py)) and export ([export.py](https://github.com/ultralytics/yolov5/blob/master/export.py)) on macOS, Windows, and Ubuntu every 24 hours and on every commit. 163 | -------------------------------------------------------------------------------- /utils/loggers/wandb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/wandb/__init__.py -------------------------------------------------------------------------------- /utils/loggers/wandb/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/wandb/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loggers/wandb/__pycache__/log_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/wandb/__pycache__/log_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loggers/wandb/__pycache__/sweep.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/wandb/__pycache__/sweep.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loggers/wandb/__pycache__/wandb_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zxx1218/LicensePlateDetection/21c509b447e12cf9fae146be184cb3bce256a51d/utils/loggers/wandb/__pycache__/wandb_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/loggers/wandb/log_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from wandb_utils import WandbLogger 4 | 5 | from utils.general import LOGGER 6 | 7 | WANDB_ARTIFACT_PREFIX = 'wandb-artifact://' 8 | 9 | 10 | def create_dataset_artifact(opt): 11 | logger = WandbLogger(opt, None, job_type='Dataset Creation') # TODO: return value unused 12 | if not logger.wandb: 13 | LOGGER.info("install wandb using `pip install wandb` to log the dataset") 14 | 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data', type=str, default='data/coco128.yaml', help='data.yaml path') 19 | parser.add_argument('--single-cls', action='store_true', help='train as single-class dataset') 20 | parser.add_argument('--project', type=str, default='YOLOv5', help='name of W&B Project') 21 | parser.add_argument('--entity', default=None, help='W&B entity') 22 | parser.add_argument('--name', type=str, default='log dataset', help='name of W&B run') 23 | 24 | opt = parser.parse_args() 25 | opt.resume = False # Explicitly disallow resume check for dataset upload job 26 | 27 | create_dataset_artifact(opt) 28 | -------------------------------------------------------------------------------- /utils/loggers/wandb/sweep.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import wandb 5 | 6 | FILE = Path(__file__).resolve() 7 | ROOT = FILE.parents[3] # YOLOv5 root directory 8 | if str(ROOT) not in sys.path: 9 | sys.path.append(str(ROOT)) # add ROOT to PATH 10 | 11 | from train import parse_opt, train 12 | from utils.callbacks import Callbacks 13 | from utils.general import increment_path 14 | from utils.torch_utils import select_device 15 | 16 | 17 | def sweep(): 18 | wandb.init() 19 | # Get hyp dict from sweep agent. Copy because train() modifies parameters which confused wandb. 20 | hyp_dict = vars(wandb.config).get("_items").copy() 21 | 22 | # Workaround: get necessary opt args 23 | opt = parse_opt(known=True) 24 | opt.batch_size = hyp_dict.get("batch_size") 25 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok or opt.evolve)) 26 | opt.epochs = hyp_dict.get("epochs") 27 | opt.nosave = True 28 | opt.data = hyp_dict.get("data") 29 | opt.weights = str(opt.weights) 30 | opt.cfg = str(opt.cfg) 31 | opt.data = str(opt.data) 32 | opt.hyp = str(opt.hyp) 33 | opt.project = str(opt.project) 34 | device = select_device(opt.device, batch_size=opt.batch_size) 35 | 36 | # train 37 | train(hyp_dict, opt, device, callbacks=Callbacks()) 38 | 39 | 40 | if __name__ == "__main__": 41 | sweep() 42 | --------------------------------------------------------------------------------