├── assets
└── screenshots.jpg
├── requirements.txt
├── datasets
└── data.yaml
├── inferences
├── configs
│ └── config.toml
└── engines.py
├── runs
└── main.py
├── README.md
└── .gitignore
/assets/screenshots.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LIU42/TrafficRules/HEAD/assets/screenshots.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python~=4.11.0.86
2 | onnxruntime~=1.20.1
3 | toml~=0.10.2
4 | numpy~=2.1.1
--------------------------------------------------------------------------------
/datasets/data.yaml:
--------------------------------------------------------------------------------
1 | train: "../train/images"
2 | val: "../val/images"
3 |
4 | names:
5 | 0: "F0"
6 | 1: "F1"
7 | 2: "L0"
8 | 3: "L1"
9 | 4: "S0"
10 | 5: "S1"
11 | 6: "R0"
12 | 7: "R1"
13 |
--------------------------------------------------------------------------------
/inferences/configs/config.toml:
--------------------------------------------------------------------------------
1 | precision = "fp32"
2 | providers = ["OpenVINOExecutionProvider", "CPUExecutionProvider"]
3 |
4 | conf-threshold = 0.25
5 | iou-threshold = 0.45
6 |
7 | model-path = "inferences/models/detection-fp32.onnx"
8 |
--------------------------------------------------------------------------------
/runs/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import inferences.engines as engines
4 |
5 |
6 | def put_bbox(image, detection_result):
7 | detection_bbox, detection_label = detection_result
8 |
9 | x1 = detection_bbox[0]
10 | y1 = detection_bbox[1]
11 | x2 = detection_bbox[2]
12 | y2 = detection_bbox[3]
13 |
14 | if detection_label[1] == '0':
15 | return cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 215), thickness=2)
16 | else:
17 | return cv2.rectangle(image, (x1, y1), (x2, y2), (0, 215, 0), thickness=2)
18 |
19 |
20 | def put_text(image, detection_result):
21 | detection_bbox, detection_label = detection_result
22 |
23 | x1 = detection_bbox[0]
24 | y1 = detection_bbox[1]
25 |
26 | if detection_label[1] == '0':
27 | return cv2.putText(image, detection_label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 215), thickness=2)
28 | else:
29 | return cv2.putText(image, detection_label, (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 215, 0), thickness=2)
30 |
31 |
32 | def put_rule(image, position, result):
33 | if result:
34 | return cv2.circle(image, position, 12, (0, 215, 0), thickness=-1)
35 | else:
36 | return cv2.circle(image, position, 12, (0, 0, 215), thickness=-1)
37 |
38 |
39 | def main(args):
40 | sources = args.sources
41 | outputs = args.outputs
42 |
43 | for source, output in zip(sources, outputs):
44 | image = cv2.imread(source)
45 | results, detections = engines.inference(image)
46 |
47 | for detection in detections:
48 | image = put_bbox(image, detection)
49 | image = put_text(image, detection)
50 |
51 | image = put_rule(image, (20, 20), results[0])
52 | image = put_rule(image, (50, 20), results[1])
53 | image = put_rule(image, (80, 20), results[2])
54 |
55 | cv2.imwrite(output, image)
56 |
57 |
58 | if __name__ == '__main__':
59 | parser = argparse.ArgumentParser()
60 |
61 | parser.add_argument('-s', '--sources', nargs='+', required=True)
62 | parser.add_argument('-o', '--outputs', nargs='+', required=True)
63 |
64 | main(parser.parse_args())
65 |
--------------------------------------------------------------------------------
/inferences/engines.py:
--------------------------------------------------------------------------------
1 | import onnxruntime as ort
2 | import cv2
3 | import toml
4 | import numpy as np
5 |
6 |
7 | configs = toml.load('inferences/configs/config.toml')
8 |
9 | classes_labels = [
10 | 'F0', 'F1',
11 | 'L0', 'L1',
12 | 'S0', 'S1',
13 | 'R0', 'R1',
14 | ]
15 |
16 | session = ort.InferenceSession(configs['model-path'], providers=configs['providers'])
17 |
18 |
19 | def normalize(inputs):
20 | return inputs / 255.0
21 |
22 |
23 | def preprocess(image):
24 | inputs = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).transpose((2, 0, 1))
25 |
26 | if configs['precision'] == 'fp16':
27 | inputs = normalize(inputs).astype(np.float16)
28 | else:
29 | inputs = normalize(inputs).astype(np.float32)
30 |
31 | return np.expand_dims(inputs, axis=0)
32 |
33 |
34 | def get_valid_outputs(outputs):
35 | valid_outputs = outputs[np.amax(outputs[:, 4:12], axis=1) > configs['conf-threshold']]
36 |
37 | bboxes = valid_outputs[:, 0:4]
38 | scores = valid_outputs[:, 4:12]
39 |
40 | return bboxes.astype(np.int32), np.max(scores, axis=1), np.argmax(scores, axis=1)
41 |
42 |
43 | def non_max_suppression(outputs):
44 | bboxes, scores, classes = get_valid_outputs(outputs)
45 |
46 | bboxes[:, 0] -= bboxes[:, 2] >> 1
47 | bboxes[:, 1] -= bboxes[:, 3] >> 1
48 |
49 | for index in cv2.dnn.NMSBoxes(bboxes, scores, configs['conf-threshold'], configs['iou-threshold'], eta=0.5):
50 | x1 = bboxes[index, 0]
51 | y1 = bboxes[index, 1]
52 |
53 | x2 = bboxes[index, 2] + x1
54 | y2 = bboxes[index, 3] + y1
55 |
56 | yield (x1, y1, x2, y2), classes_labels[classes[index]]
57 |
58 |
59 | def detection_inference(image):
60 | detections = session.run(['output0'], {'images': preprocess(image)})
61 | detections = detections[0]
62 | detections = detections.squeeze().transpose()
63 |
64 | return [detection for detection in non_max_suppression(detections)]
65 |
66 |
67 | def inference(image):
68 | detection_outputs = detection_inference(image)
69 |
70 | result0 = False
71 | result1 = False
72 | result2 = True
73 |
74 | for _, label in detection_outputs:
75 | if label == 'F1':
76 | result0 = True
77 | result1 = True
78 |
79 | for _, label in detection_outputs:
80 | if label == 'L0':
81 | result0 = False
82 | if label == 'L1':
83 | result0 = True
84 | if label == 'S0':
85 | result1 = False
86 | if label == 'S1':
87 | result1 = True
88 | if label == 'R0':
89 | result2 = False
90 | if label == 'R1':
91 | result2 = True
92 |
93 | return (result0, result1, result2), detection_outputs
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrafficRules 通行规则识别
2 |
3 | *v2.0.0 新变化:使用 YOLO11 以及一个更加丰富的数据集训练模型,对原来目标检测和信号分类两个步骤进行整合,去除了在大部分情况下冗余的过滤筛选,实现交通信号灯识别一步到位,得到的模型识别准确率和推理效率均有少量的提升,且更加易于部署。*
4 |
5 | ## 项目简介
6 |
7 | 本项目为基于 YOLO11 的路口交通信号灯通行规则识别,其中通行规则识别分为以下两个步骤:
8 |
9 | 1. **目标检测**,采用 YOLO11 目标检测模型,识别图像中交通信号灯的位置、颜色以及形状(包括圆形、左箭头、上箭头和右箭头)。
10 |
11 | 2. **规则解析**,对图像中检测出来的交通信号灯,解析其表示的通行规则(即能否直行、能否左转和能否右转)。
12 |
13 | - 圆形的信号灯能够控制三个方向的通行规则,优先级较低。
14 |
15 | - 箭头形的信号灯仅能控制对应方向的通行规则,但优先级较高。
16 |
17 | 此外,若无明确信号,即没有红色的右箭头信号灯,右转默认视为允许通行。
18 |
19 | ## 效果展示
20 |
21 | 
22 |
23 | ## 性能评估
24 |
25 | 模型的输入图像尺寸固定为 640 x 480,使用 [YOLO11n](https://docs.ultralytics.com/zh/models/yolo11/) 模型训练,在当前数据集下信号灯目标检测准确性指标如下。
26 |
27 | | Class | Precision | Recall | mAP50 | mAP50-95 |
28 | | ----- | --------- | ------ | ----- | -------- |
29 | | ALL | 0.97 | 0.971 | 0.989 | 0.89 |
30 | | F0 | 0.99 | 1 | 0.995 | 0.871 |
31 | | F1 | 1 | 0.981 | 0.995 | 0.872 |
32 | | L0 | 0.981 | 0.985 | 0.994 | 0.912 |
33 | | L1 | 0.982 | 1 | 0.995 | 0.915 |
34 | | S0 | 1 | 0.817 | 0.944 | 0.878 |
35 | | S1 | 1 | 0.987 | 0.995 | 0.885 |
36 | | R0 | 0.815 | 1 | 0.995 | 0.914 |
37 | | R1 | 0.993 | 1 | 0.995 | 0.876 |
38 |
39 | *注:本项目训练用的数据集规模较小,在真实环境下的鲁棒性可能不够理想。*
40 |
41 | ## 使用说明
42 |
43 | 首先安装依赖工具包,本项目采用 [ONNX Runtime](https://onnxruntime.ai/) 部署模型推理,如果对 Execution Provider 有特殊需求,请参阅 [官方文档](https://onnxruntime.ai/docs/execution-providers/) 进行配置。
44 |
45 | ```shell-session
46 | pip install -r requirements.txt
47 | ```
48 |
49 | 准备好待识别图像,识别程序接受的图像尺寸为 640x480,在本项目 Releases 中下载我训练好的模型权重文件,解压到 inferences/models 目录下,运行程序 runs/main.py 即可,其命令行参数的含义如下,其中输入图像文件路图像文件路径数量保持一致:
50 |
51 | | 参数名 | 简化参数名 | 参数描述 |
52 | |:---------:|:-----:|:--------------------------:|
53 | | --sources | -s | 输入预测图像文件路径序列,不同的文件路径以空格分隔。 |
54 | | --outputs | -o | 输出结果图像文件路径序列,不同的文件路径以空格分隔。 |
55 |
56 | 下面是一个运行示例。
57 |
58 | ```shell-session
59 | python runs/main.py --sources "s1.jpg" "s2.jpg" --outputs "o1.jpg" "o2.jpg"
60 | ```
61 |
62 | 本项目的识别程序被设计为一个即插即用的 Python 模块,可以将 inferences 模块完整拷贝到其他项目根目录下,配置好环境并安装好相关的依赖,参考 runs/main.py 中的调用方式。
63 |
64 | 通行规则识别模块默认的配置文件为 inferences/configs/config.toml,其中各个属性描述如下:
65 |
66 | | 属性名 | 属性描述 |
67 | |:--------------:|:-----------------------------------------:|
68 | | providers | 模型推理 ONNX Runtime Execution Providers 列表。 |
69 | | precision | 推理运算精度,可取 "fp32"(单精度)或 "fp16"(半精度)。 |
70 | | model-path | 推理模型加载路径。 |
71 | | conf-threshold | 目标检测置信度阈值。 |
72 | | iou-threshold | 目标检测非极大值抑制 IoU 阈值。 |
73 |
74 | 如果需要使用自己的数据集训练模型,则需要安装 [Ultralytics](https://docs.ultralytics.com/) 框架,参照 [官方文档](https://docs.ultralytics.com/) 进行模型的训练,最后将模型转换为 ONNX 格式进行部署即可。
75 |
76 | ```shell-session
77 | pip install ultralytics
78 | ```
79 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 | .vscode/
162 |
163 | checkpoints/*
164 | datasets/train/*
165 | datasets/val/*
166 | runs/outputs/*
167 | runs/samples/*
168 | inferences/models/*
169 |
--------------------------------------------------------------------------------