├── .gitattributes
├── .gitignore
├── LaneDetectionLaneNet
├── .idea
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── lanenet-lane-detection-master.iml
│ ├── misc.xml
│ ├── modules.xml
│ └── workspace.xml
├── LICENSE
├── README.md
├── _config.yml
├── config
│ └── global_config.py
├── data_provider
│ ├── lanenet_data_feed_pipline.py
│ └── tf_io_pipline_tools.py
├── lanenet_model
│ ├── lanenet.py
│ ├── lanenet_back_end.py
│ ├── lanenet_discriminative_loss.py
│ ├── lanenet_front_end.py
│ └── lanenet_postprocess.py
├── mnn_project
│ ├── __init__.py
│ ├── config.ini
│ ├── config_parser.cpp
│ ├── config_parser.h
│ ├── convert_lanenet_model_into_mnn_model.sh
│ ├── dbscan.hpp
│ ├── freeze_lanenet_model.py
│ ├── kdtree.cpp
│ ├── kdtree.h
│ ├── lanenet_model.cpp
│ └── lanenet_model.h
├── requirements.txt
├── semantic_segmentation_zoo
│ ├── __init__.py
│ ├── cnn_basenet.py
│ └── vgg16_based_fcn.py
└── tools
│ ├── evaluate_lanenet_on_tusimple.py
│ ├── evaluate_model_utils.py
│ ├── generate_tusimple_dataset.py
│ ├── lane_and_object_detection_on_video.py
│ ├── test_lanenet.py
│ └── train_lanenet.py
├── OUT
├── output_video - Copy.avi
└── output_video.avi
├── ObstacleDetectionYOLO
├── ObjectDetection_YOLO.py
└── images
│ ├── 0.jpg
│ └── 3.jpg
├── Project_report.pdf
├── Test_Detections_On_Video.py
└── readme.md
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/.idea/lanenet-lane-detection-master.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 | 1586554882638
141 |
142 |
143 | 1586554882638
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2018 Luo Yao
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/README.md:
--------------------------------------------------------------------------------
1 | # LaneNet-Lane-Detection
2 | Use tensorflow to implement a Deep Neural Network for real time lane detection mainly based on the IEEE IV conference
3 | paper "Towards End-to-End Lane Detection: an Instance Segmentation Approach".You can refer to their paper for details
4 | https://arxiv.org/abs/1802.05591. This model consists of a encoder-decoder stage, binary semantic segmentation stage
5 | and instance semantic segmentation using discriminative loss function for real time lane detection task.
6 |
7 | The main network architecture is as follows:
8 |
9 | `Network Architecture`
10 | 
11 |
12 | ## Installation
13 | This software has only been tested on ubuntu 16.04(x64), python3.5, cuda-9.0, cudnn-7.0 with a GTX-1070 GPU.
14 | To install this software you need tensorflow 1.10.0 and other version of tensorflow has not been tested but I think
15 | it will be able to work properly in tensorflow above version 1.10. Other required package you may install them by
16 |
17 | ```
18 | pip3 install -r requirements.txt
19 | ```
20 |
21 | ## Test model
22 | In this repo I uploaded a model trained on tusimple lane dataset [Tusimple_Lane_Detection](http://benchmark.tusimple.ai/#/).
23 | The deep neural network inference part can achieve around a 50fps which is similar to the description in the paper. But
24 | the input pipeline I implemented now need to be improved to achieve a real time lane detection system.
25 |
26 | The trained lanenet model weights files are stored in
27 | [new_lanenet_model_file](https://www.dropbox.com/sh/tnsf0lw6psszvy4/AAA81r53jpUI3wLsRW6TiPCya?dl=0). You can
28 | download the model and put them in folder model/tusimple_lanenet_vgg/
29 |
30 | You can test a single image on the trained model as follows
31 |
32 | ```
33 | python tools/test_lanenet.py --weights_path ./model/tusimple_lanenet_vgg/tusimple_lanenet_vgg.ckpt
34 | --image_path ./data/tusimple_test_image/0.jpg
35 | ```
36 | The results are as follows:
37 |
38 | `Test Input Image`
39 |
40 | 
41 |
42 | `Test Lane Mask Image`
43 |
44 | 
45 |
46 | `Test Lane Binary Segmentation Image`
47 |
48 | 
49 |
50 | `Test Lane Instance Segmentation Image`
51 |
52 | 
53 |
54 | If you want to evaluate the model on the whole tusimple test dataset you may call
55 | ```
56 | python tools/evaluate_lanenet_on_tusimple.py
57 | --image_dir ROOT_DIR/TUSIMPLE_DATASET/test_set/clips
58 | --weights_path ./model/tusimple_lanenet_vgg/tusimple_lanenet_vgg.ckpt
59 | --save_dir ROOT_DIR/TUSIMPLE_DATASET/test_set/test_output
60 | ```
61 | If you set the save_dir argument the result will be saved in that folder
62 | or the result will not be saved but be
63 | displayed during the inference process holding on 3 seconds per image.
64 | I test the model on the whole tusimple lane
65 | detection dataset and make it a video. You may catch a glimpse of it bellow.
66 |
67 | `Tusimple test dataset gif`
68 | 
69 |
70 | ## Train your own model
71 | #### Data Preparation
72 | Firstly you need to organize your training data refer to the data/training_data_example folder structure. And you need
73 | to generate a train.txt and a val.txt to record the data used for training the model.
74 |
75 | The training samples are consist of three components. A binary segmentation label file and a instance segmentation label
76 | file and the original image. The binary segmentation use 255 to represent the lane field and 0 for the rest. The
77 | instance use different pixel value to represent different lane field and 0 for the rest.
78 |
79 | All your training image will be scaled into the same scale according to the config file.
80 |
81 | Use the script here to generate the tensorflow records file
82 |
83 | ```
84 | python data_provider/lanenet_data_feed_pipline.py
85 | --dataset_dir ./data/training_data_example
86 | --tfrecords_dir ./data/training_data_example/tfrecords
87 | ```
88 |
89 | #### Train model
90 | In my experiment the training epochs are 80010, batch size is 4, initialized learning rate is 0.001 and use polynomial
91 | decay with power 0.9. About training parameters you can check the global_configuration/config.py for details.
92 | You can switch --net argument to change the base encoder stage. If you choose --net vgg then the vgg16 will be used as
93 | the base encoder stage and a pretrained parameters will be loaded. And you can modified the training
94 | script to load your own pretrained parameters or you can implement your own base encoder stage.
95 | You may call the following script to train your own model
96 |
97 | ```
98 | python tools/train_lanenet.py
99 | --net vgg
100 | --dataset_dir ./data/training_data_example
101 | -m 0
102 | ```
103 | You can also continue the training process from the snapshot by
104 | ```
105 | python tools/train_lanenet.py
106 | --net vgg
107 | --dataset_dir data/training_data_example/
108 | --weights_path path/to/your/last/checkpoint
109 | -m 0
110 | ```
111 |
112 | You may monitor the training process using tensorboard tools
113 |
114 | During my experiment the `Total loss` drops as follows:
115 | 
116 |
117 | The `Binary Segmentation loss` drops as follows:
118 | 
119 |
120 | The `Instance Segmentation loss` drops as follows:
121 | 
122 |
123 | ## Experiment
124 | The accuracy during training process rises as follows:
125 | 
126 |
127 | Please cite my repo [lanenet-lane-detection](https://github.com/MaybeShewill-CV/lanenet-lane-detection) if you use it.
128 |
129 | ## Recently updates 2018.11.10
130 | Adjust some basic cnn op according to the new tensorflow api. Use the
131 | traditional SGD optimizer to optimize the whole model instead of the
132 | origin Adam optimizer used in the origin paper. I have found that the
133 | SGD optimizer will lead to more stable training process and will not
134 | easily stuck into nan loss which may often happen when using the origin
135 | code.
136 |
137 | I have uploaded a new lanenet model trained on tusimple dataset using the
138 | new code here [new_lanenet_model_file](https://www.dropbox.com/sh/tnsf0lw6psszvy4/AAA81r53jpUI3wLsRW6TiPCya?dl=0).
139 | You may download the new model weights and update the new code. To update
140 | the new code you just need to
141 |
142 | ```
143 | git pull origin master
144 | ```
145 | The rest are just the same as which mentioned above. And recently I will
146 | release a new model trained on culane dataset.
147 |
148 | ## Recently updates 2018.12.13
149 | Since a lot of user want a automatic tools to generate the training samples
150 | from the Tusimple Dataset. I upload the tools I use to generate the training
151 | samples. You need to firstly download the Tusimple dataset and unzip the
152 | file to your local disk. Then run the following command to generate the
153 | training samples and the train.txt file.
154 |
155 | ```angular2html
156 | python tools/generate_tusimple_dataset.py --src_dir path/to/your/unzipped/file
157 | ```
158 |
159 | The script will make the train folder and the test folder. The training
160 | samples of origin rgb image, binary label image, instance label image will
161 | be automatically generated in the training/gt_image, training/gt_binary_image,
162 | training/gt_instance_image folder.You may check it yourself before start
163 | the training process.
164 |
165 | Pay attention that the script only process the training samples and you
166 | need to select several lines from the train.txt to generate your own
167 | val.txt file. In order to obtain the test images you can modify the
168 | script on your own.
169 |
170 | ## Recently updates 2019.05.16
171 |
172 | New model weights can be found [here](https://www.dropbox.com/sh/tnsf0lw6psszvy4/AAA81r53jpUI3wLsRW6TiPCya?dl=0)
173 |
174 | ## MNN Project
175 |
176 | Add tools to convert lanenet tensorflow ckpt model into mnn model and deploy
177 | the model on mobile device
178 |
179 | #### Freeze your tensorflow ckpt model weights file
180 | ```
181 | cd LANENET_PROJECT_ROOT_DIR
182 | python mnn_project/freeze_lanenet_model.py -w lanenet.ckpt -s lanenet.pb
183 | ```
184 |
185 | #### Convert pb model into mnn model
186 | ```
187 | cd MNN_PROJECT_ROOT_DIR/tools/converter/build
188 | ./MNNConver -f TF --modelFile lanenet.pb --MNNModel lanenet.mnn --bizCode MNN
189 | ```
190 |
191 | #### Add lanenet source code into MNN project
192 |
193 | Add lanenet source code into MNN project and modified CMakeList.txt to
194 | compile the executable binary file.
195 |
196 | ## TODO
197 | - [x] Add a embedding visualization tools to visualize the embedding feature map
198 | - [x] Add detailed explanation of training the components of lanenet separately.
199 | - [x] Training the model on different dataset
200 | - ~~[ ] Adjust the lanenet hnet model and merge the hnet model to the main lanenet model~~
201 | - ~~[ ] Change the normalization function from BN to GN~~
202 |
203 | ## Acknowledgement
204 |
205 | The lanenet project refers to the following projects:
206 |
207 | - [MNN](https://github.com/alibaba/MNN)
208 | - [SimpleDBSCAN](https://github.com/CallmeNezha/SimpleDBSCAN)
209 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/config/global_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 18-1-31 上午11:21
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : global_config.py
7 | # @IDE: PyCharm Community Edition
8 | """
9 | Set global configuration
10 | """
11 | from easydict import EasyDict as edict
12 |
13 | __C = edict()
14 | # Consumers can get config by: from config import cfg
15 |
16 | cfg = __C
17 |
18 | # Train options
19 | __C.TRAIN = edict()
20 |
21 | # Set the shadownet training epochs
22 | __C.TRAIN.EPOCHS = 80010
23 | # Set the display step
24 | __C.TRAIN.DISPLAY_STEP = 1
25 | # Set the test display step during training process
26 | __C.TRAIN.VAL_DISPLAY_STEP = 1000
27 | # Set the momentum parameter of the optimizer
28 | __C.TRAIN.MOMENTUM = 0.9
29 | # Set the initial learning rate
30 | __C.TRAIN.LEARNING_RATE = 0.0005
31 | # Set the GPU resource used during training process
32 | __C.TRAIN.GPU_MEMORY_FRACTION = 0.95
33 | # Set the GPU allow growth parameter during tensorflow training process
34 | __C.TRAIN.TF_ALLOW_GROWTH = True
35 | # Set the shadownet training batch size
36 | __C.TRAIN.BATCH_SIZE = 4
37 | # Set the shadownet validation batch size
38 | __C.TRAIN.VAL_BATCH_SIZE = 4
39 | # Set the class numbers
40 | __C.TRAIN.CLASSES_NUMS = 2
41 | # Set the image height
42 | __C.TRAIN.IMG_HEIGHT = 256
43 | # Set the image width
44 | __C.TRAIN.IMG_WIDTH = 512
45 | # Set the embedding features dims
46 | __C.TRAIN.EMBEDDING_FEATS_DIMS = 4
47 | # Set the random crop pad size
48 | __C.TRAIN.CROP_PAD_SIZE = 32
49 | # Set cpu multi process thread nums
50 | __C.TRAIN.CPU_MULTI_PROCESS_NUMS = 6
51 | # Set the train moving average decay
52 | __C.TRAIN.MOVING_AVERAGE_DECAY = 0.9999
53 | # Set the GPU nums
54 | __C.TRAIN.GPU_NUM = 2
55 |
56 | # Test options
57 | __C.TEST = edict()
58 |
59 | # Set the GPU resource used during testing process
60 | __C.TEST.GPU_MEMORY_FRACTION = 0.8
61 | # Set the GPU allow growth parameter during tensorflow testing process
62 | __C.TEST.TF_ALLOW_GROWTH = True
63 | # Set the test batch size
64 | __C.TEST.BATCH_SIZE = 2
65 |
66 | # Test options
67 | __C.POSTPROCESS = edict()
68 |
69 | # Set the post process connect components analysis min area threshold
70 | __C.POSTPROCESS.MIN_AREA_THRESHOLD = 100
71 | # Set the post process dbscan search radius threshold
72 | __C.POSTPROCESS.DBSCAN_EPS = 0.35
73 | # Set the post process dbscan min samples threshold
74 | __C.POSTPROCESS.DBSCAN_MIN_SAMPLES = 1000
75 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/data_provider/lanenet_data_feed_pipline.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-4-23 下午3:54
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : lanenet_data_feed_pipline.py
7 | # @IDE: PyCharm
8 | """
9 | Lanenet data feed pip line
10 | """
11 | import argparse
12 | import glob
13 | import os
14 | import os.path as ops
15 | import random
16 |
17 | import glog as log
18 | import tensorflow as tf
19 |
20 | from LaneDetectionLaneNet.config import global_config
21 | from LaneDetectionLaneNet.data_provider import tf_io_pipline_tools
22 |
23 | CFG = global_config.cfg
24 |
25 |
26 | def init_args():
27 | """
28 |
29 | :return:
30 | """
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--dataset_dir', type=str, help='The source nsfw data dir path')
33 | parser.add_argument('--tfrecords_dir', type=str, help='The dir path to save converted tfrecords')
34 |
35 | return parser.parse_args()
36 |
37 |
38 | class LaneNetDataProducer(object):
39 | """
40 | Convert raw image file into tfrecords
41 | """
42 |
43 | def __init__(self, dataset_dir):
44 | """
45 |
46 | :param dataset_dir:
47 | """
48 | self._dataset_dir = dataset_dir
49 |
50 | self._gt_image_dir = ops.join(dataset_dir, 'gt_image')
51 | self._gt_binary_image_dir = ops.join(dataset_dir, 'gt_binary_image')
52 | self._gt_instance_image_dir = ops.join(dataset_dir, 'gt_instance_image')
53 |
54 | self._train_example_index_file_path = ops.join(self._dataset_dir, 'train.txt')
55 | self._test_example_index_file_path = ops.join(self._dataset_dir, 'test.txt')
56 | self._val_example_index_file_path = ops.join(self._dataset_dir, 'val.txt')
57 |
58 | if not self._is_source_data_complete():
59 | raise ValueError('Source image data is not complete, '
60 | 'please check if one of the image folder is not exist')
61 |
62 | if not self._is_training_sample_index_file_complete():
63 | self._generate_training_example_index_file()
64 |
65 | def generate_tfrecords(self, save_dir, step_size=10000):
66 | """
67 | Generate tensorflow records file
68 | :param save_dir:
69 | :param step_size: generate a tfrecord every step_size examples
70 | :return:
71 | """
72 |
73 | def _read_training_example_index_file(_index_file_path):
74 |
75 | assert ops.exists(_index_file_path)
76 |
77 | _example_gt_path_info = []
78 | _example_gt_binary_path_info = []
79 | _example_gt_instance_path_info = []
80 |
81 | with open(_index_file_path, 'r') as _file:
82 | for _line in _file:
83 | _example_info = _line.rstrip('\r').rstrip('\n').split(' ')
84 | _example_gt_path_info.append(_example_info[0])
85 | _example_gt_binary_path_info.append(_example_info[1])
86 | _example_gt_instance_path_info.append(_example_info[2])
87 |
88 | ret = {
89 | 'gt_path_info': _example_gt_path_info,
90 | 'gt_binary_path_info': _example_gt_binary_path_info,
91 | 'gt_instance_path_info': _example_gt_instance_path_info
92 | }
93 |
94 | return ret
95 |
96 | def _split_writing_tfrecords_task(
97 | _example_gt_paths, _example_gt_binary_paths, _example_gt_instance_paths, _flags='train'):
98 |
99 | _split_example_gt_paths = []
100 | _split_example_gt_binary_paths = []
101 | _split_example_gt_instance_paths = []
102 | _split_tfrecords_save_paths = []
103 |
104 | for i in range(0, len(_example_gt_paths), step_size):
105 | _split_example_gt_paths.append(_example_gt_paths[i:i + step_size])
106 | _split_example_gt_binary_paths.append(_example_gt_binary_paths[i:i + step_size])
107 | _split_example_gt_instance_paths.append(_example_gt_instance_paths[i:i + step_size])
108 |
109 | if i + step_size > len(_example_gt_paths):
110 | _split_tfrecords_save_paths.append(
111 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, len(_example_gt_paths))))
112 | else:
113 | _split_tfrecords_save_paths.append(
114 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, i + step_size)))
115 |
116 | ret = {
117 | 'gt_paths': _split_example_gt_paths,
118 | 'gt_binary_paths': _split_example_gt_binary_paths,
119 | 'gt_instance_paths': _split_example_gt_instance_paths,
120 | 'tfrecords_paths': _split_tfrecords_save_paths
121 | }
122 |
123 | return ret
124 |
125 | # make save dirs
126 | os.makedirs(save_dir, exist_ok=True)
127 |
128 | # start generating training example tfrecords
129 | log.info('Start generating training example tfrecords')
130 |
131 | # collecting train images paths info
132 | train_image_paths_info = _read_training_example_index_file(self._train_example_index_file_path)
133 | train_gt_images_paths = train_image_paths_info['gt_path_info']
134 | train_gt_binary_images_paths = train_image_paths_info['gt_binary_path_info']
135 | train_gt_instance_images_paths = train_image_paths_info['gt_instance_path_info']
136 |
137 | # split training images according step size
138 | train_split_result = _split_writing_tfrecords_task(
139 | train_gt_images_paths, train_gt_binary_images_paths, train_gt_instance_images_paths, _flags='train')
140 | train_example_gt_paths = train_split_result['gt_paths']
141 | train_example_gt_binary_paths = train_split_result['gt_binary_paths']
142 | train_example_gt_instance_paths = train_split_result['gt_instance_paths']
143 | train_example_tfrecords_paths = train_split_result['tfrecords_paths']
144 |
145 | for index, example_gt_paths in enumerate(train_example_gt_paths):
146 | tf_io_pipline_tools.write_example_tfrecords(
147 | example_gt_paths,
148 | train_example_gt_binary_paths[index],
149 | train_example_gt_instance_paths[index],
150 | train_example_tfrecords_paths[index]
151 | )
152 |
153 | log.info('Generating training example tfrecords complete')
154 |
155 | # start generating validation example tfrecords
156 | log.info('Start generating validation example tfrecords')
157 |
158 | # collecting validation images paths info
159 | val_image_paths_info = _read_training_example_index_file(self._val_example_index_file_path)
160 | val_gt_images_paths = val_image_paths_info['gt_path_info']
161 | val_gt_binary_images_paths = val_image_paths_info['gt_binary_path_info']
162 | val_gt_instance_images_paths = val_image_paths_info['gt_instance_path_info']
163 |
164 | # split validation images according step size
165 | val_split_result = _split_writing_tfrecords_task(
166 | val_gt_images_paths, val_gt_binary_images_paths, val_gt_instance_images_paths, _flags='val')
167 | val_example_gt_paths = val_split_result['gt_paths']
168 | val_example_gt_binary_paths = val_split_result['gt_binary_paths']
169 | val_example_gt_instance_paths = val_split_result['gt_instance_paths']
170 | val_example_tfrecords_paths = val_split_result['tfrecords_paths']
171 |
172 | for index, example_gt_paths in enumerate(val_example_gt_paths):
173 | tf_io_pipline_tools.write_example_tfrecords(
174 | example_gt_paths,
175 | val_example_gt_binary_paths[index],
176 | val_example_gt_instance_paths[index],
177 | val_example_tfrecords_paths[index]
178 | )
179 |
180 | log.info('Generating validation example tfrecords complete')
181 |
182 | # generate test example tfrecords
183 | log.info('Start generating testing example tfrecords')
184 |
185 | # collecting test images paths info
186 | test_image_paths_info = _read_training_example_index_file(self._test_example_index_file_path)
187 | test_gt_images_paths = test_image_paths_info['gt_path_info']
188 | test_gt_binary_images_paths = test_image_paths_info['gt_binary_path_info']
189 | test_gt_instance_images_paths = test_image_paths_info['gt_instance_path_info']
190 |
191 | # split validating images according step size
192 | test_split_result = _split_writing_tfrecords_task(
193 | test_gt_images_paths, test_gt_binary_images_paths, test_gt_instance_images_paths, _flags='test')
194 | test_example_gt_paths = test_split_result['gt_paths']
195 | test_example_gt_binary_paths = test_split_result['gt_binary_paths']
196 | test_example_gt_instance_paths = test_split_result['gt_instance_paths']
197 | test_example_tfrecords_paths = test_split_result['tfrecords_paths']
198 |
199 | for index, example_gt_paths in enumerate(test_example_gt_paths):
200 | tf_io_pipline_tools.write_example_tfrecords(
201 | example_gt_paths,
202 | test_example_gt_binary_paths[index],
203 | test_example_gt_instance_paths[index],
204 | test_example_tfrecords_paths[index]
205 | )
206 |
207 | log.info('Generating testing example tfrecords complete')
208 |
209 | return
210 |
211 | def _is_source_data_complete(self):
212 | """
213 | Check if source data complete
214 | :return:
215 | """
216 | return \
217 | ops.exists(self._gt_binary_image_dir) and \
218 | ops.exists(self._gt_instance_image_dir) and \
219 | ops.exists(self._gt_image_dir)
220 |
221 | def _is_training_sample_index_file_complete(self):
222 | """
223 | Check if the training sample index file is complete
224 | :return:
225 | """
226 | return \
227 | ops.exists(self._train_example_index_file_path) and \
228 | ops.exists(self._test_example_index_file_path) and \
229 | ops.exists(self._val_example_index_file_path)
230 |
231 | def _generate_training_example_index_file(self):
232 | """
233 | Generate training example index file, split source file into 0.85, 0.1, 0.05 for training,
234 | testing and validation. Each image folder are processed separately
235 | :return:
236 | """
237 |
238 | def _gather_example_info():
239 | """
240 |
241 | :return:
242 | """
243 | _info = []
244 |
245 | for _gt_image_path in glob.glob('{:s}/*.png'.format(self._gt_image_dir)):
246 | _gt_binary_image_name = ops.split(_gt_image_path)[1]
247 | _gt_binary_image_path = ops.join(self._gt_binary_image_dir, _gt_binary_image_name)
248 | _gt_instance_image_name = ops.split(_gt_image_path)[1]
249 | _gt_instance_image_path = ops.join(self._gt_instance_image_dir, _gt_instance_image_name)
250 |
251 | assert ops.exists(_gt_binary_image_path), '{:s} not exist'.format(_gt_binary_image_path)
252 | assert ops.exists(_gt_instance_image_path), '{:s} not exist'.format(_gt_instance_image_path)
253 |
254 | _info.append('{:s} {:s} {:s}\n'.format(
255 | _gt_image_path,
256 | _gt_binary_image_path,
257 | _gt_instance_image_path)
258 | )
259 |
260 | return _info
261 |
262 | def _split_training_examples(_example_info):
263 | random.shuffle(_example_info)
264 |
265 | _example_nums = len(_example_info)
266 |
267 | _train_example_info = _example_info[:int(_example_nums * 0.85)]
268 | _val_example_info = _example_info[int(_example_nums * 0.85):int(_example_nums * 0.9)]
269 | _test_example_info = _example_info[int(_example_nums * 0.9):]
270 |
271 | return _train_example_info, _test_example_info, _val_example_info
272 |
273 | train_example_info, test_example_info, val_example_info = _split_training_examples(_gather_example_info())
274 |
275 | random.shuffle(train_example_info)
276 | random.shuffle(test_example_info)
277 | random.shuffle(val_example_info)
278 |
279 | with open(ops.join(self._dataset_dir, 'train.txt'), 'w') as file:
280 | file.write(''.join(train_example_info))
281 |
282 | with open(ops.join(self._dataset_dir, 'test.txt'), 'w') as file:
283 | file.write(''.join(test_example_info))
284 |
285 | with open(ops.join(self._dataset_dir, 'val.txt'), 'w') as file:
286 | file.write(''.join(val_example_info))
287 |
288 | log.info('Generating training example index file complete')
289 |
290 | return
291 |
292 |
293 | class LaneNetDataFeeder(object):
294 | """
295 | Read training examples from tfrecords for nsfw model
296 | """
297 |
298 | def __init__(self, dataset_dir, flags='train'):
299 | """
300 |
301 | :param dataset_dir:
302 | :param flags:
303 | """
304 | self._dataset_dir = dataset_dir
305 |
306 | self._tfrecords_dir = ops.join(dataset_dir, 'tfrecords')
307 | if not ops.exists(self._tfrecords_dir):
308 | raise ValueError('{:s} not exist, please check again'.format(self._tfrecords_dir))
309 |
310 | self._dataset_flags = flags.lower()
311 | if self._dataset_flags not in ['train', 'test', 'val']:
312 | raise ValueError('flags of the data feeder should be \'train\', \'test\', \'val\'')
313 |
314 | def inputs(self, batch_size, num_epochs):
315 | """
316 | dataset feed pipline input
317 | :param batch_size:
318 | :param num_epochs:
319 | :return: A tuple (images, labels), where:
320 | * images is a float tensor with shape [batch_size, H, W, C]
321 | in the range [-0.5, 0.5].
322 | * labels is an int32 tensor with shape [batch_size] with the true label,
323 | a number in the range [0, CLASS_NUMS).
324 | """
325 | if not num_epochs:
326 | num_epochs = None
327 |
328 | tfrecords_file_paths = glob.glob('{:s}/{:s}*.tfrecords'.format(
329 | self._tfrecords_dir, self._dataset_flags)
330 | )
331 | random.shuffle(tfrecords_file_paths)
332 |
333 | with tf.name_scope('input_tensor'):
334 |
335 | # TFRecordDataset opens a binary file and reads one record at a time.
336 | # `tfrecords_file_paths` could also be a list of filenames, which will be read in order.
337 | dataset = tf.data.TFRecordDataset(tfrecords_file_paths)
338 |
339 | # The map transformation takes a function and applies it to every element
340 | # of the dataset.
341 | dataset = dataset.map(map_func=tf_io_pipline_tools.decode,
342 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
343 | if self._dataset_flags != 'test':
344 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_train,
345 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
346 | else:
347 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_test,
348 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
349 | dataset = dataset.map(map_func=tf_io_pipline_tools.normalize,
350 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
351 |
352 | # The shuffle transformation uses a finite-sized buffer to shuffle elements
353 | # in memory. The parameter is the number of elements in the buffer. For
354 | # completely uniform shuffling, set the parameter to be the same as the
355 | # number of elements in the dataset.
356 | if self._dataset_flags != 'test':
357 | dataset = dataset.shuffle(buffer_size=1000)
358 | # repeat num epochs
359 | dataset = dataset.repeat()
360 |
361 | dataset = dataset.batch(batch_size, drop_remainder=True)
362 |
363 | iterator = dataset.make_one_shot_iterator()
364 |
365 | return iterator.get_next(name='{:s}_IteratorGetNext'.format(self._dataset_flags))
366 |
367 |
368 | if __name__ == '__main__':
369 | # init args
370 | args = init_args()
371 |
372 | assert ops.exists(args.dataset_dir), '{:s} not exist'.format(args.dataset_dir)
373 |
374 | producer = LaneNetDataProducer(dataset_dir=args.dataset_dir)
375 | producer.generate_tfrecords(save_dir=args.tfrecords_dir, step_size=1000)
376 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/data_provider/tf_io_pipline_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-4-23 下午3:53
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : tf_io_pipline_tools.py
7 | # @IDE: PyCharm
8 | """
9 | tensorflow io pip line tools
10 | """
11 | import os
12 | import os.path as ops
13 |
14 | import cv2
15 | import glog as log
16 | import numpy as np
17 | import tensorflow as tf
18 |
19 | from LaneDetectionLaneNet.config import global_config
20 |
21 | CFG = global_config.cfg
22 |
23 | RESIZE_IMAGE_HEIGHT = CFG.TRAIN.IMG_HEIGHT + CFG.TRAIN.CROP_PAD_SIZE
24 | RESIZE_IMAGE_WIDTH = CFG.TRAIN.IMG_WIDTH + CFG.TRAIN.CROP_PAD_SIZE
25 | CROP_IMAGE_HEIGHT = CFG.TRAIN.IMG_HEIGHT
26 | CROP_IMAGE_WIDTH = CFG.TRAIN.IMG_WIDTH
27 |
28 |
29 | def int64_feature(value):
30 | """
31 |
32 | :return:
33 | """
34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
35 |
36 |
37 | def bytes_feature(value):
38 | """
39 |
40 | :param value:
41 | :return:
42 | """
43 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
44 |
45 |
46 | def write_example_tfrecords(gt_images_paths, gt_binary_images_paths, gt_instance_images_paths, tfrecords_path):
47 | """
48 | write tfrecords
49 | :param gt_images_paths:
50 | :param gt_binary_images_paths:
51 | :param gt_instance_images_paths:
52 | :param tfrecords_path:
53 | :return:
54 | """
55 | _tfrecords_dir = ops.split(tfrecords_path)[0]
56 | os.makedirs(_tfrecords_dir, exist_ok=True)
57 |
58 | log.info('Writing {:s}....'.format(tfrecords_path))
59 |
60 | with tf.python_io.TFRecordWriter(tfrecords_path) as _writer:
61 | for _index, _gt_image_path in enumerate(gt_images_paths):
62 |
63 | # prepare gt image
64 | _gt_image = cv2.imread(_gt_image_path, cv2.IMREAD_UNCHANGED)
65 | if _gt_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT, 3):
66 | _gt_image = cv2.resize(
67 | _gt_image,
68 | dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
69 | interpolation=cv2.INTER_LINEAR
70 | )
71 | _gt_image_raw = _gt_image.tostring()
72 |
73 | # prepare gt binary image
74 | _gt_binary_image = cv2.imread(gt_binary_images_paths[_index], cv2.IMREAD_UNCHANGED)
75 | if _gt_binary_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT):
76 | _gt_binary_image = cv2.resize(
77 | _gt_binary_image,
78 | dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
79 | interpolation=cv2.INTER_NEAREST
80 | )
81 | _gt_binary_image = np.array(_gt_binary_image / 255.0, dtype=np.uint8)
82 | _gt_binary_image_raw = _gt_binary_image.tostring()
83 |
84 | # prepare gt instance image
85 | _gt_instance_image = cv2.imread(gt_instance_images_paths[_index], cv2.IMREAD_UNCHANGED)
86 | if _gt_instance_image.shape != (RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT):
87 | _gt_instance_image = cv2.resize(
88 | _gt_instance_image,
89 | dsize=(RESIZE_IMAGE_WIDTH, RESIZE_IMAGE_HEIGHT),
90 | interpolation=cv2.INTER_NEAREST
91 | )
92 | _gt_instance_image_raw = _gt_instance_image.tostring()
93 |
94 | _example = tf.train.Example(
95 | features=tf.train.Features(
96 | feature={
97 | 'gt_image_raw': bytes_feature(_gt_image_raw),
98 | 'gt_binary_image_raw': bytes_feature(_gt_binary_image_raw),
99 | 'gt_instance_image_raw': bytes_feature(_gt_instance_image_raw)
100 | }))
101 | _writer.write(_example.SerializeToString())
102 |
103 | log.info('Writing {:s} complete'.format(tfrecords_path))
104 |
105 | return
106 |
107 |
108 | def decode(serialized_example):
109 | """
110 | Parses an image and label from the given `serialized_example`
111 | :param serialized_example:
112 | :return:
113 | """
114 | features = tf.parse_single_example(
115 | serialized_example,
116 | # Defaults are not specified since both keys are required.
117 | features={
118 | 'gt_image_raw': tf.FixedLenFeature([], tf.string),
119 | 'gt_binary_image_raw': tf.FixedLenFeature([], tf.string),
120 | 'gt_instance_image_raw': tf.FixedLenFeature([], tf.string)
121 | })
122 |
123 | # decode gt image
124 | gt_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 3])
125 | gt_image = tf.decode_raw(features['gt_image_raw'], tf.uint8)
126 | gt_image = tf.reshape(gt_image, gt_image_shape)
127 |
128 | # decode gt binary image
129 | gt_binary_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 1])
130 | gt_binary_image = tf.decode_raw(features['gt_binary_image_raw'], tf.uint8)
131 | gt_binary_image = tf.reshape(gt_binary_image, gt_binary_image_shape)
132 |
133 | # decode gt instance image
134 | gt_instance_image_shape = tf.stack([RESIZE_IMAGE_HEIGHT, RESIZE_IMAGE_WIDTH, 1])
135 | gt_instance_image = tf.decode_raw(features['gt_instance_image_raw'], tf.uint8)
136 | gt_instance_image = tf.reshape(gt_instance_image, gt_instance_image_shape)
137 |
138 | return gt_image, gt_binary_image, gt_instance_image
139 |
140 |
141 | def central_crop(image, crop_height, crop_width):
142 | """
143 | Performs central crops of the given image
144 | :param image:
145 | :param crop_height:
146 | :param crop_width:
147 | :return:
148 | """
149 | shape = tf.shape(input=image)
150 | height, width = shape[0], shape[1]
151 |
152 | amount_to_be_cropped_h = (height - crop_height)
153 | crop_top = amount_to_be_cropped_h // 2
154 | amount_to_be_cropped_w = (width - crop_width)
155 | crop_left = amount_to_be_cropped_w // 2
156 |
157 | return tf.slice(image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
158 |
159 |
160 | def augment_for_train(gt_image, gt_binary_image, gt_instance_image):
161 | """
162 |
163 | :param gt_image:
164 | :param gt_binary_image:
165 | :param gt_instance_image:
166 | :return:
167 | """
168 | # convert image from uint8 to float32
169 | gt_image = tf.cast(gt_image, tf.float32)
170 | gt_binary_image = tf.cast(gt_binary_image, tf.float32)
171 | gt_instance_image = tf.cast(gt_instance_image, tf.float32)
172 |
173 | # apply random color augmentation
174 | gt_image, gt_binary_image, gt_instance_image = random_color_augmentation(
175 | gt_image, gt_binary_image, gt_instance_image
176 | )
177 |
178 | # apply random flip augmentation
179 | gt_image, gt_binary_image, gt_instance_image = random_horizon_flip_batch_images(
180 | gt_image, gt_binary_image, gt_instance_image
181 | )
182 |
183 | # apply random crop image
184 | return random_crop_batch_images(
185 | gt_image=gt_image,
186 | gt_binary_image=gt_binary_image,
187 | gt_instance_image=gt_instance_image,
188 | cropped_size=[CROP_IMAGE_WIDTH, CROP_IMAGE_HEIGHT]
189 | )
190 |
191 |
192 | def augment_for_test(gt_image, gt_binary_image, gt_instance_image):
193 | """
194 |
195 | :param gt_image:
196 | :param gt_binary_image:
197 | :param gt_instance_image:
198 | :return:
199 | """
200 | # apply central crop
201 | gt_image = central_crop(
202 | image=gt_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
203 | )
204 | gt_binary_image = central_crop(
205 | image=gt_binary_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
206 | )
207 | gt_instance_image = central_crop(
208 | image=gt_instance_image, crop_height=CROP_IMAGE_HEIGHT, crop_width=CROP_IMAGE_WIDTH
209 | )
210 |
211 | return gt_image, gt_binary_image, gt_instance_image
212 |
213 |
214 | def normalize(gt_image, gt_binary_image, gt_instance_image):
215 | """
216 | Normalize the image data by substracting the imagenet mean value
217 | :param gt_image:
218 | :param gt_binary_image:
219 | :param gt_instance_image:
220 | :return:
221 | """
222 |
223 | if gt_image.get_shape().as_list()[-1] != 3 \
224 | or gt_binary_image.get_shape().as_list()[-1] != 1 \
225 | or gt_instance_image.get_shape().as_list()[-1] != 1:
226 | log.error(gt_image.get_shape())
227 | log.error(gt_binary_image.get_shape())
228 | log.error(gt_instance_image.get_shape())
229 | raise ValueError('Input must be of size [height, width, C>0]')
230 |
231 | gt_image = tf.subtract(tf.divide(gt_image, tf.constant(127.5, dtype=tf.float32)),
232 | tf.constant(1.0, dtype=tf.float32))
233 |
234 | return gt_image, gt_binary_image, gt_instance_image
235 |
236 |
237 | def random_crop_batch_images(gt_image, gt_binary_image, gt_instance_image, cropped_size):
238 | """
239 | Random crop image batch data for training
240 | :param gt_image:
241 | :param gt_binary_image:
242 | :param gt_instance_image:
243 | :param cropped_size:
244 | :return:
245 | """
246 | concat_images = tf.concat([gt_image, gt_binary_image, gt_instance_image], axis=-1)
247 |
248 | concat_cropped_images = tf.image.random_crop(
249 | concat_images,
250 | [cropped_size[1], cropped_size[0], tf.shape(concat_images)[-1]],
251 | seed=tf.random.set_random_seed(1234)
252 | )
253 |
254 | cropped_gt_image = tf.slice(
255 | concat_cropped_images,
256 | begin=[0, 0, 0],
257 | size=[cropped_size[1], cropped_size[0], 3]
258 | )
259 | cropped_gt_binary_image = tf.slice(
260 | concat_cropped_images,
261 | begin=[0, 0, 3],
262 | size=[cropped_size[1], cropped_size[0], 1]
263 | )
264 | cropped_gt_instance_image = tf.slice(
265 | concat_cropped_images,
266 | begin=[0, 0, 4],
267 | size=[cropped_size[1], cropped_size[0], 1]
268 | )
269 |
270 | return cropped_gt_image, cropped_gt_binary_image, cropped_gt_instance_image
271 |
272 |
273 | def random_horizon_flip_batch_images(gt_image, gt_binary_image, gt_instance_image):
274 | """
275 | Random horizon flip image batch data for training
276 | :param gt_image:
277 | :param gt_binary_image:
278 | :param gt_instance_image:
279 | :return:
280 | """
281 | concat_images = tf.concat([gt_image, gt_binary_image, gt_instance_image], axis=-1)
282 |
283 | [image_height, image_width, _] = gt_image.get_shape().as_list()
284 |
285 | concat_flipped_images = tf.image.random_flip_left_right(
286 | image=concat_images,
287 | seed=tf.random.set_random_seed(1)
288 | )
289 |
290 | flipped_gt_image = tf.slice(
291 | concat_flipped_images,
292 | begin=[0, 0, 0],
293 | size=[image_height, image_width, 3]
294 | )
295 | flipped_gt_binary_image = tf.slice(
296 | concat_flipped_images,
297 | begin=[0, 0, 3],
298 | size=[image_height, image_width, 1]
299 | )
300 | flipped_gt_instance_image = tf.slice(
301 | concat_flipped_images,
302 | begin=[0, 0, 4],
303 | size=[image_height, image_width, 1]
304 | )
305 |
306 | return flipped_gt_image, flipped_gt_binary_image, flipped_gt_instance_image
307 |
308 |
309 | def random_color_augmentation(gt_image, gt_binary_image, gt_instance_image):
310 | """
311 | andom color augmentation
312 | :param gt_image:
313 | :param gt_binary_image:
314 | :param gt_instance_image:
315 | :return:
316 | """
317 | # first apply random saturation augmentation
318 | gt_image = tf.image.random_saturation(gt_image, 0.8, 1.2)
319 | # sencond apply random brightness augmentation
320 | gt_image = tf.image.random_brightness(gt_image, 0.05)
321 | # third apply random contrast augmentation
322 | gt_image = tf.image.random_contrast(gt_image, 0.7, 1.3)
323 |
324 | gt_image = tf.clip_by_value(gt_image, 0.0, 255.0)
325 |
326 | return gt_image, gt_binary_image, gt_instance_image
327 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/lanenet_model/lanenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Author: Mayur Sunil Jawalkar (mj8628)
5 | Kunjan Suresh Mhaske (km1556)
6 |
7 | Implement LaneNet Model
8 | """
9 | import tensorflow as tf
10 |
11 | from LaneDetectionLaneNet.config import global_config
12 | from LaneDetectionLaneNet.lanenet_model import lanenet_back_end
13 | from LaneDetectionLaneNet.lanenet_model import lanenet_front_end
14 | from LaneDetectionLaneNet.semantic_segmentation_zoo import cnn_basenet
15 |
16 | CFG = global_config.cfg
17 |
18 |
19 | class LaneNet(cnn_basenet.CNNBaseModel):
20 | """
21 |
22 | """
23 | def __init__(self, phase, net_flag='vgg', reuse=tf.AUTO_REUSE):
24 | """
25 |
26 | """
27 | super(LaneNet, self).__init__()
28 | self._net_flag = net_flag
29 | self._reuse = reuse
30 |
31 | self._frontend = lanenet_front_end.LaneNetFrondEnd(
32 | phase=phase, net_flag=net_flag
33 | )
34 | self._backend = lanenet_back_end.LaneNetBackEnd(
35 | phase=phase
36 | )
37 |
38 | def inference(self, input_tensor, name):
39 | """
40 |
41 | :param input_tensor:
42 | :param name:
43 | :return:
44 | """
45 | with tf.variable_scope(name_or_scope=name, reuse=self._reuse):
46 | # first extract image features
47 | extract_feats_result = self._frontend.build_model(
48 | input_tensor=input_tensor,
49 | name='{:s}_frontend'.format(self._net_flag),
50 | reuse=self._reuse
51 | )
52 |
53 | # second apply backend process
54 | binary_seg_prediction, instance_seg_prediction = self._backend.inference(
55 | binary_seg_logits=extract_feats_result['binary_segment_logits']['data'],
56 | instance_seg_logits=extract_feats_result['instance_segment_logits']['data'],
57 | name='{:s}_backend'.format(self._net_flag),
58 | reuse=self._reuse
59 | )
60 |
61 | if not self._reuse:
62 | self._reuse = True
63 |
64 | return binary_seg_prediction, instance_seg_prediction
65 |
66 | def compute_loss(self, input_tensor, binary_label, instance_label, name):
67 | """
68 | calculate lanenet loss for training
69 | :param input_tensor:
70 | :param binary_label:
71 | :param instance_label:
72 | :param name:
73 | :return:
74 | """
75 | with tf.variable_scope(name_or_scope=name, reuse=self._reuse):
76 | # first extract image features
77 | extract_feats_result = self._frontend.build_model(
78 | input_tensor=input_tensor,
79 | name='{:s}_frontend'.format(self._net_flag),
80 | reuse=self._reuse
81 | )
82 |
83 | # second apply backend process
84 | calculated_losses = self._backend.compute_loss(
85 | binary_seg_logits=extract_feats_result['binary_segment_logits']['data'],
86 | binary_label=binary_label,
87 | instance_seg_logits=extract_feats_result['instance_segment_logits']['data'],
88 | instance_label=instance_label,
89 | name='{:s}_backend'.format(self._net_flag),
90 | reuse=self._reuse
91 | )
92 |
93 | if not self._reuse:
94 | self._reuse = True
95 |
96 | return calculated_losses
97 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/lanenet_model/lanenet_back_end.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-4-24 下午3:54
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : lanenet_back_end.py
7 | # @IDE: PyCharm
8 | """
9 | LaneNet backend branch which is mainly used for binary and instance segmentation loss calculation
10 | """
11 | import tensorflow as tf
12 |
13 | from LaneDetectionLaneNet.config import global_config
14 | from LaneDetectionLaneNet.lanenet_model import lanenet_discriminative_loss
15 | from LaneDetectionLaneNet.semantic_segmentation_zoo import cnn_basenet
16 |
17 | CFG = global_config.cfg
18 |
19 |
20 | class LaneNetBackEnd(cnn_basenet.CNNBaseModel):
21 | """
22 | LaneNet backend branch which is mainly used for binary and instance segmentation loss calculation
23 | """
24 | def __init__(self, phase):
25 | """
26 | init lanenet backend
27 | :param phase: train or test
28 | """
29 | super(LaneNetBackEnd, self).__init__()
30 | self._phase = phase
31 | self._is_training = self._is_net_for_training()
32 |
33 | def _is_net_for_training(self):
34 | """
35 | if the net is used for training or not
36 | :return:
37 | """
38 | if isinstance(self._phase, tf.Tensor):
39 | phase = self._phase
40 | else:
41 | phase = tf.constant(self._phase, dtype=tf.string)
42 |
43 | return tf.equal(phase, tf.constant('train', dtype=tf.string))
44 |
45 | @classmethod
46 | def _compute_class_weighted_cross_entropy_loss(cls, onehot_labels, logits, classes_weights):
47 | """
48 |
49 | :param onehot_labels:
50 | :param logits:
51 | :param classes_weights:
52 | :return:
53 | """
54 | loss_weights = tf.reduce_sum(tf.multiply(onehot_labels, classes_weights), axis=3)
55 |
56 | loss = tf.losses.softmax_cross_entropy(
57 | onehot_labels=onehot_labels,
58 | logits=logits,
59 | weights=loss_weights
60 | )
61 |
62 | return loss
63 |
64 | def compute_loss(self, binary_seg_logits, binary_label,
65 | instance_seg_logits, instance_label,
66 | name, reuse):
67 | """
68 | compute lanenet loss
69 | :param binary_seg_logits:
70 | :param binary_label:
71 | :param instance_seg_logits:
72 | :param instance_label:
73 | :param name:
74 | :param reuse:
75 | :return:
76 | """
77 | with tf.variable_scope(name_or_scope=name, reuse=reuse):
78 | # calculate class weighted binary seg loss
79 | with tf.variable_scope(name_or_scope='binary_seg'):
80 | binary_label_onehot = tf.one_hot(
81 | tf.reshape(
82 | tf.cast(binary_label, tf.int32),
83 | shape=[binary_label.get_shape().as_list()[0],
84 | binary_label.get_shape().as_list()[1],
85 | binary_label.get_shape().as_list()[2]]),
86 | depth=CFG.TRAIN.CLASSES_NUMS,
87 | axis=-1
88 | )
89 |
90 | binary_label_plain = tf.reshape(
91 | binary_label,
92 | shape=[binary_label.get_shape().as_list()[0] *
93 | binary_label.get_shape().as_list()[1] *
94 | binary_label.get_shape().as_list()[2] *
95 | binary_label.get_shape().as_list()[3]])
96 | unique_labels, unique_id, counts = tf.unique_with_counts(binary_label_plain)
97 | counts = tf.cast(counts, tf.float32)
98 | inverse_weights = tf.divide(
99 | 1.0,
100 | tf.log(tf.add(tf.divide(counts, tf.reduce_sum(counts)), tf.constant(1.02)))
101 | )
102 |
103 | binary_segmenatation_loss = self._compute_class_weighted_cross_entropy_loss(
104 | onehot_labels=binary_label_onehot,
105 | logits=binary_seg_logits,
106 | classes_weights=inverse_weights
107 | )
108 |
109 | # calculate class weighted instance seg loss
110 | with tf.variable_scope(name_or_scope='instance_seg'):
111 |
112 | pix_bn = self.layerbn(
113 | inputdata=instance_seg_logits, is_training=self._is_training, name='pix_bn')
114 | pix_relu = self.relu(inputdata=pix_bn, name='pix_relu')
115 | pix_embedding = self.conv2d(
116 | inputdata=pix_relu,
117 | out_channel=CFG.TRAIN.EMBEDDING_FEATS_DIMS,
118 | kernel_size=1,
119 | use_bias=False,
120 | name='pix_embedding_conv'
121 | )
122 | pix_image_shape = (pix_embedding.get_shape().as_list()[1], pix_embedding.get_shape().as_list()[2])
123 | instance_segmentation_loss, l_var, l_dist, l_reg = \
124 | lanenet_discriminative_loss.discriminative_loss(
125 | pix_embedding, instance_label, CFG.TRAIN.EMBEDDING_FEATS_DIMS,
126 | pix_image_shape, 0.5, 3.0, 1.0, 1.0, 0.001
127 | )
128 |
129 | l2_reg_loss = tf.constant(0.0, tf.float32)
130 | for vv in tf.trainable_variables():
131 | if 'bn' in vv.name or 'gn' in vv.name:
132 | continue
133 | else:
134 | l2_reg_loss = tf.add(l2_reg_loss, tf.nn.l2_loss(vv))
135 | l2_reg_loss *= 0.001
136 | total_loss = binary_segmenatation_loss + instance_segmentation_loss + l2_reg_loss
137 |
138 | ret = {
139 | 'total_loss': total_loss,
140 | 'binary_seg_logits': binary_seg_logits,
141 | 'instance_seg_logits': pix_embedding,
142 | 'binary_seg_loss': binary_segmenatation_loss,
143 | 'discriminative_loss': instance_segmentation_loss
144 | }
145 |
146 | return ret
147 |
148 | def inference(self, binary_seg_logits, instance_seg_logits, name, reuse):
149 | """
150 |
151 | :param binary_seg_logits:
152 | :param instance_seg_logits:
153 | :param name:
154 | :param reuse:
155 | :return:
156 | """
157 | with tf.variable_scope(name_or_scope=name, reuse=reuse):
158 |
159 | with tf.variable_scope(name_or_scope='binary_seg'):
160 | binary_seg_score = tf.nn.softmax(logits=binary_seg_logits)
161 | binary_seg_prediction = tf.argmax(binary_seg_score, axis=-1)
162 |
163 | with tf.variable_scope(name_or_scope='instance_seg'):
164 |
165 | pix_bn = self.layerbn(
166 | inputdata=instance_seg_logits, is_training=self._is_training, name='pix_bn')
167 | pix_relu = self.relu(inputdata=pix_bn, name='pix_relu')
168 | instance_seg_prediction = self.conv2d(
169 | inputdata=pix_relu,
170 | out_channel=CFG.TRAIN.EMBEDDING_FEATS_DIMS,
171 | kernel_size=1,
172 | use_bias=False,
173 | name='pix_embedding_conv'
174 | )
175 |
176 | return binary_seg_prediction, instance_seg_prediction
177 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/lanenet_model/lanenet_discriminative_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 18-5-11 下午3:48
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : lanenet_discriminative_loss.py
7 | # @IDE: PyCharm Community Edition
8 | """
9 | Discriminative Loss for instance segmentation
10 | """
11 | import tensorflow as tf
12 |
13 |
14 | def discriminative_loss_single(
15 | prediction,
16 | correct_label,
17 | feature_dim,
18 | label_shape,
19 | delta_v,
20 | delta_d,
21 | param_var,
22 | param_dist,
23 | param_reg):
24 | """
25 | discriminative loss
26 | :param prediction: inference of network
27 | :param correct_label: instance label
28 | :param feature_dim: feature dimension of prediction
29 | :param label_shape: shape of label
30 | :param delta_v: cut off variance distance
31 | :param delta_d: cut off cluster distance
32 | :param param_var: weight for intra cluster variance
33 | :param param_dist: weight for inter cluster distances
34 | :param param_reg: weight regularization
35 | """
36 | correct_label = tf.reshape(
37 | correct_label, [label_shape[1] * label_shape[0]]
38 | )
39 | reshaped_pred = tf.reshape(
40 | prediction, [label_shape[1] * label_shape[0], feature_dim]
41 | )
42 |
43 | # calculate instance nums
44 | unique_labels, unique_id, counts = tf.unique_with_counts(correct_label)
45 | counts = tf.cast(counts, tf.float32)
46 | num_instances = tf.size(unique_labels)
47 |
48 | # calculate instance pixel embedding mean vec
49 | segmented_sum = tf.unsorted_segment_sum(
50 | reshaped_pred, unique_id, num_instances)
51 | mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1)))
52 | mu_expand = tf.gather(mu, unique_id)
53 |
54 | distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1, ord=1)
55 | distance = tf.subtract(distance, delta_v)
56 | distance = tf.clip_by_value(distance, 0., distance)
57 | distance = tf.square(distance)
58 |
59 | l_var = tf.unsorted_segment_sum(distance, unique_id, num_instances)
60 | l_var = tf.div(l_var, counts)
61 | l_var = tf.reduce_sum(l_var)
62 | l_var = tf.divide(l_var, tf.cast(num_instances, tf.float32))
63 |
64 | mu_interleaved_rep = tf.tile(mu, [num_instances, 1])
65 | mu_band_rep = tf.tile(mu, [1, num_instances])
66 | mu_band_rep = tf.reshape(
67 | mu_band_rep,
68 | (num_instances *
69 | num_instances,
70 | feature_dim))
71 |
72 | mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
73 |
74 | intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), axis=1)
75 | zero_vector = tf.zeros(1, dtype=tf.float32)
76 | bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
77 | mu_diff_bool = tf.boolean_mask(mu_diff, bool_mask)
78 |
79 | mu_norm = tf.norm(mu_diff_bool, axis=1, ord=1)
80 | mu_norm = tf.subtract(2. * delta_d, mu_norm)
81 | mu_norm = tf.clip_by_value(mu_norm, 0., mu_norm)
82 | mu_norm = tf.square(mu_norm)
83 |
84 | l_dist = tf.reduce_mean(mu_norm)
85 |
86 | l_reg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))
87 |
88 | param_scale = 1.
89 | l_var = param_var * l_var
90 | l_dist = param_dist * l_dist
91 | l_reg = param_reg * l_reg
92 |
93 | loss = param_scale * (l_var + l_dist + l_reg)
94 |
95 | return loss, l_var, l_dist, l_reg
96 |
97 |
98 | def discriminative_loss(prediction, correct_label, feature_dim, image_shape,
99 | delta_v, delta_d, param_var, param_dist, param_reg):
100 | """
101 |
102 | :return: discriminative loss and its three components
103 | """
104 |
105 | def cond(label, batch, out_loss, out_var, out_dist, out_reg, i):
106 | return tf.less(i, tf.shape(batch)[0])
107 |
108 | def body(label, batch, out_loss, out_var, out_dist, out_reg, i):
109 | disc_loss, l_var, l_dist, l_reg = discriminative_loss_single(
110 | prediction[i], correct_label[i], feature_dim, image_shape, delta_v, delta_d, param_var, param_dist, param_reg)
111 |
112 | out_loss = out_loss.write(i, disc_loss)
113 | out_var = out_var.write(i, l_var)
114 | out_dist = out_dist.write(i, l_dist)
115 | out_reg = out_reg.write(i, l_reg)
116 |
117 | return label, batch, out_loss, out_var, out_dist, out_reg, i + 1
118 |
119 | # TensorArray is a data structure that support dynamic writing
120 | output_ta_loss = tf.TensorArray(
121 | dtype=tf.float32, size=0, dynamic_size=True)
122 | output_ta_var = tf.TensorArray(
123 | dtype=tf.float32, size=0, dynamic_size=True)
124 | output_ta_dist = tf.TensorArray(
125 | dtype=tf.float32, size=0, dynamic_size=True)
126 | output_ta_reg = tf.TensorArray(
127 | dtype=tf.float32, size=0, dynamic_size=True)
128 |
129 | _, _, out_loss_op, out_var_op, out_dist_op, out_reg_op, _ = tf.while_loop(
130 | cond, body, [
131 | correct_label, prediction, output_ta_loss, output_ta_var, output_ta_dist, output_ta_reg, 0])
132 | out_loss_op = out_loss_op.stack()
133 | out_var_op = out_var_op.stack()
134 | out_dist_op = out_dist_op.stack()
135 | out_reg_op = out_reg_op.stack()
136 |
137 | disc_loss = tf.reduce_mean(out_loss_op)
138 | l_var = tf.reduce_mean(out_var_op)
139 | l_dist = tf.reduce_mean(out_dist_op)
140 | l_reg = tf.reduce_mean(out_reg_op)
141 |
142 | return disc_loss, l_var, l_dist, l_reg
143 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/lanenet_model/lanenet_front_end.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-4-24 下午3:53
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : lanenet_front_end.py
7 | # @IDE: PyCharm
8 | """
9 | LaneNet frontend branch which is mainly used for feature extraction
10 | """
11 | from LaneDetectionLaneNet.semantic_segmentation_zoo import cnn_basenet
12 | from LaneDetectionLaneNet.semantic_segmentation_zoo import vgg16_based_fcn
13 |
14 |
15 | class LaneNetFrondEnd(cnn_basenet.CNNBaseModel):
16 | """
17 | LaneNet frontend which is used to extract image features for following process
18 | """
19 | def __init__(self, phase, net_flag):
20 | """
21 |
22 | """
23 | super(LaneNetFrondEnd, self).__init__()
24 |
25 | self._frontend_net_map = {
26 | 'vgg': vgg16_based_fcn.VGG16FCN(phase=phase)
27 | }
28 |
29 | self._net = self._frontend_net_map[net_flag]
30 |
31 | def build_model(self, input_tensor, name, reuse):
32 | """
33 |
34 | :param input_tensor:
35 | :param name:
36 | :param reuse:
37 | :return:
38 | """
39 |
40 | return self._net.build_model(
41 | input_tensor=input_tensor,
42 | name=name,
43 | reuse=reuse
44 | )
45 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/lanenet_model/lanenet_postprocess.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 18-5-30 上午10:04
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
6 | # @File : lanenet_postprocess.py
7 | # @IDE: PyCharm Community Edition
8 | """
9 | LaneNet model post process
10 | """
11 | import os.path as ops
12 | import math
13 |
14 | import cv2
15 | import glog as log
16 | import numpy as np
17 | from sklearn.cluster import DBSCAN
18 | from sklearn.preprocessing import StandardScaler
19 |
20 | from LaneDetectionLaneNet.config import global_config
21 |
22 | CFG = global_config.cfg
23 |
24 |
25 | def _morphological_process(image, kernel_size=5):
26 | """
27 | morphological process to fill the hole in the binary segmentation result
28 | :param image:
29 | :param kernel_size:
30 | :return:
31 | """
32 | if len(image.shape) == 3:
33 | raise ValueError('Binary segmentation result image should be a single channel image')
34 |
35 | if image.dtype is not np.uint8:
36 | image = np.array(image, np.uint8)
37 |
38 | kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size))
39 |
40 | # close operation fille hole
41 | closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1)
42 |
43 | return closing
44 |
45 |
46 | def _connect_components_analysis(image):
47 | """
48 | connect components analysis to remove the small components
49 | :param image:
50 | :return:
51 | """
52 | if len(image.shape) == 3:
53 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
54 | else:
55 | gray_image = image
56 |
57 | return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S)
58 |
59 |
60 | class _LaneFeat(object):
61 | """
62 |
63 | """
64 | def __init__(self, feat, coord, class_id=-1):
65 | """
66 | lane feat object
67 | :param feat: lane embeddng feats [feature_1, feature_2, ...]
68 | :param coord: lane coordinates [x, y]
69 | :param class_id: lane class id
70 | """
71 | self._feat = feat
72 | self._coord = coord
73 | self._class_id = class_id
74 |
75 | @property
76 | def feat(self):
77 | """
78 |
79 | :return:
80 | """
81 | return self._feat
82 |
83 | @feat.setter
84 | def feat(self, value):
85 | """
86 |
87 | :param value:
88 | :return:
89 | """
90 | if not isinstance(value, np.ndarray):
91 | value = np.array(value, dtype=np.float64)
92 |
93 | if value.dtype != np.float32:
94 | value = np.array(value, dtype=np.float64)
95 |
96 | self._feat = value
97 |
98 | @property
99 | def coord(self):
100 | """
101 |
102 | :return:
103 | """
104 | return self._coord
105 |
106 | @coord.setter
107 | def coord(self, value):
108 | """
109 |
110 | :param value:
111 | :return:
112 | """
113 | if not isinstance(value, np.ndarray):
114 | value = np.array(value)
115 |
116 | if value.dtype != np.int32:
117 | value = np.array(value, dtype=np.int32)
118 |
119 | self._coord = value
120 |
121 | @property
122 | def class_id(self):
123 | """
124 |
125 | :return:
126 | """
127 | return self._class_id
128 |
129 | @class_id.setter
130 | def class_id(self, value):
131 | """
132 |
133 | :param value:
134 | :return:
135 | """
136 | if not isinstance(value, np.int64):
137 | raise ValueError('Class id must be integer')
138 |
139 | self._class_id = value
140 |
141 |
142 | class _LaneNetCluster(object):
143 | """
144 | Instance segmentation result cluster
145 | """
146 |
147 | def __init__(self):
148 | """
149 |
150 | """
151 | self._color_map = [np.array([255, 0, 0]),
152 | np.array([0, 255, 0]),
153 | np.array([0, 0, 255]),
154 | np.array([125, 125, 0]),
155 | np.array([0, 125, 125]),
156 | np.array([125, 0, 125]),
157 | np.array([50, 100, 50]),
158 | np.array([100, 50, 100])]
159 |
160 | @staticmethod
161 | def _embedding_feats_dbscan_cluster(embedding_image_feats):
162 | """
163 | dbscan cluster
164 | :param embedding_image_feats:
165 | :return:
166 | """
167 | db = DBSCAN(eps=CFG.POSTPROCESS.DBSCAN_EPS, min_samples=CFG.POSTPROCESS.DBSCAN_MIN_SAMPLES)
168 | try:
169 | features = StandardScaler().fit_transform(embedding_image_feats)
170 | db.fit(features)
171 | except Exception as err:
172 | log.error(err)
173 | ret = {
174 | 'origin_features': None,
175 | 'cluster_nums': 0,
176 | 'db_labels': None,
177 | 'unique_labels': None,
178 | 'cluster_center': None
179 | }
180 | return ret
181 | db_labels = db.labels_
182 | unique_labels = np.unique(db_labels)
183 |
184 | num_clusters = len(unique_labels)
185 | cluster_centers = db.components_
186 |
187 | ret = {
188 | 'origin_features': features,
189 | 'cluster_nums': num_clusters,
190 | 'db_labels': db_labels,
191 | 'unique_labels': unique_labels,
192 | 'cluster_center': cluster_centers
193 | }
194 |
195 | return ret
196 |
197 | @staticmethod
198 | def _get_lane_embedding_feats(binary_seg_ret, instance_seg_ret):
199 | """
200 | get lane embedding features according the binary seg result
201 | :param binary_seg_ret:
202 | :param instance_seg_ret:
203 | :return:
204 | """
205 | idx = np.where(binary_seg_ret == 255)
206 | lane_embedding_feats = instance_seg_ret[idx]
207 | # idx_scale = np.vstack((idx[0] / 256.0, idx[1] / 512.0)).transpose()
208 | # lane_embedding_feats = np.hstack((lane_embedding_feats, idx_scale))
209 | lane_coordinate = np.vstack((idx[1], idx[0])).transpose()
210 |
211 | assert lane_embedding_feats.shape[0] == lane_coordinate.shape[0]
212 |
213 | ret = {
214 | 'lane_embedding_feats': lane_embedding_feats,
215 | 'lane_coordinates': lane_coordinate
216 | }
217 |
218 | return ret
219 |
220 | def apply_lane_feats_cluster(self, binary_seg_result, instance_seg_result):
221 | """
222 |
223 | :param binary_seg_result:
224 | :param instance_seg_result:
225 | :return:
226 | """
227 | # get embedding feats and coords
228 | get_lane_embedding_feats_result = self._get_lane_embedding_feats(
229 | binary_seg_ret=binary_seg_result,
230 | instance_seg_ret=instance_seg_result
231 | )
232 |
233 | # dbscan cluster
234 | dbscan_cluster_result = self._embedding_feats_dbscan_cluster(
235 | embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats']
236 | )
237 |
238 | mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1], 3], dtype=np.uint8)
239 | db_labels = dbscan_cluster_result['db_labels']
240 | unique_labels = dbscan_cluster_result['unique_labels']
241 | coord = get_lane_embedding_feats_result['lane_coordinates']
242 |
243 | if db_labels is None:
244 | return None, None
245 |
246 | lane_coords = []
247 |
248 | for index, label in enumerate(unique_labels.tolist()):
249 | if label == -1:
250 | continue
251 | idx = np.where(db_labels == label)
252 | pix_coord_idx = tuple((coord[idx][:, 1], coord[idx][:, 0]))
253 | mask[pix_coord_idx] = self._color_map[index]
254 | lane_coords.append(coord[idx])
255 |
256 | return mask, lane_coords
257 |
258 |
259 | class LaneNetPostProcessor(object):
260 | """
261 | lanenet post process for lane generation
262 | """
263 | def __init__(self, ipm_remap_file_path='./LaneDetectionLaneNet/data/tusimple_ipm_remap.yml'):
264 | """
265 |
266 | :param ipm_remap_file_path: ipm generate file path
267 | """
268 | assert ops.exists(ipm_remap_file_path), '{:s} not exist'.format(ipm_remap_file_path)
269 |
270 | self._cluster = _LaneNetCluster()
271 | self._ipm_remap_file_path = ipm_remap_file_path
272 |
273 | remap_file_load_ret = self._load_remap_matrix()
274 | self._remap_to_ipm_x = remap_file_load_ret['remap_to_ipm_x']
275 | self._remap_to_ipm_y = remap_file_load_ret['remap_to_ipm_y']
276 |
277 | self._color_map = [np.array([255, 0, 0]),
278 | np.array([0, 255, 0]),
279 | np.array([0, 0, 255]),
280 | np.array([125, 125, 0]),
281 | np.array([0, 125, 125]),
282 | np.array([125, 0, 125]),
283 | np.array([50, 100, 50]),
284 | np.array([100, 50, 100])]
285 |
286 | def _load_remap_matrix(self):
287 | """
288 |
289 | :return:
290 | """
291 | fs = cv2.FileStorage(self._ipm_remap_file_path, cv2.FILE_STORAGE_READ)
292 |
293 | remap_to_ipm_x = fs.getNode('remap_ipm_x').mat()
294 | remap_to_ipm_y = fs.getNode('remap_ipm_y').mat()
295 |
296 | ret = {
297 | 'remap_to_ipm_x': remap_to_ipm_x,
298 | 'remap_to_ipm_y': remap_to_ipm_y,
299 | }
300 |
301 | fs.release()
302 |
303 | return ret
304 |
305 | def postprocess(self, binary_seg_result, instance_seg_result=None,
306 | min_area_threshold=100, source_image=None,
307 | data_source='tusimple'):
308 | """
309 |
310 | :param binary_seg_result:
311 | :param instance_seg_result:
312 | :param min_area_threshold:
313 | :param source_image:
314 | :param data_source:
315 | :return:
316 | """
317 | # convert binary_seg_result
318 | binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8)
319 |
320 | # apply image morphology operation to fill in the hold and reduce the small area
321 | morphological_ret = _morphological_process(binary_seg_result, kernel_size=5)
322 |
323 | connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret)
324 |
325 | labels = connect_components_analysis_ret[1]
326 | stats = connect_components_analysis_ret[2]
327 | for index, stat in enumerate(stats):
328 | if stat[4] <= min_area_threshold:
329 | idx = np.where(labels == index)
330 | morphological_ret[idx] = 0
331 |
332 | # apply embedding features cluster
333 | mask_image, lane_coords = self._cluster.apply_lane_feats_cluster(
334 | binary_seg_result=morphological_ret,
335 | instance_seg_result=instance_seg_result
336 | )
337 |
338 | if mask_image is None:
339 | return {
340 | 'mask_image': None,
341 | 'fit_params': None,
342 | 'source_image': None,
343 | }
344 |
345 | # lane line fit
346 | fit_params = []
347 | src_lane_pts = [] # lane pts every single lane
348 | for lane_index, coords in enumerate(lane_coords):
349 | if data_source == 'tusimple':
350 | tmp_mask = np.zeros(shape=(720, 1280), dtype=np.uint8)
351 | tmp_mask[tuple((np.int_(coords[:, 1] * 720 / 256), np.int_(coords[:, 0] * 1280 / 512)))] = 255
352 | elif data_source == 'beec_ccd':
353 | tmp_mask = np.zeros(shape=(1350, 2448), dtype=np.uint8)
354 | tmp_mask[tuple((np.int_(coords[:, 1] * 1350 / 256), np.int_(coords[:, 0] * 2448 / 512)))] = 255
355 | else:
356 | raise ValueError('Wrong data source now only support tusimple and beec_ccd')
357 | tmp_ipm_mask = cv2.remap(
358 | tmp_mask,
359 | self._remap_to_ipm_x,
360 | self._remap_to_ipm_y,
361 | interpolation=cv2.INTER_NEAREST
362 | )
363 | nonzero_y = np.array(tmp_ipm_mask.nonzero()[0])
364 | nonzero_x = np.array(tmp_ipm_mask.nonzero()[1])
365 |
366 | fit_param = np.polyfit(nonzero_y, nonzero_x, 2)
367 | fit_params.append(fit_param)
368 |
369 | [ipm_image_height, ipm_image_width] = tmp_ipm_mask.shape
370 | plot_y = np.linspace(10, ipm_image_height, ipm_image_height - 10)
371 | fit_x = fit_param[0] * plot_y ** 2 + fit_param[1] * plot_y + fit_param[2]
372 | # fit_x = fit_param[0] * plot_y ** 3 + fit_param[1] * plot_y ** 2 + fit_param[2] * plot_y + fit_param[3]
373 |
374 | lane_pts = []
375 | for index in range(0, plot_y.shape[0], 5):
376 | src_x = self._remap_to_ipm_x[
377 | int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
378 | if src_x <= 0:
379 | continue
380 | src_y = self._remap_to_ipm_y[
381 | int(plot_y[index]), int(np.clip(fit_x[index], 0, ipm_image_width - 1))]
382 | src_y = src_y if src_y > 0 else 0
383 |
384 | lane_pts.append([src_x, src_y])
385 |
386 | src_lane_pts.append(lane_pts)
387 |
388 | # tusimple test data sample point along y axis every 10 pixels
389 | source_image_width = source_image.shape[1]
390 | for index, single_lane_pts in enumerate(src_lane_pts):
391 | single_lane_pt_x = np.array(single_lane_pts, dtype=np.float32)[:, 0]
392 | single_lane_pt_y = np.array(single_lane_pts, dtype=np.float32)[:, 1]
393 | if data_source == 'tusimple':
394 | start_plot_y = 240
395 | end_plot_y = 720
396 | elif data_source == 'beec_ccd':
397 | start_plot_y = 820
398 | end_plot_y = 1350
399 | else:
400 | raise ValueError('Wrong data source now only support tusimple and beec_ccd')
401 | step = int(math.floor((end_plot_y - start_plot_y) / 10))
402 | for plot_y in np.linspace(start_plot_y, end_plot_y, step):
403 | diff = single_lane_pt_y - plot_y
404 | fake_diff_bigger_than_zero = diff.copy()
405 | fake_diff_smaller_than_zero = diff.copy()
406 | fake_diff_bigger_than_zero[np.where(diff <= 0)] = float('inf')
407 | fake_diff_smaller_than_zero[np.where(diff > 0)] = float('-inf')
408 | idx_low = np.argmax(fake_diff_smaller_than_zero)
409 | idx_high = np.argmin(fake_diff_bigger_than_zero)
410 |
411 | previous_src_pt_x = single_lane_pt_x[idx_low]
412 | previous_src_pt_y = single_lane_pt_y[idx_low]
413 | last_src_pt_x = single_lane_pt_x[idx_high]
414 | last_src_pt_y = single_lane_pt_y[idx_high]
415 |
416 | if previous_src_pt_y < start_plot_y or last_src_pt_y < start_plot_y or \
417 | fake_diff_smaller_than_zero[idx_low] == float('-inf') or \
418 | fake_diff_bigger_than_zero[idx_high] == float('inf'):
419 | continue
420 |
421 | interpolation_src_pt_x = (abs(previous_src_pt_y - plot_y) * previous_src_pt_x +
422 | abs(last_src_pt_y - plot_y) * last_src_pt_x) / \
423 | (abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
424 | interpolation_src_pt_y = (abs(previous_src_pt_y - plot_y) * previous_src_pt_y +
425 | abs(last_src_pt_y - plot_y) * last_src_pt_y) / \
426 | (abs(previous_src_pt_y - plot_y) + abs(last_src_pt_y - plot_y))
427 |
428 | if interpolation_src_pt_x > source_image_width or interpolation_src_pt_x < 10:
429 | continue
430 |
431 | lane_color = self._color_map[index].tolist()
432 | cv2.circle(source_image, (int(interpolation_src_pt_x),
433 | int(interpolation_src_pt_y)), 5, lane_color, -1)
434 | ret = {
435 | 'mask_image': mask_image,
436 | 'fit_params': fit_params,
437 | 'source_image': source_image,
438 | }
439 |
440 | return ret
441 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/mnn_project/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 2019/11/5 下午5:03
4 | # @Author : LuoYao
5 | # @Site : ICode
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/mnn_project/config.ini:
--------------------------------------------------------------------------------
1 | [LaneNet]
2 | # 模型文件路径
3 | model_file_path=~/MNN-0.2.1.0/beec_task/lane_detection/model/lanenet_model.mnn
4 | # pixel embedding feature dims
5 | pix_embedding_feature_dims=4
6 | # dbscan邻域距离判断距离阈值
7 | dbscan_neighbor_radius=0.4
8 | # dbscan核心对象最少包含样本数
9 | dbscan_core_object_min_pts=500
10 |
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/mnn_project/config_parser.cpp:
--------------------------------------------------------------------------------
1 | /************************************************
2 | * Author: MaybeShewill-CV
3 | * File: configParser.cpp
4 | * Date: 2019/10/10 上午10:39
5 | ************************************************/
6 |
7 | #include "config_parser.h"
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | namespace beec {
17 | namespace config_parse_utils {
18 |
19 | ConfigParser::ConfigParser(const std::string &filename) {
20 |
21 | std::ifstream fin(filename);
22 |
23 | if (fin.good()) {
24 | std::string line;
25 | std::string current_header = "";
26 | while (std::getline(fin, line)) {
27 | trim(line);
28 |
29 | // Skip empty lines
30 | if (line.size() == 0)
31 | continue;
32 |
33 | switch (line[0]) {
34 | case '#':
35 | case ';':
36 | // Ignore comments
37 | break;
38 | case '[':
39 | // Section header
40 | current_header = read_header(line);
41 | break;
42 | default:
43 | // Everything else will be configurations
44 | read_configuration(line, current_header);
45 | }
46 | }
47 | fin.close();
48 | } else {
49 | throw std::runtime_error("File `" + filename + "` does not exist");
50 | }
51 | }
52 |
53 | std::map ConfigParser::get_section(const std::string §ion_name) const {
54 |
55 | if (_m_sections.count(section_name) == 0) {
56 | std::string error = "No such key: `" + section_name + "`";
57 | throw std::out_of_range(error);
58 | }
59 | return _m_sections.at(section_name);
60 | }
61 |
62 | std::map ConfigParser::operator[](const std::string §ion_name) const {
63 |
64 | if (_m_sections.count(section_name) == 0) {
65 | std::string error = "No such key: `" + section_name + "`";
66 | throw std::out_of_range(error);
67 | }
68 | return _m_sections.at(section_name);
69 | }
70 |
71 | void ConfigParser::dump(FILE *log_file) {
72 |
73 | // Set up iterators
74 | std::map::iterator itr1;
75 | std::map >::iterator itr2;
76 | for (itr2 = _m_sections.begin(); itr2 != _m_sections.end(); itr2++) {
77 | fprintf(log_file, "[%s]\n", itr2->first.c_str());
78 | for (itr1 = itr2->second.begin(); itr1 != itr2->second.end(); itr1++) {
79 | fprintf(log_file, "%s=%s\n", itr1->first.c_str(), itr1->second.c_str());
80 | }
81 | }
82 | }
83 |
84 | std::string ConfigParser::read_header(const std::string &line) {
85 |
86 | if (line[line.size() - 1] != ']')
87 | throw std::runtime_error("Invalid section header: `" + line + "`");
88 | return trim_copy(line.substr(1, line.size() - 2));
89 | }
90 |
91 | void ConfigParser::read_configuration(const std::string &line, const std::string &header) {
92 | if (header == "") {
93 | std::string error = "No section provided for: `" + line + "`";
94 | throw std::runtime_error(error);
95 | }
96 |
97 | if (line.find('=') == std::string::npos) {
98 | std::string error = "Invalid configuration: `" + line + "`";
99 | throw std::runtime_error(error);
100 | }
101 |
102 | std::istringstream iss(line);
103 | std::string key;
104 | std::string val;
105 | std::getline(iss, key, '=');
106 |
107 | if (key.size() == 0) {
108 | std::string error = "No key found in configuration: `" + line + "`";
109 | throw std::runtime_error(error);
110 | }
111 |
112 | std::getline(iss, val);
113 |
114 | _m_sections[header][trim_copy(key)] = trim_copy(val);
115 | }
116 |
117 | // trim from start (in place)
118 | void ConfigParser::ltrim(std::string &s) {
119 | s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int ch) {
120 | return !std::isspace(ch);
121 | }));
122 | }
123 |
124 | // trim from end (in place)
125 | void ConfigParser::rtrim(std::string &s) {
126 | s.erase(std::find_if(s.rbegin(), s.rend(), [](int ch) {
127 | return !std::isspace(ch);
128 | }).base(), s.end());
129 | }
130 |
131 | // trim from both ends (in place)
132 | void ConfigParser::trim(std::string &s) {
133 | ltrim(s);
134 | rtrim(s);
135 | }
136 |
137 | // trim from start (copying)
138 | std::string ConfigParser::ltrim_copy(std::string s) {
139 | ltrim(s);
140 | return s;
141 | }
142 |
143 | // trim from end (copying)
144 | std::string ConfigParser::rtrim_copy(std::string s) {
145 | rtrim(s);
146 | return s;
147 | }
148 |
149 | // trim from both ends (copying)
150 | std::string ConfigParser::trim_copy(std::string s) {
151 | trim(s);
152 | return s;
153 | }
154 | }
155 | }
--------------------------------------------------------------------------------
/LaneDetectionLaneNet/mnn_project/config_parser.h:
--------------------------------------------------------------------------------
1 | /************************************************
2 | * Author: MaybeShewill-CV
3 | * File: configParser.h
4 | * Date: 2019/10/10 上午10:39
5 | ************************************************/
6 |
7 | #ifndef MNN_CONFIGPARSER_H
8 | #define MNN_CONFIGPARSER_H
9 |
10 | // Config parser
11 |
12 | #include
13 | #include
14 | #include
15 | #include