├── assets └── screenshots.jpg ├── requirements.txt ├── datasets └── data.yaml ├── inferences ├── configs │ └── config.toml └── engines.py ├── runs └── main.py ├── README.md └── .gitignore /assets/screenshots.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LIU42/TrafficRules/HEAD/assets/screenshots.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python~=4.11.0.86 2 | onnxruntime~=1.20.1 3 | toml~=0.10.2 4 | numpy~=2.1.1 -------------------------------------------------------------------------------- /datasets/data.yaml: -------------------------------------------------------------------------------- 1 | train: "../train/images" 2 | val: "../val/images" 3 | 4 | names: 5 | 0: "F0" 6 | 1: "F1" 7 | 2: "L0" 8 | 3: "L1" 9 | 4: "S0" 10 | 5: "S1" 11 | 6: "R0" 12 | 7: "R1" 13 | -------------------------------------------------------------------------------- /inferences/configs/config.toml: -------------------------------------------------------------------------------- 1 | precision = "fp32" 2 | providers = ["OpenVINOExecutionProvider", "CPUExecutionProvider"] 3 | 4 | conf-threshold = 0.25 5 | iou-threshold = 0.45 6 | 7 | model-path = "inferences/models/detection-fp32.onnx" 8 | -------------------------------------------------------------------------------- /runs/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import inferences.engines as engines 4 | 5 | 6 | def put_bbox(image, detection_result): 7 | detection_bbox, detection_label = detection_result 8 | 9 | x1 = detection_bbox[0] 10 | y1 = detection_bbox[1] 11 | x2 = detection_bbox[2] 12 | y2 = detection_bbox[3] 13 | 14 | if detection_label[1] == '0': 15 | return cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 215), thickness=2) 16 | else: 17 | return cv2.rectangle(image, (x1, y1), (x2, y2), (0, 215, 0), thickness=2) 18 | 19 | 20 | def put_text(image, detection_result): 21 | detection_bbox, detection_label = detection_result 22 | 23 | x1 = detection_bbox[0] 24 | y1 = detection_bbox[1] 25 | 26 | if detection_label[1] == '0': 27 | return cv2.putText(image, detection_label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 215), thickness=2) 28 | else: 29 | return cv2.putText(image, detection_label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 215, 0), thickness=2) 30 | 31 | 32 | def put_rule(image, position, result): 33 | if result: 34 | return cv2.circle(image, position, 12, (0, 215, 0), thickness=-1) 35 | else: 36 | return cv2.circle(image, position, 12, (0, 0, 215), thickness=-1) 37 | 38 | 39 | def main(args): 40 | sources = args.sources 41 | outputs = args.outputs 42 | 43 | for source, output in zip(sources, outputs): 44 | image = cv2.imread(source) 45 | results, detections = engines.inference(image) 46 | 47 | for detection in detections: 48 | image = put_bbox(image, detection) 49 | image = put_text(image, detection) 50 | 51 | image = put_rule(image, (20, 20), results[0]) 52 | image = put_rule(image, (50, 20), results[1]) 53 | image = put_rule(image, (80, 20), results[2]) 54 | 55 | cv2.imwrite(output, image) 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument('-s', '--sources', nargs='+', required=True) 62 | parser.add_argument('-o', '--outputs', nargs='+', required=True) 63 | 64 | main(parser.parse_args()) 65 | -------------------------------------------------------------------------------- /inferences/engines.py: -------------------------------------------------------------------------------- 1 | import onnxruntime as ort 2 | import cv2 3 | import toml 4 | import numpy as np 5 | 6 | 7 | configs = toml.load('inferences/configs/config.toml') 8 | 9 | classes_labels = [ 10 | 'F0', 'F1', 11 | 'L0', 'L1', 12 | 'S0', 'S1', 13 | 'R0', 'R1', 14 | ] 15 | 16 | session = ort.InferenceSession(configs['model-path'], providers=configs['providers']) 17 | 18 | 19 | def normalize(inputs): 20 | return inputs / 255.0 21 | 22 | 23 | def preprocess(image): 24 | inputs = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose((2, 0, 1)) 25 | 26 | if configs['precision'] == 'fp16': 27 | inputs = normalize(inputs).astype(np.float16) 28 | else: 29 | inputs = normalize(inputs).astype(np.float32) 30 | 31 | return np.expand_dims(inputs, axis=0) 32 | 33 | 34 | def get_valid_outputs(outputs): 35 | valid_outputs = outputs[np.amax(outputs[:, 4:12], axis=1) > configs['conf-threshold']] 36 | 37 | bboxes = valid_outputs[:, 0:4] 38 | scores = valid_outputs[:, 4:12] 39 | 40 | return bboxes.astype(np.int32), np.max(scores, axis=1), np.argmax(scores, axis=1) 41 | 42 | 43 | def non_max_suppression(outputs): 44 | bboxes, scores, classes = get_valid_outputs(outputs) 45 | 46 | bboxes[:, 0] -= bboxes[:, 2] >> 1 47 | bboxes[:, 1] -= bboxes[:, 3] >> 1 48 | 49 | for index in cv2.dnn.NMSBoxes(bboxes, scores, configs['conf-threshold'], configs['iou-threshold'], eta=0.5): 50 | x1 = bboxes[index, 0] 51 | y1 = bboxes[index, 1] 52 | 53 | x2 = bboxes[index, 2] + x1 54 | y2 = bboxes[index, 3] + y1 55 | 56 | yield (x1, y1, x2, y2), classes_labels[classes[index]] 57 | 58 | 59 | def detection_inference(image): 60 | detections = session.run(['output0'], {'images': preprocess(image)}) 61 | detections = detections[0] 62 | detections = detections.squeeze().transpose() 63 | 64 | return [detection for detection in non_max_suppression(detections)] 65 | 66 | 67 | def inference(image): 68 | detection_outputs = detection_inference(image) 69 | 70 | result0 = False 71 | result1 = False 72 | result2 = True 73 | 74 | for _, label in detection_outputs: 75 | if label == 'F1': 76 | result0 = True 77 | result1 = True 78 | 79 | for _, label in detection_outputs: 80 | if label == 'L0': 81 | result0 = False 82 | if label == 'L1': 83 | result0 = True 84 | if label == 'S0': 85 | result1 = False 86 | if label == 'S1': 87 | result1 = True 88 | if label == 'R0': 89 | result2 = False 90 | if label == 'R1': 91 | result2 = True 92 | 93 | return (result0, result1, result2), detection_outputs 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrafficRules 通行规则识别 2 | 3 | *v2.0.0 新变化:使用 YOLO11 以及一个更加丰富的数据集训练模型,对原来目标检测和信号分类两个步骤进行整合,去除了在大部分情况下冗余的过滤筛选,实现交通信号灯识别一步到位,得到的模型识别准确率和推理效率均有少量的提升,且更加易于部署。* 4 | 5 | ## 项目简介 6 | 7 | 本项目为基于 YOLO11 的路口交通信号灯通行规则识别,其中通行规则识别分为以下两个步骤: 8 | 9 | 1. **目标检测**,采用 YOLO11 目标检测模型,识别图像中交通信号灯的位置、颜色以及形状(包括圆形、左箭头、上箭头和右箭头)。 10 | 11 | 2. **规则解析**,对图像中检测出来的交通信号灯,解析其表示的通行规则(即能否直行、能否左转和能否右转)。 12 | 13 | - 圆形的信号灯能够控制三个方向的通行规则,优先级较低。 14 | 15 | - 箭头形的信号灯仅能控制对应方向的通行规则,但优先级较高。 16 | 17 | 此外,若无明确信号,即没有红色的右箭头信号灯,右转默认视为允许通行。 18 | 19 | ## 效果展示 20 | 21 | ![](./assets/screenshots.jpg) 22 | 23 | ## 性能评估 24 | 25 | 模型的输入图像尺寸固定为 640 x 480,使用 [YOLO11n](https://docs.ultralytics.com/zh/models/yolo11/) 模型训练,在当前数据集下信号灯目标检测准确性指标如下。 26 | 27 | | Class | Precision | Recall | mAP50 | mAP50-95 | 28 | | ----- | --------- | ------ | ----- | -------- | 29 | | ALL | 0.97 | 0.971 | 0.989 | 0.89 | 30 | | F0 | 0.99 | 1 | 0.995 | 0.871 | 31 | | F1 | 1 | 0.981 | 0.995 | 0.872 | 32 | | L0 | 0.981 | 0.985 | 0.994 | 0.912 | 33 | | L1 | 0.982 | 1 | 0.995 | 0.915 | 34 | | S0 | 1 | 0.817 | 0.944 | 0.878 | 35 | | S1 | 1 | 0.987 | 0.995 | 0.885 | 36 | | R0 | 0.815 | 1 | 0.995 | 0.914 | 37 | | R1 | 0.993 | 1 | 0.995 | 0.876 | 38 | 39 | *注:本项目训练用的数据集规模较小,在真实环境下的鲁棒性可能不够理想。* 40 | 41 | ## 使用说明 42 | 43 | 首先安装依赖工具包,本项目采用 [ONNX Runtime](https://onnxruntime.ai/) 部署模型推理,如果对 Execution Provider 有特殊需求,请参阅 [官方文档](https://onnxruntime.ai/docs/execution-providers/) 进行配置。 44 | 45 | ```shell-session 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | 准备好待识别图像,识别程序接受的图像尺寸为 640x480,在本项目 Releases 中下载我训练好的模型权重文件,解压到 inferences/models 目录下,运行程序 runs/main.py 即可,其命令行参数的含义如下,其中输入图像文件路图像文件路径数量保持一致: 50 | 51 | | 参数名 | 简化参数名 | 参数描述 | 52 | |:---------:|:-----:|:--------------------------:| 53 | | --sources | -s | 输入预测图像文件路径序列,不同的文件路径以空格分隔。 | 54 | | --outputs | -o | 输出结果图像文件路径序列,不同的文件路径以空格分隔。 | 55 | 56 | 下面是一个运行示例。 57 | 58 | ```shell-session 59 | python runs/main.py --sources "s1.jpg" "s2.jpg" --outputs "o1.jpg" "o2.jpg" 60 | ``` 61 | 62 | 本项目的识别程序被设计为一个即插即用的 Python 模块,可以将 inferences 模块完整拷贝到其他项目根目录下,配置好环境并安装好相关的依赖,参考 runs/main.py 中的调用方式。 63 | 64 | 通行规则识别模块默认的配置文件为 inferences/configs/config.toml,其中各个属性描述如下: 65 | 66 | | 属性名 | 属性描述 | 67 | |:--------------:|:-----------------------------------------:| 68 | | providers | 模型推理 ONNX Runtime Execution Providers 列表。 | 69 | | precision | 推理运算精度,可取 "fp32"(单精度)或 "fp16"(半精度)。 | 70 | | model-path | 推理模型加载路径。 | 71 | | conf-threshold | 目标检测置信度阈值。 | 72 | | iou-threshold | 目标检测非极大值抑制 IoU 阈值。 | 73 | 74 | 如果需要使用自己的数据集训练模型,则需要安装 [Ultralytics](https://docs.ultralytics.com/) 框架,参照 [官方文档](https://docs.ultralytics.com/) 进行模型的训练,最后将模型转换为 ONNX 格式进行部署即可。 75 | 76 | ```shell-session 77 | pip install ultralytics 78 | ``` 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .vscode/ 162 | 163 | checkpoints/* 164 | datasets/train/* 165 | datasets/val/* 166 | runs/outputs/* 167 | runs/samples/* 168 | inferences/models/* 169 | --------------------------------------------------------------------------------