├── LICENSE
├── README.md
├── UFLD.py
├── __pycache__
├── UFLD.cpython-37.pyc
├── UFLD.cpython-38.pyc
├── common.cpython-37.pyc
├── dataloader.cpython-37.pyc
├── flirCapture2.cpython-36.pyc
├── flirCapture2.cpython-38.pyc
├── laneDetection.cpython-35.pyc
├── laneDetection.cpython-36.pyc
├── laneDetection.cpython-37.pyc
└── laneDetection.cpython-38.pyc
├── calibration.cache
├── calibration_data
├── __pycache__
│ ├── constant.cpython-37.pyc
│ ├── dataloader.cpython-37.pyc
│ ├── dataset.cpython-37.pyc
│ └── mytransforms.cpython-37.pyc
├── constant.py
├── dataloader.py
├── dataset.py
├── make_mini_tusimple.py
└── mytransforms.py
├── common.py
├── configs
├── __pycache__
│ ├── constant.cpython-36.pyc
│ ├── constant.cpython-37.pyc
│ └── constant.cpython-38.pyc
├── constant.py
└── tusimple_4.py
├── launch_opencv.py
├── mnist_calibration.cache
├── model
├── __pycache__
│ ├── backbone.cpython-36.pyc
│ ├── backbone.cpython-37.pyc
│ ├── backbone.cpython-38.pyc
│ ├── model.cpython-36.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── model_convert.cpython-38.pyc
│ └── model_convert2.cpython-38.pyc
├── backbone.py
└── model.py
├── onnx_to_tensorrt.py
├── onnx_to_tensorrt_int8.py
├── requirement.txt
├── tensorrt_run.py
├── test_devices.py
├── torch2onnx.py
└── utils
├── __pycache__
├── common.cpython-36.pyc
├── common.cpython-37.pyc
├── common.cpython-38.pyc
├── config.cpython-36.pyc
├── config.cpython-37.pyc
├── config.cpython-38.pyc
├── dist_utils.cpython-36.pyc
├── dist_utils.cpython-37.pyc
├── dist_utils.cpython-38.pyc
├── factory.cpython-38.pyc
├── loss.cpython-38.pyc
└── metrics.cpython-38.pyc
├── common.py
├── config.py
├── dist_utils.py
├── factory.py
├── loss.py
├── metrics.py
├── onnx2trt.py
├── onnx2trt_test.py
└── onnx_to_tensorrt.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU AFFERO GENERAL PUBLIC LICENSE
2 | Version 3, 19 November 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU Affero General Public License is a free, copyleft license for
11 | software and other kinds of works, specifically designed to ensure
12 | cooperation with the community in the case of network server software.
13 |
14 | The licenses for most software and other practical works are designed
15 | to take away your freedom to share and change the works. By contrast,
16 | our General Public Licenses are intended to guarantee your freedom to
17 | share and change all versions of a program--to make sure it remains free
18 | software for all its users.
19 |
20 | When we speak of free software, we are referring to freedom, not
21 | price. Our General Public Licenses are designed to make sure that you
22 | have the freedom to distribute copies of free software (and charge for
23 | them if you wish), that you receive source code or can get it if you
24 | want it, that you can change the software or use pieces of it in new
25 | free programs, and that you know you can do these things.
26 |
27 | Developers that use our General Public Licenses protect your rights
28 | with two steps: (1) assert copyright on the software, and (2) offer
29 | you this License which gives you legal permission to copy, distribute
30 | and/or modify the software.
31 |
32 | A secondary benefit of defending all users' freedom is that
33 | improvements made in alternate versions of the program, if they
34 | receive widespread use, become available for other developers to
35 | incorporate. Many developers of free software are heartened and
36 | encouraged by the resulting cooperation. However, in the case of
37 | software used on network servers, this result may fail to come about.
38 | The GNU General Public License permits making a modified version and
39 | letting the public access it on a server without ever releasing its
40 | source code to the public.
41 |
42 | The GNU Affero General Public License is designed specifically to
43 | ensure that, in such cases, the modified source code becomes available
44 | to the community. It requires the operator of a network server to
45 | provide the source code of the modified version running there to the
46 | users of that server. Therefore, public use of a modified version, on
47 | a publicly accessible server, gives the public access to the source
48 | code of the modified version.
49 |
50 | An older license, called the Affero General Public License and
51 | published by Affero, was designed to accomplish similar goals. This is
52 | a different license, not a version of the Affero GPL, but Affero has
53 | released a new version of the Affero GPL which permits relicensing under
54 | this license.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | TERMS AND CONDITIONS
60 |
61 | 0. Definitions.
62 |
63 | "This License" refers to version 3 of the GNU Affero General Public License.
64 |
65 | "Copyright" also means copyright-like laws that apply to other kinds of
66 | works, such as semiconductor masks.
67 |
68 | "The Program" refers to any copyrightable work licensed under this
69 | License. Each licensee is addressed as "you". "Licensees" and
70 | "recipients" may be individuals or organizations.
71 |
72 | To "modify" a work means to copy from or adapt all or part of the work
73 | in a fashion requiring copyright permission, other than the making of an
74 | exact copy. The resulting work is called a "modified version" of the
75 | earlier work or a work "based on" the earlier work.
76 |
77 | A "covered work" means either the unmodified Program or a work based
78 | on the Program.
79 |
80 | To "propagate" a work means to do anything with it that, without
81 | permission, would make you directly or secondarily liable for
82 | infringement under applicable copyright law, except executing it on a
83 | computer or modifying a private copy. Propagation includes copying,
84 | distribution (with or without modification), making available to the
85 | public, and in some countries other activities as well.
86 |
87 | To "convey" a work means any kind of propagation that enables other
88 | parties to make or receive copies. Mere interaction with a user through
89 | a computer network, with no transfer of a copy, is not conveying.
90 |
91 | An interactive user interface displays "Appropriate Legal Notices"
92 | to the extent that it includes a convenient and prominently visible
93 | feature that (1) displays an appropriate copyright notice, and (2)
94 | tells the user that there is no warranty for the work (except to the
95 | extent that warranties are provided), that licensees may convey the
96 | work under this License, and how to view a copy of this License. If
97 | the interface presents a list of user commands or options, such as a
98 | menu, a prominent item in the list meets this criterion.
99 |
100 | 1. Source Code.
101 |
102 | The "source code" for a work means the preferred form of the work
103 | for making modifications to it. "Object code" means any non-source
104 | form of a work.
105 |
106 | A "Standard Interface" means an interface that either is an official
107 | standard defined by a recognized standards body, or, in the case of
108 | interfaces specified for a particular programming language, one that
109 | is widely used among developers working in that language.
110 |
111 | The "System Libraries" of an executable work include anything, other
112 | than the work as a whole, that (a) is included in the normal form of
113 | packaging a Major Component, but which is not part of that Major
114 | Component, and (b) serves only to enable use of the work with that
115 | Major Component, or to implement a Standard Interface for which an
116 | implementation is available to the public in source code form. A
117 | "Major Component", in this context, means a major essential component
118 | (kernel, window system, and so on) of the specific operating system
119 | (if any) on which the executable work runs, or a compiler used to
120 | produce the work, or an object code interpreter used to run it.
121 |
122 | The "Corresponding Source" for a work in object code form means all
123 | the source code needed to generate, install, and (for an executable
124 | work) run the object code and to modify the work, including scripts to
125 | control those activities. However, it does not include the work's
126 | System Libraries, or general-purpose tools or generally available free
127 | programs which are used unmodified in performing those activities but
128 | which are not part of the work. For example, Corresponding Source
129 | includes interface definition files associated with source files for
130 | the work, and the source code for shared libraries and dynamically
131 | linked subprograms that the work is specifically designed to require,
132 | such as by intimate data communication or control flow between those
133 | subprograms and other parts of the work.
134 |
135 | The Corresponding Source need not include anything that users
136 | can regenerate automatically from other parts of the Corresponding
137 | Source.
138 |
139 | The Corresponding Source for a work in source code form is that
140 | same work.
141 |
142 | 2. Basic Permissions.
143 |
144 | All rights granted under this License are granted for the term of
145 | copyright on the Program, and are irrevocable provided the stated
146 | conditions are met. This License explicitly affirms your unlimited
147 | permission to run the unmodified Program. The output from running a
148 | covered work is covered by this License only if the output, given its
149 | content, constitutes a covered work. This License acknowledges your
150 | rights of fair use or other equivalent, as provided by copyright law.
151 |
152 | You may make, run and propagate covered works that you do not
153 | convey, without conditions so long as your license otherwise remains
154 | in force. You may convey covered works to others for the sole purpose
155 | of having them make modifications exclusively for you, or provide you
156 | with facilities for running those works, provided that you comply with
157 | the terms of this License in conveying all material for which you do
158 | not control copyright. Those thus making or running the covered works
159 | for you must do so exclusively on your behalf, under your direction
160 | and control, on terms that prohibit them from making any copies of
161 | your copyrighted material outside their relationship with you.
162 |
163 | Conveying under any other circumstances is permitted solely under
164 | the conditions stated below. Sublicensing is not allowed; section 10
165 | makes it unnecessary.
166 |
167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168 |
169 | No covered work shall be deemed part of an effective technological
170 | measure under any applicable law fulfilling obligations under article
171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172 | similar laws prohibiting or restricting circumvention of such
173 | measures.
174 |
175 | When you convey a covered work, you waive any legal power to forbid
176 | circumvention of technological measures to the extent such circumvention
177 | is effected by exercising rights under this License with respect to
178 | the covered work, and you disclaim any intention to limit operation or
179 | modification of the work as a means of enforcing, against the work's
180 | users, your or third parties' legal rights to forbid circumvention of
181 | technological measures.
182 |
183 | 4. Conveying Verbatim Copies.
184 |
185 | You may convey verbatim copies of the Program's source code as you
186 | receive it, in any medium, provided that you conspicuously and
187 | appropriately publish on each copy an appropriate copyright notice;
188 | keep intact all notices stating that this License and any
189 | non-permissive terms added in accord with section 7 apply to the code;
190 | keep intact all notices of the absence of any warranty; and give all
191 | recipients a copy of this License along with the Program.
192 |
193 | You may charge any price or no price for each copy that you convey,
194 | and you may offer support or warranty protection for a fee.
195 |
196 | 5. Conveying Modified Source Versions.
197 |
198 | You may convey a work based on the Program, or the modifications to
199 | produce it from the Program, in the form of source code under the
200 | terms of section 4, provided that you also meet all of these conditions:
201 |
202 | a) The work must carry prominent notices stating that you modified
203 | it, and giving a relevant date.
204 |
205 | b) The work must carry prominent notices stating that it is
206 | released under this License and any conditions added under section
207 | 7. This requirement modifies the requirement in section 4 to
208 | "keep intact all notices".
209 |
210 | c) You must license the entire work, as a whole, under this
211 | License to anyone who comes into possession of a copy. This
212 | License will therefore apply, along with any applicable section 7
213 | additional terms, to the whole of the work, and all its parts,
214 | regardless of how they are packaged. This License gives no
215 | permission to license the work in any other way, but it does not
216 | invalidate such permission if you have separately received it.
217 |
218 | d) If the work has interactive user interfaces, each must display
219 | Appropriate Legal Notices; however, if the Program has interactive
220 | interfaces that do not display Appropriate Legal Notices, your
221 | work need not make them do so.
222 |
223 | A compilation of a covered work with other separate and independent
224 | works, which are not by their nature extensions of the covered work,
225 | and which are not combined with it such as to form a larger program,
226 | in or on a volume of a storage or distribution medium, is called an
227 | "aggregate" if the compilation and its resulting copyright are not
228 | used to limit the access or legal rights of the compilation's users
229 | beyond what the individual works permit. Inclusion of a covered work
230 | in an aggregate does not cause this License to apply to the other
231 | parts of the aggregate.
232 |
233 | 6. Conveying Non-Source Forms.
234 |
235 | You may convey a covered work in object code form under the terms
236 | of sections 4 and 5, provided that you also convey the
237 | machine-readable Corresponding Source under the terms of this License,
238 | in one of these ways:
239 |
240 | a) Convey the object code in, or embodied in, a physical product
241 | (including a physical distribution medium), accompanied by the
242 | Corresponding Source fixed on a durable physical medium
243 | customarily used for software interchange.
244 |
245 | b) Convey the object code in, or embodied in, a physical product
246 | (including a physical distribution medium), accompanied by a
247 | written offer, valid for at least three years and valid for as
248 | long as you offer spare parts or customer support for that product
249 | model, to give anyone who possesses the object code either (1) a
250 | copy of the Corresponding Source for all the software in the
251 | product that is covered by this License, on a durable physical
252 | medium customarily used for software interchange, for a price no
253 | more than your reasonable cost of physically performing this
254 | conveying of source, or (2) access to copy the
255 | Corresponding Source from a network server at no charge.
256 |
257 | c) Convey individual copies of the object code with a copy of the
258 | written offer to provide the Corresponding Source. This
259 | alternative is allowed only occasionally and noncommercially, and
260 | only if you received the object code with such an offer, in accord
261 | with subsection 6b.
262 |
263 | d) Convey the object code by offering access from a designated
264 | place (gratis or for a charge), and offer equivalent access to the
265 | Corresponding Source in the same way through the same place at no
266 | further charge. You need not require recipients to copy the
267 | Corresponding Source along with the object code. If the place to
268 | copy the object code is a network server, the Corresponding Source
269 | may be on a different server (operated by you or a third party)
270 | that supports equivalent copying facilities, provided you maintain
271 | clear directions next to the object code saying where to find the
272 | Corresponding Source. Regardless of what server hosts the
273 | Corresponding Source, you remain obligated to ensure that it is
274 | available for as long as needed to satisfy these requirements.
275 |
276 | e) Convey the object code using peer-to-peer transmission, provided
277 | you inform other peers where the object code and Corresponding
278 | Source of the work are being offered to the general public at no
279 | charge under subsection 6d.
280 |
281 | A separable portion of the object code, whose source code is excluded
282 | from the Corresponding Source as a System Library, need not be
283 | included in conveying the object code work.
284 |
285 | A "User Product" is either (1) a "consumer product", which means any
286 | tangible personal property which is normally used for personal, family,
287 | or household purposes, or (2) anything designed or sold for incorporation
288 | into a dwelling. In determining whether a product is a consumer product,
289 | doubtful cases shall be resolved in favor of coverage. For a particular
290 | product received by a particular user, "normally used" refers to a
291 | typical or common use of that class of product, regardless of the status
292 | of the particular user or of the way in which the particular user
293 | actually uses, or expects or is expected to use, the product. A product
294 | is a consumer product regardless of whether the product has substantial
295 | commercial, industrial or non-consumer uses, unless such uses represent
296 | the only significant mode of use of the product.
297 |
298 | "Installation Information" for a User Product means any methods,
299 | procedures, authorization keys, or other information required to install
300 | and execute modified versions of a covered work in that User Product from
301 | a modified version of its Corresponding Source. The information must
302 | suffice to ensure that the continued functioning of the modified object
303 | code is in no case prevented or interfered with solely because
304 | modification has been made.
305 |
306 | If you convey an object code work under this section in, or with, or
307 | specifically for use in, a User Product, and the conveying occurs as
308 | part of a transaction in which the right of possession and use of the
309 | User Product is transferred to the recipient in perpetuity or for a
310 | fixed term (regardless of how the transaction is characterized), the
311 | Corresponding Source conveyed under this section must be accompanied
312 | by the Installation Information. But this requirement does not apply
313 | if neither you nor any third party retains the ability to install
314 | modified object code on the User Product (for example, the work has
315 | been installed in ROM).
316 |
317 | The requirement to provide Installation Information does not include a
318 | requirement to continue to provide support service, warranty, or updates
319 | for a work that has been modified or installed by the recipient, or for
320 | the User Product in which it has been modified or installed. Access to a
321 | network may be denied when the modification itself materially and
322 | adversely affects the operation of the network or violates the rules and
323 | protocols for communication across the network.
324 |
325 | Corresponding Source conveyed, and Installation Information provided,
326 | in accord with this section must be in a format that is publicly
327 | documented (and with an implementation available to the public in
328 | source code form), and must require no special password or key for
329 | unpacking, reading or copying.
330 |
331 | 7. Additional Terms.
332 |
333 | "Additional permissions" are terms that supplement the terms of this
334 | License by making exceptions from one or more of its conditions.
335 | Additional permissions that are applicable to the entire Program shall
336 | be treated as though they were included in this License, to the extent
337 | that they are valid under applicable law. If additional permissions
338 | apply only to part of the Program, that part may be used separately
339 | under those permissions, but the entire Program remains governed by
340 | this License without regard to the additional permissions.
341 |
342 | When you convey a copy of a covered work, you may at your option
343 | remove any additional permissions from that copy, or from any part of
344 | it. (Additional permissions may be written to require their own
345 | removal in certain cases when you modify the work.) You may place
346 | additional permissions on material, added by you to a covered work,
347 | for which you have or can give appropriate copyright permission.
348 |
349 | Notwithstanding any other provision of this License, for material you
350 | add to a covered work, you may (if authorized by the copyright holders of
351 | that material) supplement the terms of this License with terms:
352 |
353 | a) Disclaiming warranty or limiting liability differently from the
354 | terms of sections 15 and 16 of this License; or
355 |
356 | b) Requiring preservation of specified reasonable legal notices or
357 | author attributions in that material or in the Appropriate Legal
358 | Notices displayed by works containing it; or
359 |
360 | c) Prohibiting misrepresentation of the origin of that material, or
361 | requiring that modified versions of such material be marked in
362 | reasonable ways as different from the original version; or
363 |
364 | d) Limiting the use for publicity purposes of names of licensors or
365 | authors of the material; or
366 |
367 | e) Declining to grant rights under trademark law for use of some
368 | trade names, trademarks, or service marks; or
369 |
370 | f) Requiring indemnification of licensors and authors of that
371 | material by anyone who conveys the material (or modified versions of
372 | it) with contractual assumptions of liability to the recipient, for
373 | any liability that these contractual assumptions directly impose on
374 | those licensors and authors.
375 |
376 | All other non-permissive additional terms are considered "further
377 | restrictions" within the meaning of section 10. If the Program as you
378 | received it, or any part of it, contains a notice stating that it is
379 | governed by this License along with a term that is a further
380 | restriction, you may remove that term. If a license document contains
381 | a further restriction but permits relicensing or conveying under this
382 | License, you may add to a covered work material governed by the terms
383 | of that license document, provided that the further restriction does
384 | not survive such relicensing or conveying.
385 |
386 | If you add terms to a covered work in accord with this section, you
387 | must place, in the relevant source files, a statement of the
388 | additional terms that apply to those files, or a notice indicating
389 | where to find the applicable terms.
390 |
391 | Additional terms, permissive or non-permissive, may be stated in the
392 | form of a separately written license, or stated as exceptions;
393 | the above requirements apply either way.
394 |
395 | 8. Termination.
396 |
397 | You may not propagate or modify a covered work except as expressly
398 | provided under this License. Any attempt otherwise to propagate or
399 | modify it is void, and will automatically terminate your rights under
400 | this License (including any patent licenses granted under the third
401 | paragraph of section 11).
402 |
403 | However, if you cease all violation of this License, then your
404 | license from a particular copyright holder is reinstated (a)
405 | provisionally, unless and until the copyright holder explicitly and
406 | finally terminates your license, and (b) permanently, if the copyright
407 | holder fails to notify you of the violation by some reasonable means
408 | prior to 60 days after the cessation.
409 |
410 | Moreover, your license from a particular copyright holder is
411 | reinstated permanently if the copyright holder notifies you of the
412 | violation by some reasonable means, this is the first time you have
413 | received notice of violation of this License (for any work) from that
414 | copyright holder, and you cure the violation prior to 30 days after
415 | your receipt of the notice.
416 |
417 | Termination of your rights under this section does not terminate the
418 | licenses of parties who have received copies or rights from you under
419 | this License. If your rights have been terminated and not permanently
420 | reinstated, you do not qualify to receive new licenses for the same
421 | material under section 10.
422 |
423 | 9. Acceptance Not Required for Having Copies.
424 |
425 | You are not required to accept this License in order to receive or
426 | run a copy of the Program. Ancillary propagation of a covered work
427 | occurring solely as a consequence of using peer-to-peer transmission
428 | to receive a copy likewise does not require acceptance. However,
429 | nothing other than this License grants you permission to propagate or
430 | modify any covered work. These actions infringe copyright if you do
431 | not accept this License. Therefore, by modifying or propagating a
432 | covered work, you indicate your acceptance of this License to do so.
433 |
434 | 10. Automatic Licensing of Downstream Recipients.
435 |
436 | Each time you convey a covered work, the recipient automatically
437 | receives a license from the original licensors, to run, modify and
438 | propagate that work, subject to this License. You are not responsible
439 | for enforcing compliance by third parties with this License.
440 |
441 | An "entity transaction" is a transaction transferring control of an
442 | organization, or substantially all assets of one, or subdividing an
443 | organization, or merging organizations. If propagation of a covered
444 | work results from an entity transaction, each party to that
445 | transaction who receives a copy of the work also receives whatever
446 | licenses to the work the party's predecessor in interest had or could
447 | give under the previous paragraph, plus a right to possession of the
448 | Corresponding Source of the work from the predecessor in interest, if
449 | the predecessor has it or can get it with reasonable efforts.
450 |
451 | You may not impose any further restrictions on the exercise of the
452 | rights granted or affirmed under this License. For example, you may
453 | not impose a license fee, royalty, or other charge for exercise of
454 | rights granted under this License, and you may not initiate litigation
455 | (including a cross-claim or counterclaim in a lawsuit) alleging that
456 | any patent claim is infringed by making, using, selling, offering for
457 | sale, or importing the Program or any portion of it.
458 |
459 | 11. Patents.
460 |
461 | A "contributor" is a copyright holder who authorizes use under this
462 | License of the Program or a work on which the Program is based. The
463 | work thus licensed is called the contributor's "contributor version".
464 |
465 | A contributor's "essential patent claims" are all patent claims
466 | owned or controlled by the contributor, whether already acquired or
467 | hereafter acquired, that would be infringed by some manner, permitted
468 | by this License, of making, using, or selling its contributor version,
469 | but do not include claims that would be infringed only as a
470 | consequence of further modification of the contributor version. For
471 | purposes of this definition, "control" includes the right to grant
472 | patent sublicenses in a manner consistent with the requirements of
473 | this License.
474 |
475 | Each contributor grants you a non-exclusive, worldwide, royalty-free
476 | patent license under the contributor's essential patent claims, to
477 | make, use, sell, offer for sale, import and otherwise run, modify and
478 | propagate the contents of its contributor version.
479 |
480 | In the following three paragraphs, a "patent license" is any express
481 | agreement or commitment, however denominated, not to enforce a patent
482 | (such as an express permission to practice a patent or covenant not to
483 | sue for patent infringement). To "grant" such a patent license to a
484 | party means to make such an agreement or commitment not to enforce a
485 | patent against the party.
486 |
487 | If you convey a covered work, knowingly relying on a patent license,
488 | and the Corresponding Source of the work is not available for anyone
489 | to copy, free of charge and under the terms of this License, through a
490 | publicly available network server or other readily accessible means,
491 | then you must either (1) cause the Corresponding Source to be so
492 | available, or (2) arrange to deprive yourself of the benefit of the
493 | patent license for this particular work, or (3) arrange, in a manner
494 | consistent with the requirements of this License, to extend the patent
495 | license to downstream recipients. "Knowingly relying" means you have
496 | actual knowledge that, but for the patent license, your conveying the
497 | covered work in a country, or your recipient's use of the covered work
498 | in a country, would infringe one or more identifiable patents in that
499 | country that you have reason to believe are valid.
500 |
501 | If, pursuant to or in connection with a single transaction or
502 | arrangement, you convey, or propagate by procuring conveyance of, a
503 | covered work, and grant a patent license to some of the parties
504 | receiving the covered work authorizing them to use, propagate, modify
505 | or convey a specific copy of the covered work, then the patent license
506 | you grant is automatically extended to all recipients of the covered
507 | work and works based on it.
508 |
509 | A patent license is "discriminatory" if it does not include within
510 | the scope of its coverage, prohibits the exercise of, or is
511 | conditioned on the non-exercise of one or more of the rights that are
512 | specifically granted under this License. You may not convey a covered
513 | work if you are a party to an arrangement with a third party that is
514 | in the business of distributing software, under which you make payment
515 | to the third party based on the extent of your activity of conveying
516 | the work, and under which the third party grants, to any of the
517 | parties who would receive the covered work from you, a discriminatory
518 | patent license (a) in connection with copies of the covered work
519 | conveyed by you (or copies made from those copies), or (b) primarily
520 | for and in connection with specific products or compilations that
521 | contain the covered work, unless you entered into that arrangement,
522 | or that patent license was granted, prior to 28 March 2007.
523 |
524 | Nothing in this License shall be construed as excluding or limiting
525 | any implied license or other defenses to infringement that may
526 | otherwise be available to you under applicable patent law.
527 |
528 | 12. No Surrender of Others' Freedom.
529 |
530 | If conditions are imposed on you (whether by court order, agreement or
531 | otherwise) that contradict the conditions of this License, they do not
532 | excuse you from the conditions of this License. If you cannot convey a
533 | covered work so as to satisfy simultaneously your obligations under this
534 | License and any other pertinent obligations, then as a consequence you may
535 | not convey it at all. For example, if you agree to terms that obligate you
536 | to collect a royalty for further conveying from those to whom you convey
537 | the Program, the only way you could satisfy both those terms and this
538 | License would be to refrain entirely from conveying the Program.
539 |
540 | 13. Remote Network Interaction; Use with the GNU General Public License.
541 |
542 | Notwithstanding any other provision of this License, if you modify the
543 | Program, your modified version must prominently offer all users
544 | interacting with it remotely through a computer network (if your version
545 | supports such interaction) an opportunity to receive the Corresponding
546 | Source of your version by providing access to the Corresponding Source
547 | from a network server at no charge, through some standard or customary
548 | means of facilitating copying of software. This Corresponding Source
549 | shall include the Corresponding Source for any work covered by version 3
550 | of the GNU General Public License that is incorporated pursuant to the
551 | following paragraph.
552 |
553 | Notwithstanding any other provision of this License, you have
554 | permission to link or combine any covered work with a work licensed
555 | under version 3 of the GNU General Public License into a single
556 | combined work, and to convey the resulting work. The terms of this
557 | License will continue to apply to the part which is the covered work,
558 | but the work with which it is combined will remain governed by version
559 | 3 of the GNU General Public License.
560 |
561 | 14. Revised Versions of this License.
562 |
563 | The Free Software Foundation may publish revised and/or new versions of
564 | the GNU Affero General Public License from time to time. Such new versions
565 | will be similar in spirit to the present version, but may differ in detail to
566 | address new problems or concerns.
567 |
568 | Each version is given a distinguishing version number. If the
569 | Program specifies that a certain numbered version of the GNU Affero General
570 | Public License "or any later version" applies to it, you have the
571 | option of following the terms and conditions either of that numbered
572 | version or of any later version published by the Free Software
573 | Foundation. If the Program does not specify a version number of the
574 | GNU Affero General Public License, you may choose any version ever published
575 | by the Free Software Foundation.
576 |
577 | If the Program specifies that a proxy can decide which future
578 | versions of the GNU Affero General Public License can be used, that proxy's
579 | public statement of acceptance of a version permanently authorizes you
580 | to choose that version for the Program.
581 |
582 | Later license versions may give you additional or different
583 | permissions. However, no additional obligations are imposed on any
584 | author or copyright holder as a result of your choosing to follow a
585 | later version.
586 |
587 | 15. Disclaimer of Warranty.
588 |
589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597 |
598 | 16. Limitation of Liability.
599 |
600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608 | SUCH DAMAGES.
609 |
610 | 17. Interpretation of Sections 15 and 16.
611 |
612 | If the disclaimer of warranty and limitation of liability provided
613 | above cannot be given local legal effect according to their terms,
614 | reviewing courts shall apply local law that most closely approximates
615 | an absolute waiver of all civil liability in connection with the
616 | Program, unless a warranty or assumption of liability accompanies a
617 | copy of the Program in return for a fee.
618 |
619 | END OF TERMS AND CONDITIONS
620 |
621 | How to Apply These Terms to Your New Programs
622 |
623 | If you develop a new program, and you want it to be of the greatest
624 | possible use to the public, the best way to achieve this is to make it
625 | free software which everyone can redistribute and change under these terms.
626 |
627 | To do so, attach the following notices to the program. It is safest
628 | to attach them to the start of each source file to most effectively
629 | state the exclusion of warranty; and each file should have at least
630 | the "copyright" line and a pointer to where the full notice is found.
631 |
632 |
633 | Copyright (C)
634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TRT_Ultra_Fast_Lane_Detect
2 |
3 | TRT_Ultra_Fast_Lane_Detect is an implementation of converting Ultra fast lane detection into tensorRT model by Python API. There are some other works in our project are listed below:
4 |
5 | - The detection procedure is encapsulated.
6 | - The pytorch model is transformed into onnx model and trt model.
7 | - The trt models have different versions: FP32, FP16, INT8.
8 | - The Tusimple data set can be compressed by /calibration_data/make_mini_tusimple.py. There are redundancies in the Tusimple data set, for only 20-th frames are used. The compressed tusimple data set takes about 1GB.
9 |
10 | The original project, model, and paper is available from https://github.com/cfzd/Ultra-Fast-Lane-Detection
11 |
12 |
13 |
14 | ### Ultra-Fast-Lane-Detection
15 |
16 | PyTorch implementation of the paper "[Ultra Fast Structure-aware Deep Lane Detection](https://arxiv.org/abs/2004.11757)".
17 |
18 | Updates: Our paper has been accepted by ECCV2020.
19 |
20 | [](https://github.com/cfzd/Ultra-Fast-Lane-Detection/blob/master/vis.jpg)
21 |
22 | The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark).
23 |
24 | Caffe model and prototxt can be found [here](https://github.com/Jade999/caffe_lane_detection).
25 |
26 |
27 |
28 | ### Trained models
29 |
30 | The trained models can be obtained by the following table:
31 |
32 | | Dataset | Metric paper | Metric This repo | Avg FPS on GTX 1080Ti | Model |
33 | | -------- | ------------ | ---------------- | --------------------- | ------------------------------------------------------------ |
34 | | Tusimple | 95.87 | 95.82 | 306 | [GoogleDrive](https://drive.google.com/file/d/1WCYyur5ZaWczH15ecmeDowrW30xcLrCn/view?usp=sharing)/[BaiduDrive(code:bghd)](https://pan.baidu.com/s/1Fjm5yVq1JDpGjh4bdgdDLA) |
35 | | CULane | 68.4 | 69.7 | 324 | [GoogleDrive](https://drive.google.com/file/d/1zXBRTw50WOzvUp6XKsi8Zrk3MUC3uFuq/view?usp=sharing)/[BaiduDrive(code:w9tw)](https://pan.baidu.com/s/19Ig0TrV8MfmFTyCvbSa4ag) |
36 |
37 |
38 |
39 | ### Installation
40 |
41 | `pip3 install -r requirement.txt`
42 |
43 |
44 |
45 | ### Convert
46 |
47 | Above all, you have to train or download a 4 lane model trained by the Ultra Fast Lane Detection pytorch version. You have to change some codes, if you want to use different lane number.
48 |
49 |
50 |
51 | Now, we have a trained pytorch model "model.pth".
52 |
53 | 1. Use torch2onnx.py to convert the the model into onnx model. You should rename your model as "model.pth". The original configuration file is configs/tusimple_4.py.
54 |
55 | `python3 configs/${config_file}.py `
56 |
57 | 2. Use onnx_to_tensorrt.py to convert the onnx model in to tensorRT model (FP16, FP32).
58 |
59 | `python3 onnx_to_tensorrt.py -p ${mode_in_fp16_or_fp32} --model ${model_name}`
60 |
61 | 3. Use onnx_to_tensorrt.py to convert the onnx model in to tensorRT model (INT8).
62 |
63 | `python3 onnx_to_tensorrt.py --model ${model_name}`
64 |
65 | 4. Run tensorrt_run.py to activate detection
66 |
67 | `python tensorrt_run.py --model ${model_name}`
68 |
69 |
70 |
71 | ### Evalutaion
72 |
73 | | | Pytorch | libtorch | tensorRT(FP32) | tensorRT(FP16) | tensorRT(int8) |
74 | | :--------: | :-----: | :------: | :------------: | :------------: | :------------: |
75 | | GTX1060 | 55fps | 55fps | 55fps | Unsupported | 99fps |
76 | | Xavier AGX | 27fps | 27fps | -- | -- | -- |
77 | | Jetson TX1 | 8fps | 8fps | 8fps | 16fps | Unsupported |
78 | | Jetson nano A01(4GB) | -- | -- | -- | 8fps | Unsupported |
79 |
80 | Where "--" denotes the experiment hasn't been completed yet.
81 | Anyone with untested equipment can send his results to the issues. The results will be adopted.
82 |
--------------------------------------------------------------------------------
/UFLD.py:
--------------------------------------------------------------------------------
1 | import torch, os, cv2
2 | from model.model import parsingNet
3 | from utils.common import merge_config
4 | from utils.dist_utils import dist_print
5 | from configs.constant import tusimple_row_anchor
6 | import torch
7 | import scipy.special, tqdm
8 | import numpy as np
9 | import torchvision.transforms as transforms
10 | from PIL import Image
11 | import time
12 | from torch.autograd import Variable
13 | import onnx
14 |
15 |
16 | class laneDetection():
17 | def __init__(self):
18 | torch.backends.cudnn.benchmark = True
19 | self.args, self.cfg = merge_config()
20 | self.cls_num_per_lane = 56
21 | self.row_anchor = tusimple_row_anchor
22 | self.net = parsingNet(pretrained = False, backbone=self.cfg.backbone, cls_dim = (self.cfg.griding_num+1, self.cls_num_per_lane, self.cfg.num_lanes), use_aux=False).cuda()
23 |
24 | state_dict = torch.load(self.cfg.test_model, map_location='cpu')['model']
25 | compatible_state_dict = {}
26 | for k, v in state_dict.items():
27 | if 'module.' in k:
28 | compatible_state_dict[k[7:]] = v
29 | else:
30 | compatible_state_dict[k] = v
31 |
32 | self.net.load_state_dict(compatible_state_dict, strict=False)
33 |
34 | #not recommend to uncommen this line
35 | self.net.eval()
36 |
37 | self.img_transforms = transforms.Compose([
38 | transforms.Resize((288, 800)),
39 | transforms.ToTensor(),
40 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
41 | ])
42 |
43 | self.img_w = 960
44 | self.img_h = 480
45 | self.scale_factor = 1
46 | self.color = [(255,0,0),(0,255,0),(0,0,255),(255,255,0)]
47 | self.idx = np.arange(self.cfg.griding_num) + 1
48 | self.idx = self.idx.reshape(-1, 1, 1)
49 |
50 | self.cpu_img = None
51 | self.gpu_img = None
52 | self.type = None
53 | self.gpu_output = None
54 | self.cpu_output = None
55 |
56 | col_sample = np.linspace(0, 800 - 1, self.cfg.griding_num)
57 | self.col_sample_w = col_sample[1] - col_sample[0]
58 |
59 | def setResolution(self, w, h):
60 | self.img_w = w
61 | self.img_h = h
62 |
63 | def getFrame(self, frame):
64 | self.cpu_img = frame
65 |
66 | def setScaleFactor(self, factor=1):
67 | self.scale_factor = factor
68 |
69 | def preprocess(self):
70 | tmp_img = cv2.cvtColor(self.cpu_img, cv2.COLOR_BGR2RGB)
71 | if self.scale_factor != 1:
72 | tmp_img = cv2.resize(tmp_img, (self.img_w//self.scale_factor, self.img_h//self.scale_factor))
73 | tmp_img = Image.fromarray(tmp_img)
74 | tmp_img = self.img_transforms(tmp_img)
75 | self.gpu_img = tmp_img.unsqueeze(0).cuda()
76 |
77 | def inference(self):
78 | self.gpu_output = self.net(self.gpu_img)
79 |
80 | def parseResults(self):
81 | self.cpu_output = self.gpu_output[0].data.cpu().numpy()
82 | self.prob = scipy.special.softmax(self.cpu_output[:-1, :, :], axis=0)
83 |
84 | self.loc = np.sum(self.prob * self.idx, axis=0)
85 | self.cpu_output = np.argmax(self.cpu_output, axis=0)
86 |
87 | self.loc[self.cpu_output == self.cfg.griding_num] = 0
88 | #self.cpu_output = self.loc
89 |
90 | # import pdb; pdb.set_trace()
91 | vis = self.cpu_img
92 | for i in range(self.loc.shape[1]):
93 | if np.sum(self.loc[:, i] > 0) > 40:
94 | for k in range(self.loc.shape[0]):
95 | if self.loc[k, i] > 0:
96 | ppp = (int(self.loc[k, i] * self.col_sample_w * self.img_w / 800) - 1, int(self.img_h * (self.row_anchor[k]/288)) - 1 )
97 | cv2.circle(vis,ppp,3, self.color[i], -1)
98 |
99 | cv2.imshow("output",vis)
100 | cv2.waitKey(1)
101 | return vis
102 |
103 |
104 |
--------------------------------------------------------------------------------
/__pycache__/UFLD.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/UFLD.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/UFLD.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/UFLD.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/common.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/common.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/flirCapture2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/flirCapture2.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/flirCapture2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/flirCapture2.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/laneDetection.cpython-35.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-35.pyc
--------------------------------------------------------------------------------
/__pycache__/laneDetection.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/laneDetection.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/laneDetection.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-38.pyc
--------------------------------------------------------------------------------
/calibration.cache:
--------------------------------------------------------------------------------
1 | TRT-7000-EntropyCalibration2
2 | input.1: 3c6966a5
3 | 127: 3c298ff7
4 | 128: 3b5df2ba
5 | 129: 3b5df2ba
6 | 130: 3b5df2ba
7 | 131: 3bd8b106
8 | 132: 3b5d4402
9 | 133: 3b151529
10 | 134: 3b5e7f99
11 | 135: 3bbe2fc8
12 | 136: 3b936c4c
13 | 137: 3b840b6f
14 | 138: 3bb62876
15 | 139: 3b44a09e
16 | 140: 3b1f11fd
17 | 141: 3b5c9d1f
18 | 142: 3b87e6e8
19 | 143: 3bd1e091
20 | 144: 3bd1e091
21 | 145: 3bed89d1
22 | 146: 3c3c2f3a
23 | 147: 3b7fc3b6
24 | 148: 3b636566
25 | 149: 3c2c5bd4
26 | 150: 3b248a2b
27 | 151: 3b69ad12
28 | 152: 3c749c21
29 | 153: 3b74d13d
30 | 154: 3c0c0a0a
31 | 155: 3c2a12fa
32 | 156: 3ade8c52
33 | 157: 3b74846d
34 | 158: 3c133ae7
35 | 159: 3c19d149
36 | 160: 3c0e6787
37 | 161: 3c1f198d
38 | 162: 3c475aa0
39 | 163: 3c020fdd
40 | 164: 3bc506df
41 | 165: 3c5efda5
42 | 166: 3b152a0b
43 | 167: 3b4bbbdd
44 | 168: 3c5adf69
45 | 169: 3c4ba73f
46 | 170: 3bcb724d
47 | 171: 3c4472d7
48 | 172: 3c03f7bf
49 | 173: 3b837231
50 | 174: 3c58ce13
51 | 175: 3c51b338
52 | 176: 3c4474b3
53 | 177: 3b98ca6b
54 | 178: 3c1a27de
55 | 179: 3bdc38ad
56 | 180: 3b06c3e8
57 | 181: 3c59fb33
58 | 182: 3ad771f3
59 | 183: 3c198c73
60 | 184: 3cb0b2f4
61 | 185: 3c6782e0
62 | 186: 3bcc34f8
63 | 187: 3be23788
64 | 188: 3be23788
65 | 189: 3b241ede
66 | 190: 3d46f62c
67 | 191: 3d4323f7
68 | 192: 3d549149
69 | 193: 3e3bfc93
70 | 195: 3e3bfc93
71 | (Unnamed Layer* 69) [Constant]_output: 3a0ac72c
72 | (Unnamed Layer* 70) [Matrix Multiply]_output: 3dd67f34
73 | (Unnamed Layer* 71) [Constant]_output: 38a25c2e
74 | (Unnamed Layer* 72) [Shuffle]_output: 38a25c2e
75 | 196: 3dd69133
76 | 197: 3dd69133
77 | (Unnamed Layer* 76) [Constant]_output: 396ebf65
78 | (Unnamed Layer* 77) [Matrix Multiply]_output: 3d8ddbd0
79 | (Unnamed Layer* 78) [Constant]_output: 38714d42
80 | (Unnamed Layer* 79) [Shuffle]_output: 38714d42
81 | 198: 3d8ddc8e
82 | 200: 3d8ddc8e
83 |
--------------------------------------------------------------------------------
/calibration_data/__pycache__/constant.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/constant.cpython-37.pyc
--------------------------------------------------------------------------------
/calibration_data/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/calibration_data/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/calibration_data/__pycache__/mytransforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/mytransforms.cpython-37.pyc
--------------------------------------------------------------------------------
/calibration_data/constant.py:
--------------------------------------------------------------------------------
1 | # row anchors are a series of pre-defined coordinates in image height to detect lanes
2 | # the row anchors are defined according to the evaluation protocol of CULane and Tusimple
3 | # since our method will resize the image to 288x800 for training, the row anchors are defined with the height of 288
4 | # you can modify these row anchors according to your training image resolution
5 |
6 | tusimple_row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112,
7 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164,
8 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216,
9 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268,
10 | 272, 276, 280, 284]
11 | culane_row_anchor = [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287]
12 |
--------------------------------------------------------------------------------
/calibration_data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 | import numpy as np
3 |
4 | import torchvision.transforms as transforms
5 | import calibration_data.mytransforms as mytransforms
6 | from calibration_data.constant import tusimple_row_anchor, culane_row_anchor
7 | from calibration_data.dataset import LaneClsDataset, LaneTestDataset
8 |
9 | def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed, num_lanes):
10 | target_transform = transforms.Compose([
11 | mytransforms.FreeScaleMask((288, 800)),
12 | mytransforms.MaskToTensor(),
13 | ])
14 | segment_transform = transforms.Compose([
15 | mytransforms.FreeScaleMask((36, 100)),
16 | mytransforms.MaskToTensor(),
17 | ])
18 | img_transform = transforms.Compose([
19 | transforms.Resize((288, 800)),
20 | transforms.ToTensor(),
21 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
22 | ])
23 | simu_transform = mytransforms.Compose2([
24 | mytransforms.RandomRotate(6),
25 | mytransforms.RandomUDoffsetLABEL(100),
26 | mytransforms.RandomLROffsetLABEL(200)
27 | ])
28 | if dataset == 'CULane':
29 | train_dataset = LaneClsDataset(data_root,
30 | os.path.join(data_root, 'list/train_gt.txt'),
31 | img_transform=img_transform, target_transform=target_transform,
32 | simu_transform = simu_transform,
33 | segment_transform=segment_transform,
34 | row_anchor = culane_row_anchor,
35 | griding_num=griding_num, use_aux=use_aux, num_lanes = num_lanes)
36 | cls_num_per_lane = 18
37 |
38 | elif dataset == 'Tusimple':
39 | train_dataset = LaneClsDataset(data_root,
40 | os.path.join(data_root, 'train_gt.txt'),
41 | img_transform=img_transform, target_transform=target_transform,
42 | simu_transform = simu_transform,
43 | griding_num=griding_num,
44 | row_anchor = tusimple_row_anchor,
45 | segment_transform=segment_transform,use_aux=use_aux, num_lanes = num_lanes)
46 | cls_num_per_lane = 56
47 | else:
48 | raise NotImplementedError
49 |
50 | if distributed:
51 | sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
52 | else:
53 | sampler = torch.utils.data.RandomSampler(train_dataset)
54 |
55 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4)
56 |
57 | return train_loader, cls_num_per_lane
58 |
59 | def get_test_loader(batch_size, data_root,dataset, distributed):
60 | img_transforms = transforms.Compose([
61 | transforms.Resize((288, 800)),
62 | transforms.ToTensor(),
63 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
64 | ])
65 | if dataset == 'CULane':
66 | test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'list/test.txt'),img_transform = img_transforms)
67 | cls_num_per_lane = 18
68 | elif dataset == 'Tusimple':
69 | test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'test.txt'), img_transform = img_transforms)
70 | cls_num_per_lane = 56
71 |
72 | if distributed:
73 | sampler = SeqDistributedSampler(test_dataset, shuffle = False)
74 | else:
75 | sampler = torch.utils.data.SequentialSampler(test_dataset)
76 | loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler = sampler, num_workers=4)
77 | return loader
78 |
79 |
80 | class SeqDistributedSampler(torch.utils.data.distributed.DistributedSampler):
81 | '''
82 | Change the behavior of DistributedSampler to sequential distributed sampling.
83 | The sequential sampling helps the stability of multi-thread testing, which needs multi-thread file io.
84 | Without sequentially sampling, the file io on thread may interfere other threads.
85 | '''
86 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
87 | super().__init__(dataset, num_replicas, rank, shuffle)
88 | def __iter__(self):
89 | g = torch.Generator()
90 | g.manual_seed(self.epoch)
91 | if self.shuffle:
92 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
93 | else:
94 | indices = list(range(len(self.dataset)))
95 |
96 |
97 | # add extra samples to make it evenly divisible
98 | indices += indices[:(self.total_size - len(indices))]
99 | assert len(indices) == self.total_size
100 |
101 |
102 | num_per_rank = int(self.total_size // self.num_replicas)
103 |
104 | # sequential sampling
105 | indices = indices[num_per_rank * self.rank : num_per_rank * (self.rank + 1)]
106 |
107 | assert len(indices) == self.num_samples
108 |
109 | return iter(indices)
110 |
--------------------------------------------------------------------------------
/calibration_data/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import os
4 | import pdb
5 | import numpy as np
6 | import cv2
7 | from calibration_data.mytransforms import find_start_pos
8 |
9 |
10 | def loader_func(path):
11 | return Image.open(path)
12 |
13 |
14 | class LaneTestDataset(torch.utils.data.Dataset):
15 | def __init__(self, path, list_path, img_transform=None):
16 | super(LaneTestDataset, self).__init__()
17 | self.path = path
18 | self.img_transform = img_transform
19 | with open(list_path, 'r') as f:
20 | self.list = f.readlines()
21 | self.list = [l[1:] if l[0] == '/' else l for l in self.list] # exclude the incorrect path prefix '/' of CULane
22 |
23 |
24 | def __getitem__(self, index):
25 | name = self.list[index].split()[0]
26 | img_path = os.path.join(self.path, name)
27 | img = loader_func(img_path)
28 |
29 | if self.img_transform is not None:
30 | img = self.img_transform(img)
31 |
32 | return img, name
33 |
34 | def __len__(self):
35 | return len(self.list)
36 |
37 |
38 | class LaneClsDataset(torch.utils.data.Dataset):
39 | def __init__(self, path, list_path, img_transform = None,target_transform = None,simu_transform = None, griding_num=50, load_name = False,
40 | row_anchor = None,use_aux=False,segment_transform=None, num_lanes = 8):
41 | super(LaneClsDataset, self).__init__()
42 | self.img_transform = img_transform
43 | self.target_transform = target_transform
44 | self.segment_transform = segment_transform
45 | self.simu_transform = simu_transform
46 | self.path = path
47 | self.griding_num = griding_num
48 | self.load_name = load_name
49 | self.use_aux = use_aux
50 | self.num_lanes = num_lanes
51 |
52 | with open(list_path, 'r') as f:
53 | self.list = f.readlines()
54 |
55 | self.row_anchor = row_anchor
56 | self.row_anchor.sort()
57 |
58 | def __getitem__(self, index):
59 | l = self.list[index]
60 | l_info = l.split()
61 | img_name, label_name = l_info[0], l_info[1]
62 | if img_name[0] == '/':
63 | img_name = img_name[1:]
64 | label_name = label_name[1:]
65 |
66 | label_path = os.path.join(self.path, label_name)
67 | label = loader_func(label_path)
68 |
69 | img_path = os.path.join(self.path, img_name)
70 | img = loader_func(img_path)
71 |
72 |
73 | if self.simu_transform is not None:
74 | img, label = self.simu_transform(img, label)
75 | lane_pts = self._get_index(label)
76 | # get the coordinates of lanes at row anchors
77 |
78 |
79 |
80 | w, h = img.size
81 | cls_label = self._grid_pts(lane_pts, self.griding_num, w)
82 | # make the coordinates to classification label
83 | if self.use_aux:
84 | assert self.segment_transform is not None
85 | seg_label = self.segment_transform(label)
86 |
87 | if self.img_transform is not None:
88 | img = self.img_transform(img)
89 |
90 | if self.use_aux:
91 | return img, cls_label, seg_label
92 | if self.load_name:
93 | return img, cls_label, img_name
94 | return img, cls_label
95 |
96 | def __len__(self):
97 | return len(self.list)
98 |
99 | def _grid_pts(self, pts, num_cols, w):
100 | # pts : numlane,n,2
101 | num_lane, n, n2 = pts.shape
102 | col_sample = np.linspace(0, w - 1, num_cols)
103 |
104 | assert n2 == 2
105 | to_pts = np.zeros((n, num_lane))
106 | for i in range(num_lane):
107 | pti = pts[i, :, 1]
108 | to_pts[:, i] = np.asarray(
109 | [int(pt // (col_sample[1] - col_sample[0])) if pt != -1 else num_cols for pt in pti])
110 | return to_pts.astype(int)
111 |
112 | def _get_index(self, label):
113 | w, h = label.size
114 |
115 | if h != 288:
116 | scale_f = lambda x : int((x * 1.0/288) * h)
117 | sample_tmp = list(map(scale_f,self.row_anchor))
118 |
119 | all_idx = np.zeros((self.num_lanes,len(sample_tmp),2))
120 | for i,r in enumerate(sample_tmp):
121 | label_r = np.asarray(label)[int(round(r))]
122 | for lane_idx in range(1, self.num_lanes + 1):
123 | pos = np.where(label_r == lane_idx)[0]
124 | if len(pos) == 0:
125 | all_idx[lane_idx - 1, i, 0] = r
126 | all_idx[lane_idx - 1, i, 1] = -1
127 | continue
128 | pos = np.mean(pos)
129 | all_idx[lane_idx - 1, i, 0] = r
130 | all_idx[lane_idx - 1, i, 1] = pos
131 |
132 | # data augmentation: extend the lane to the boundary of image
133 |
134 | all_idx_cp = all_idx.copy()
135 | for i in range(self.num_lanes):
136 | if np.all(all_idx_cp[i,:,1] == -1):
137 | continue
138 | # if there is no lane
139 |
140 | valid = all_idx_cp[i,:,1] != -1
141 | # get all valid lane points' index
142 | valid_idx = all_idx_cp[i,valid,:]
143 | # get all valid lane points
144 | if valid_idx[-1,0] == all_idx_cp[0,-1,0]:
145 | # if the last valid lane point's y-coordinate is already the last y-coordinate of all rows
146 | # this means this lane has reached the bottom boundary of the image
147 | # so we skip
148 | continue
149 | if len(valid_idx) < 6:
150 | continue
151 | # if the lane is too short to extend
152 |
153 | valid_idx_half = valid_idx[len(valid_idx) // 2:,:]
154 | p = np.polyfit(valid_idx_half[:,0], valid_idx_half[:,1],deg = 1)
155 | start_line = valid_idx_half[-1,0]
156 | pos = find_start_pos(all_idx_cp[i,:,0],start_line) + 1
157 |
158 | fitted = np.polyval(p,all_idx_cp[i,pos:,0])
159 | fitted = np.array([-1 if y < 0 or y > w-1 else y for y in fitted])
160 |
161 | assert np.all(all_idx_cp[i,pos:,1] == -1)
162 | all_idx_cp[i,pos:,1] = fitted
163 | if -1 in all_idx[:, :, 0]:
164 | pdb.set_trace()
165 | return all_idx_cp
166 |
--------------------------------------------------------------------------------
/calibration_data/make_mini_tusimple.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | path ='Tusimple/clips'
4 |
5 | def get_filelist(dir):
6 | for home, dirs, files in os.walk(path):
7 | for filename in files:
8 | if filename == "20.jpg" or filename == "20.png":
9 | continue
10 | else:
11 | print(filename)
12 | os.remove(os.path.join(home, filename))
13 |
14 | if __name__ =="__main__":
15 | get_filelist(path)
16 |
--------------------------------------------------------------------------------
/calibration_data/mytransforms.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | import random
3 | import numpy as np
4 | from PIL import Image, ImageOps, ImageFilter
5 | #from config import cfg
6 | import torch
7 | import pdb
8 | import cv2
9 |
10 | # ===============================img tranforms============================
11 |
12 | class Compose2(object):
13 | def __init__(self, transforms):
14 | self.transforms = transforms
15 |
16 | def __call__(self, img, mask, bbx=None):
17 | if bbx is None:
18 | for t in self.transforms:
19 | img, mask = t(img, mask)
20 | return img, mask
21 | for t in self.transforms:
22 | img, mask, bbx = t(img, mask, bbx)
23 | return img, mask, bbx
24 |
25 | class FreeScale(object):
26 | def __init__(self, size):
27 | self.size = size # (h, w)
28 |
29 | def __call__(self, img, mask):
30 | return img.resize((self.size[1], self.size[0]), Image.BILINEAR), mask.resize((self.size[1], self.size[0]), Image.NEAREST)
31 |
32 | class FreeScaleMask(object):
33 | def __init__(self,size):
34 | self.size = size
35 | def __call__(self,mask):
36 | return mask.resize((self.size[1], self.size[0]), Image.NEAREST)
37 |
38 | class Scale(object):
39 | def __init__(self, size):
40 | self.size = size
41 |
42 | def __call__(self, img, mask):
43 | if img.size != mask.size:
44 | print(img.size)
45 | print(mask.size)
46 | assert img.size == mask.size
47 | w, h = img.size
48 | if (w <= h and w == self.size) or (h <= w and h == self.size):
49 | return img, mask
50 | if w < h:
51 | ow = self.size
52 | oh = int(self.size * h / w)
53 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)
54 | else:
55 | oh = self.size
56 | ow = int(self.size * w / h)
57 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)
58 |
59 |
60 | class RandomRotate(object):
61 | """Crops the given PIL.Image at a random location to have a region of
62 | the given size. size can be a tuple (target_height, target_width)
63 | or an integer, in which case the target will be of a square shape (size, size)
64 | """
65 |
66 | def __init__(self, angle):
67 | self.angle = angle
68 |
69 | def __call__(self, image, label):
70 | #assert label is None or image.size == label.size
71 | #assert label is None or image.size == label.size
72 |
73 | angle = random.randint(0, self.angle * 2) - self.angle
74 |
75 | label = label.rotate(angle, resample=Image.NEAREST)
76 | image = image.rotate(angle, resample=Image.BILINEAR)
77 |
78 | return image, label
79 |
80 |
81 |
82 | # ===============================label tranforms============================
83 |
84 | class DeNormalize(object):
85 | def __init__(self, mean, std):
86 | self.mean = mean
87 | self.std = std
88 |
89 | def __call__(self, tensor):
90 | for t, m, s in zip(tensor, self.mean, self.std):
91 | t.mul_(s).add_(m)
92 | return tensor
93 |
94 |
95 | class MaskToTensor(object):
96 | def __call__(self, img):
97 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
98 |
99 |
100 | def find_start_pos(row_sample,start_line):
101 | # row_sample = row_sample.sort()
102 | # for i,r in enumerate(row_sample):
103 | # if r >= start_line:
104 | # return i
105 | l,r = 0,len(row_sample)-1
106 | while True:
107 | mid = int((l+r)/2)
108 | if r - l == 1:
109 | return r
110 | if row_sample[mid] < start_line:
111 | l = mid
112 | if row_sample[mid] > start_line:
113 | r = mid
114 | if row_sample[mid] == start_line:
115 | return mid
116 |
117 | class RandomLROffsetLABEL(object):
118 | def __init__(self,max_offset):
119 | self.max_offset = max_offset
120 | def __call__(self,img,label):
121 | offset = np.random.randint(-self.max_offset,self.max_offset)
122 | w, h = img.size
123 |
124 | img = np.array(img)
125 | if offset > 0:
126 | img[:,offset:,:] = img[:,0:w-offset,:]
127 | img[:,:offset,:] = 0
128 | if offset < 0:
129 | real_offset = -offset
130 | img[:,0:w-real_offset,:] = img[:,real_offset:,:]
131 | img[:,w-real_offset:,:] = 0
132 |
133 | label = np.array(label)
134 | if offset > 0:
135 | label[:,offset:] = label[:,0:w-offset]
136 | label[:,:offset] = 0
137 | if offset < 0:
138 | offset = -offset
139 | label[:,0:w-offset] = label[:,offset:]
140 | label[:,w-offset:] = 0
141 | return Image.fromarray(img),Image.fromarray(label)
142 |
143 | class RandomUDoffsetLABEL(object):
144 | def __init__(self,max_offset):
145 | self.max_offset = max_offset
146 | def __call__(self,img,label):
147 | offset = np.random.randint(-self.max_offset,self.max_offset)
148 | w, h = img.size
149 |
150 | img = np.array(img)
151 | if offset > 0:
152 | img[offset:,:,:] = img[0:h-offset,:,:]
153 | img[:offset,:,:] = 0
154 | if offset < 0:
155 | real_offset = -offset
156 | img[0:h-real_offset,:,:] = img[real_offset:,:,:]
157 | img[h-real_offset:,:,:] = 0
158 |
159 | label = np.array(label)
160 | if offset > 0:
161 | label[offset:,:] = label[0:h-offset,:]
162 | label[:offset,:] = 0
163 | if offset < 0:
164 | offset = -offset
165 | label[0:h-offset,:] = label[offset:,:]
166 | label[h-offset:,:] = 0
167 | return Image.fromarray(img),Image.fromarray(label)
168 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | #
2 | # Copyright 1993-2019 NVIDIA Corporation. All rights reserved.
3 | #
4 | # NOTICE TO LICENSEE:
5 | #
6 | # This source code and/or documentation ("Licensed Deliverables") are
7 | # subject to NVIDIA intellectual property rights under U.S. and
8 | # international Copyright laws.
9 | #
10 | # These Licensed Deliverables contained herein is PROPRIETARY and
11 | # CONFIDENTIAL to NVIDIA and is being provided under the terms and
12 | # conditions of a form of NVIDIA software license agreement by and
13 | # between NVIDIA and Licensee ("License Agreement") or electronically
14 | # accepted by Licensee. Notwithstanding any terms or conditions to
15 | # the contrary in the License Agreement, reproduction or disclosure
16 | # of the Licensed Deliverables to any third party without the express
17 | # written consent of NVIDIA is prohibited.
18 | #
19 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
20 | # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
21 | # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
22 | # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
23 | # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
24 | # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
25 | # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
26 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
27 | # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
28 | # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
29 | # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
30 | # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
31 | # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
32 | # OF THESE LICENSED DELIVERABLES.
33 | #
34 | # U.S. Government End Users. These Licensed Deliverables are a
35 | # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
36 | # 1995), consisting of "commercial computer software" and "commercial
37 | # computer software documentation" as such terms are used in 48
38 | # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
39 | # only as a commercial end item. Consistent with 48 C.F.R.12.212 and
40 | # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
41 | # U.S. Government End Users acquire the Licensed Deliverables with
42 | # only those rights set forth herein.
43 | #
44 | # Any use of the Licensed Deliverables in individual and commercial
45 | # software must include, in the user documentation and internal
46 | # comments to the code, the above Disclaimer and U.S. Government End
47 | # Users Notice.
48 | #
49 |
50 | from itertools import chain
51 | import argparse
52 | import os
53 |
54 | import pycuda.driver as cuda
55 | import pycuda.autoinit
56 | import numpy as np
57 |
58 | import tensorrt as trt
59 |
60 | try:
61 | # Sometimes python2 does not understand FileNotFoundError
62 | FileNotFoundError
63 | except NameError:
64 | FileNotFoundError = IOError
65 |
66 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
67 |
68 | def GiB(val):
69 | return val * 1 << 30
70 |
71 |
72 | def add_help(description):
73 | parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
74 | args, _ = parser.parse_known_args()
75 |
76 |
77 | def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):
78 | '''
79 | Parses sample arguments.
80 |
81 | Args:
82 | description (str): Description of the sample.
83 | subfolder (str): The subfolder containing data relevant to this sample
84 | find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
85 |
86 | Returns:
87 | str: Path of data directory.
88 | '''
89 |
90 | # Standard command-line arguments for all samples.
91 | kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
92 | parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
93 | parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT])
94 | args, _ = parser.parse_known_args()
95 |
96 | def get_data_path(data_dir):
97 | # If the subfolder exists, append it to the path, otherwise use the provided path as-is.
98 | data_path = os.path.join(data_dir, subfolder)
99 | if not os.path.exists(data_path):
100 | print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
101 | data_path = data_dir
102 | # Make sure data directory exists.
103 | if not (os.path.exists(data_path)):
104 | print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
105 | return data_path
106 |
107 | data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
108 | return data_paths, locate_files(data_paths, find_files)
109 |
110 | def locate_files(data_paths, filenames):
111 | """
112 | Locates the specified files in the specified data directories.
113 | If a file exists in multiple data directories, the first directory is used.
114 |
115 | Args:
116 | data_paths (List[str]): The data directories.
117 | filename (List[str]): The names of the files to find.
118 |
119 | Returns:
120 | List[str]: The absolute paths of the files.
121 |
122 | Raises:
123 | FileNotFoundError if a file could not be located.
124 | """
125 | found_files = [None] * len(filenames)
126 | for data_path in data_paths:
127 | # Find all requested files.
128 | for index, (found, filename) in enumerate(zip(found_files, filenames)):
129 | if not found:
130 | file_path = os.path.abspath(os.path.join(data_path, filename))
131 | if os.path.exists(file_path):
132 | found_files[index] = file_path
133 |
134 | # Check that all files were found
135 | for f, filename in zip(found_files, filenames):
136 | if not f or not os.path.exists(f):
137 | raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths))
138 | return found_files
139 |
140 | # Simple helper data class that's a little nicer to use than a 2-tuple.
141 | class HostDeviceMem(object):
142 | def __init__(self, host_mem, device_mem):
143 | self.host = host_mem
144 | self.device = device_mem
145 |
146 | def __str__(self):
147 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
148 |
149 | def __repr__(self):
150 | return self.__str__()
151 |
152 | # Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
153 | def allocate_buffers(engine):
154 | inputs = []
155 | outputs = []
156 | bindings = []
157 | stream = cuda.Stream()
158 | for binding in engine:
159 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
160 | dtype = trt.nptype(engine.get_binding_dtype(binding))
161 | # Allocate host and device buffers
162 | host_mem = cuda.pagelocked_empty(size, dtype)
163 | device_mem = cuda.mem_alloc(host_mem.nbytes)
164 | # Append the device buffer to device bindings.
165 | bindings.append(int(device_mem))
166 | # Append to the appropriate list.
167 | if engine.binding_is_input(binding):
168 | inputs.append(HostDeviceMem(host_mem, device_mem))
169 | else:
170 | outputs.append(HostDeviceMem(host_mem, device_mem))
171 | return inputs, outputs, bindings, stream
172 |
173 | # This function is generalized for multiple inputs/outputs.
174 | # inputs and outputs are expected to be lists of HostDeviceMem objects.
175 | def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):
176 | # Transfer input data to the GPU.
177 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
178 | # Run inference.
179 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
180 | # Transfer predictions back from the GPU.
181 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
182 | # Synchronize the stream
183 | stream.synchronize()
184 | # Return only the host outputs.
185 | return [out.host for out in outputs]
186 |
187 | # This function is generalized for multiple inputs/outputs for full dimension networks.
188 | # inputs and outputs are expected to be lists of HostDeviceMem objects.
189 | def do_inference_v2(context, bindings, inputs, outputs, stream):
190 | # Transfer input data to the GPU.
191 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
192 | # Run inference.
193 | context.execute_async_v2(bindings=bindings, stream_handle=stream.handle)
194 | # Transfer predictions back from the GPU.
195 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
196 | # Synchronize the stream
197 | stream.synchronize()
198 | # Return only the host outputs.
199 | return [out.host for out in outputs]
200 |
--------------------------------------------------------------------------------
/configs/__pycache__/constant.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-36.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/constant.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-37.pyc
--------------------------------------------------------------------------------
/configs/__pycache__/constant.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-38.pyc
--------------------------------------------------------------------------------
/configs/constant.py:
--------------------------------------------------------------------------------
1 | # row anchors are a series of pre-defined coordinates in image height to detect lanes
2 | # the row anchors are defined according to the evaluation protocol of CULane and Tusimple
3 | # since our method will resize the image to 288x800 for training, the row anchors are defined with the height of 288
4 | # you can modify these row anchors according to your training image resolution
5 |
6 | tusimple_row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112,
7 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164,
8 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216,
9 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268,
10 | 272, 276, 280, 284]
11 | culane_row_anchor = [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287]
12 |
--------------------------------------------------------------------------------
/configs/tusimple_4.py:
--------------------------------------------------------------------------------
1 | # DATA
2 | dataset='Tusimple'
3 | data_root = "./data/Tusimple_ours"
4 |
5 | # TRAIN
6 | epoch = 60
7 | batch_size = 16
8 | optimizer = 'Adam' #['SGD','Adam']
9 | # learning_rate = 0.1
10 | learning_rate = 2e-4
11 | weight_decay = 1e-4
12 | momentum = 0.9
13 |
14 | scheduler = 'cos' #['multi', 'cos']
15 | # steps = [50,75]
16 | gamma = 0.1
17 | warmup = 'linear'
18 | warmup_iters = 100
19 |
20 | # NETWORK
21 | backbone = '18'
22 | griding_num = 100
23 | use_aux = True
24 |
25 | # LOSS
26 | sim_loss_w = 1.0
27 | shp_loss_w = 0.1
28 |
29 | # EXP
30 | note = ''
31 |
32 | log_path = '/media/kyle/Seagate/train_lane/train_lane4.log'
33 |
34 | # FINETUNE or RESUME MODEL PATH
35 | finetune = None
36 | resume = None
37 |
38 | # TEST
39 | test_model = None
40 | test_work_dir = "./data/Tusimple_ours"
41 |
42 | num_lanes = 4
43 |
44 |
--------------------------------------------------------------------------------
/launch_opencv.py:
--------------------------------------------------------------------------------
1 | from UFLD import *
2 | import cv2
3 | import threading
4 | import time
5 | import numpy as np
6 |
7 | detector = laneDetection()
8 | detector.setResolution(640, 480)
9 | detector.setScaleFactor(4)
10 | frame = 0
11 | currentImage = None
12 |
13 |
14 | def threadDetect():
15 | print("Detection initiating")
16 | global detector, currentImage, frame
17 | fps = []
18 | print("Waiting for camera")
19 | time.sleep(1)
20 | ret = True
21 | print("Detection Begins:")
22 | while ret:
23 | t1 = time.time()
24 | detector.getFrame(currentImage)
25 | detector.preprocess()
26 | detector.inference()
27 | detector.parseResults()
28 | t2 = time.time()
29 | if frame > 30:
30 | fps.append(1/(t2-t1))
31 | print("\ravg FPS: "+str(np.mean(fps)), end="", flush=True)
32 |
33 |
34 |
35 | if __name__ == "__main__":
36 |
37 | cap = cv2.VideoCapture(2)
38 |
39 | detecting = threading.Thread(target=threadDetect)
40 | detecting.setDaemon(True)
41 | detecting.start()
42 |
43 |
44 | while True:
45 | _, currentImage = cap.read()
46 | frame += 1
47 |
48 | detecting.join()
49 |
50 |
51 |
--------------------------------------------------------------------------------
/mnist_calibration.cache:
--------------------------------------------------------------------------------
1 | TRT-7000-EntropyCalibration2
2 | input.1: 3c6966a5
3 | 127: 3c298ff7
4 | 128: 3b5df2ba
5 | 129: 3b5df2ba
6 | 130: 3b5df2ba
7 | 131: 3bd8b106
8 | 132: 3b5d4402
9 | 133: 3b151529
10 | 134: 3b5e7f99
11 | 135: 3bbe2fc8
12 | 136: 3b936c4c
13 | 137: 3b840b6f
14 | 138: 3bb62876
15 | 139: 3b44a09e
16 | 140: 3b1f11fd
17 | 141: 3b5c9d1f
18 | 142: 3b87e6e8
19 | 143: 3bd1e091
20 | 144: 3bd1e091
21 | 145: 3bed89d1
22 | 146: 3c3c2f3a
23 | 147: 3b7fc3b6
24 | 148: 3b636566
25 | 149: 3c2c5bd4
26 | 150: 3b248a2b
27 | 151: 3b69ad12
28 | 152: 3c749c21
29 | 153: 3b74d13d
30 | 154: 3c0c0a0a
31 | 155: 3c2a12fa
32 | 156: 3ade8c52
33 | 157: 3b74846d
34 | 158: 3c133ae7
35 | 159: 3c19d149
36 | 160: 3c0e6787
37 | 161: 3c1f198d
38 | 162: 3c475aa0
39 | 163: 3c020fdd
40 | 164: 3bc506df
41 | 165: 3c5efda5
42 | 166: 3b152a0b
43 | 167: 3b4bbbdd
44 | 168: 3c5adf69
45 | 169: 3c4ba73f
46 | 170: 3bcb724d
47 | 171: 3c4472d7
48 | 172: 3c03f7bf
49 | 173: 3b837231
50 | 174: 3c58ce13
51 | 175: 3c51b338
52 | 176: 3c4474b3
53 | 177: 3b98ca6b
54 | 178: 3c1a27de
55 | 179: 3bdc38ad
56 | 180: 3b06c3e8
57 | 181: 3c59fb33
58 | 182: 3ad771f3
59 | 183: 3c198c73
60 | 184: 3cb0b2f4
61 | 185: 3c6782e0
62 | 186: 3bcc34f8
63 | 187: 3be23788
64 | 188: 3be23788
65 | 189: 3b241ede
66 | 190: 3d46f62c
67 | 191: 3d4323f7
68 | 192: 3d549149
69 | 193: 3e3bfc93
70 | 195: 3e3bfc93
71 | (Unnamed Layer* 69) [Constant]_output: 3a0ac72c
72 | (Unnamed Layer* 70) [Matrix Multiply]_output: 3dd67f34
73 | (Unnamed Layer* 71) [Constant]_output: 38a25c2e
74 | (Unnamed Layer* 72) [Shuffle]_output: 38a25c2e
75 | 196: 3dd69133
76 | 197: 3dd69133
77 | (Unnamed Layer* 76) [Constant]_output: 396ebf65
78 | (Unnamed Layer* 77) [Matrix Multiply]_output: 3d8ddbd0
79 | (Unnamed Layer* 78) [Constant]_output: 38714d42
80 | (Unnamed Layer* 79) [Shuffle]_output: 38714d42
81 | 198: 3d8ddc8e
82 | 200: 3d8ddc8e
83 |
--------------------------------------------------------------------------------
/model/__pycache__/backbone.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/backbone.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/backbone.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model_convert.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model_convert.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model_convert2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model_convert2.cpython-38.pyc
--------------------------------------------------------------------------------
/model/backbone.py:
--------------------------------------------------------------------------------
1 | import torch,pdb
2 | import torchvision
3 | import torch.nn.modules
4 |
5 | class vgg16bn(torch.nn.Module):
6 | def __init__(self,pretrained = False):
7 | super(vgg16bn,self).__init__()
8 | model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children())
9 | model = model[:33]+model[34:43]
10 | self.model = torch.nn.Sequential(*model)
11 |
12 | def forward(self,x):
13 | return self.model(x)
14 | class resnet(torch.nn.Module):
15 | def __init__(self,layers,pretrained = False):
16 | super(resnet,self).__init__()
17 | if layers == '18':
18 | model = torchvision.models.resnet18(pretrained=pretrained)
19 | elif layers == '34':
20 | model = torchvision.models.resnet34(pretrained=pretrained)
21 | elif layers == '50':
22 | model = torchvision.models.resnet50(pretrained=pretrained)
23 | elif layers == '101':
24 | model = torchvision.models.resnet101(pretrained=pretrained)
25 | elif layers == '152':
26 | model = torchvision.models.resnet152(pretrained=pretrained)
27 | elif layers == '50next':
28 | model = torchvision.models.resnext50_32x4d(pretrained=pretrained)
29 | elif layers == '101next':
30 | model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
31 | elif layers == '50wide':
32 | model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
33 | elif layers == '101wide':
34 | model = torchvision.models.wide_resnet101_2(pretrained=pretrained)
35 | else:
36 | raise NotImplementedError
37 |
38 | self.conv1 = model.conv1
39 | self.bn1 = model.bn1
40 | self.relu = model.relu
41 | self.maxpool = model.maxpool
42 | self.layer1 = model.layer1
43 | self.layer2 = model.layer2
44 | self.layer3 = model.layer3
45 | self.layer4 = model.layer4
46 |
47 | def forward(self,x):
48 | x = self.conv1(x)
49 | x = self.bn1(x)
50 | x = self.relu(x)
51 | x = self.maxpool(x)
52 | x = self.layer1(x)
53 | x2 = self.layer2(x)
54 | x3 = self.layer3(x2)
55 | x4 = self.layer4(x3)
56 | return x2,x3,x4
57 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from model.backbone import resnet
3 | import numpy as np
4 |
5 | class conv_bn_relu(torch.nn.Module):
6 | def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False):
7 | super(conv_bn_relu,self).__init__()
8 | self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size,
9 | stride = stride, padding = padding, dilation = dilation,bias = bias)
10 | self.bn = torch.nn.BatchNorm2d(out_channels)
11 | self.relu = torch.nn.ReLU()
12 |
13 | def forward(self,x):
14 | x = self.conv(x)
15 | x = self.bn(x)
16 | x = self.relu(x)
17 | return x
18 | class parsingNet(torch.nn.Module):
19 | def __init__(self, size=(288, 800), pretrained=True, backbone='50', cls_dim=(37, 10, 4), use_aux=False):
20 | super(parsingNet, self).__init__()
21 |
22 | self.size = size
23 | self.w = size[0]
24 | self.h = size[1]
25 | self.cls_dim = cls_dim # (num_gridding, num_cls_per_lane, num_of_lanes)
26 | # num_cls_per_lane is the number of row anchors
27 | self.use_aux = use_aux
28 | self.total_dim = np.prod(cls_dim)#produce
29 |
30 | # input : nchw,
31 | # output: (w+1) * sample_rows * 4
32 | self.model = resnet(backbone, pretrained=pretrained)
33 |
34 | if self.use_aux:
35 | self.aux_header2 = torch.nn.Sequential(
36 | conv_bn_relu(128, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1),
37 | conv_bn_relu(128,128,3,padding=1),
38 | conv_bn_relu(128,128,3,padding=1),
39 | conv_bn_relu(128,128,3,padding=1),
40 | )
41 | self.aux_header3 = torch.nn.Sequential(
42 | conv_bn_relu(256, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1),
43 | conv_bn_relu(128,128,3,padding=1),
44 | conv_bn_relu(128,128,3,padding=1),
45 | )
46 | self.aux_header4 = torch.nn.Sequential(
47 | conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1),
48 | conv_bn_relu(128,128,3,padding=1),
49 | )
50 | self.aux_combine = torch.nn.Sequential(
51 | conv_bn_relu(384, 256, 3,padding=2,dilation=2),
52 | conv_bn_relu(256, 128, 3,padding=2,dilation=2),
53 | conv_bn_relu(128, 128, 3,padding=2,dilation=2),
54 | conv_bn_relu(128, 128, 3,padding=4,dilation=4),
55 | torch.nn.Conv2d(128, cls_dim[-1] + 1,1)
56 | # output : n, num_of_lanes+1, h, w
57 | )
58 | initialize_weights(self.aux_header2,self.aux_header3,self.aux_header4,self.aux_combine)
59 |
60 | self.cls = torch.nn.Sequential(
61 | torch.nn.Linear(1800, 2048),
62 | torch.nn.ReLU(),
63 | torch.nn.Linear(2048, self.total_dim),
64 | )
65 |
66 | self.pool = torch.nn.Conv2d(512,8,1) if backbone in ['34','18'] else torch.nn.Conv2d(2048,8,1)
67 | # 1/32,2048 channel
68 | # 288,800 -> 9,40,2048
69 | # (w+1) * sample_rows * 4
70 | # 37 * 10 * 4
71 | initialize_weights(self.cls)
72 |
73 | def forward(self, x):
74 | # n c h w - > n 2048 sh sw
75 | # -> n 2048
76 | x2,x3,fea = self.model(x)
77 | if self.use_aux:
78 | x2 = self.aux_header2(x2)
79 | x3 = self.aux_header3(x3)
80 | x3 = torch.nn.functional.interpolate(x3,scale_factor = 2,mode='bilinear')
81 | x4 = self.aux_header4(fea)
82 | x4 = torch.nn.functional.interpolate(x4,scale_factor = 4,mode='bilinear')
83 | aux_seg = torch.cat([x2,x3,x4],dim=1)
84 | aux_seg = self.aux_combine(aux_seg)
85 | else:
86 | aux_seg = None
87 |
88 | fea = self.pool(fea).view(-1, 1800)
89 |
90 | group_cls = self.cls(fea).view(-1, *self.cls_dim)
91 |
92 | if self.use_aux:
93 | return group_cls, aux_seg
94 |
95 | return group_cls
96 |
97 |
98 | def initialize_weights(*models):
99 | for model in models:
100 | real_init_weights(model)
101 | def real_init_weights(m):
102 |
103 | if isinstance(m, list):
104 | for mini_m in m:
105 | real_init_weights(mini_m)
106 | else:
107 | if isinstance(m, torch.nn.Conv2d):
108 | torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
109 | if m.bias is not None:
110 | torch.nn.init.constant_(m.bias, 0)
111 | elif isinstance(m, torch.nn.Linear):
112 | m.weight.data.normal_(0.0, std=0.01)
113 | elif isinstance(m, torch.nn.BatchNorm2d):
114 | torch.nn.init.constant_(m.weight, 1)
115 | torch.nn.init.constant_(m.bias, 0)
116 | elif isinstance(m,torch.nn.Module):
117 | for mini_m in m.children():
118 | real_init_weights(mini_m)
119 | else:
120 | print('unkonwn module', m)
121 |
--------------------------------------------------------------------------------
/onnx_to_tensorrt.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import cv2
6 | import tensorrt as trt
7 | import pycuda.driver as cuda
8 | import pycuda.autoinit
9 | import numpy as np
10 |
11 | EXPLICIT_BATCH = []
12 | if trt.__version__[0] >= '7':
13 | EXPLICIT_BATCH.append(
14 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
15 |
16 | mode = 'fp16'
17 |
18 | def build_engine(onnx_file_path, mode, verbose=False):
19 | """Build a TensorRT engine from an ONNX file."""
20 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
21 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
22 | builder.max_workspace_size = 1 << 30
23 | builder.max_batch_size = 1
24 | if mode=='fp16':
25 | builder.fp16_mode = True
26 | else:
27 | builder.fp16_mode = False
28 | #builder.strict_type_constraints = True
29 |
30 | # Parse model file
31 | print('Loading ONNX file from path {}...'.format(onnx_file_path))
32 | with open(onnx_file_path, 'rb') as model:
33 | if not parser.parse(model.read()):
34 | print('ERROR: Failed to parse the ONNX file.')
35 | for error in range(parser.num_errors):
36 | print(parser.get_error(error))
37 | return None
38 | if trt.__version__[0] >= '7':
39 | # Reshape input to batch size 1
40 | shape = list(network.get_input(0).shape)
41 | shape[0] = 1
42 | network.get_input(0).shape = shape
43 |
44 | model_name = onnx_file_path[:-5]
45 |
46 | print('Building an engine. This would take a while...')
47 | print('(Use "--verbose" to enable verbose logging.)')
48 | engine = builder.build_cuda_engine(network)
49 | print('Completed creating engine.')
50 | return engine
51 |
52 |
53 | def main():
54 | """Create a TensorRT engine for ONNX-based Model."""
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument(
57 | '-v', '--verbose', action='store_true',
58 | help='enable verbose output (for debugging)')
59 | parser.add_argument(
60 | '-m', '--model', type=str, default='model.onnx')
61 | parser.add_argument(
62 | '-p', '--precision', type=str, default='fp16')
63 | args = parser.parse_args()
64 |
65 | mode = args.precision
66 | onnx_file_path = args.model
67 | if not os.path.isfile(onnx_file_path):
68 | raise SystemExit('ERROR: file (%s) not found!' % onnx_file_path)
69 | if mode=='fp16':
70 | engine_file_path = '%s_fp16.trt'% args.model[:-5]
71 | elif mode == 'fp32':
72 | engine_file_path = '%s_fp32.trt'% args.model[:-5]
73 | else:
74 | print("illegal mode")
75 | exit(0)
76 | engine = build_engine(onnx_file_path, mode,args.verbose)
77 | with open(engine_file_path, 'wb') as f:
78 | f.write(engine.serialize())
79 | print('Serialized the TensorRT engine to file: %s' % engine_file_path)
80 |
81 |
82 |
83 | if __name__ == '__main__':
84 | main()
85 |
--------------------------------------------------------------------------------
/onnx_to_tensorrt_int8.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import cv2
6 | import tensorrt as trt
7 | import pycuda.driver as cuda
8 | import pycuda.autoinit
9 | import numpy as np
10 | import torchvision.transforms as transforms
11 | from PIL import Image
12 |
13 | img_transforms = transforms.Compose([
14 | transforms.Resize((288, 800)),
15 | transforms.ToTensor(),
16 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
17 | ])
18 |
19 | EXPLICIT_BATCH = []
20 | if trt.__version__[0] >= '7':
21 | EXPLICIT_BATCH.append(
22 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
23 |
24 |
25 |
26 | class EntropyCalibrator(trt.IInt8EntropyCalibrator2):
27 | def __init__(self, training_data, cache_file, batch_size=16):
28 | # Whenever you specify a custom constructor for a TensorRT class,
29 | # you MUST call the constructor of the parent explicitly.
30 | trt.IInt8EntropyCalibrator2.__init__(self)
31 |
32 | self.cache_file = cache_file
33 | # Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned.
34 | self.data = self.load_data(training_data)
35 | self.batch_size = batch_size
36 | self.current_index = 0
37 |
38 | # Allocate enough memory for a whole batch.
39 | self.device_input = cuda.mem_alloc(self.data[0].nbytes * self.batch_size)
40 |
41 | # Returns a numpy buffer of shape (num_images, 1, 28, 28)
42 | def load_data(self, datapath):
43 | print("loading image data")
44 | imgs = os.listdir(datapath)
45 | dataset = []
46 | for data in imgs:
47 | img = cv2.imread(datapath+data)
48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
49 | img = Image.fromarray(img)
50 | img = img_transforms(img).numpy()
51 | dataset.append(img)
52 | print(dataset)
53 | return np.array(dataset)
54 |
55 | def get_batch_size(self):
56 | return self.batch_size
57 |
58 | # TensorRT passes along the names of the engine bindings to the get_batch function.
59 | # You don't necessarily have to use them, but they can be useful to understand the order of
60 | # the inputs. The bindings list is expected to have the same ordering as 'names'.
61 | def get_batch(self, names):
62 | if self.current_index + self.batch_size > self.data.shape[0]:
63 | return None
64 |
65 | current_batch = int(self.current_index / self.batch_size)
66 | if current_batch % 10 == 0:
67 | print("Calibrating batch {:}, containing {:} images".format(current_batch, self.batch_size))
68 |
69 | batch = self.data[self.current_index:self.current_index + self.batch_size].ravel()
70 | cuda.memcpy_htod(self.device_input, batch)
71 | self.current_index += self.batch_size
72 | return [self.device_input]
73 |
74 | def read_calibration_cache(self):
75 | # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None.
76 | if os.path.exists(self.cache_file):
77 | with open(self.cache_file, "rb") as f:
78 | return f.read()
79 |
80 | def write_calibration_cache(self, cache):
81 | with open(self.cache_file, "wb") as f:
82 | f.write(cache)
83 |
84 | def build_int8_engine(onnx_file_path, calib, batch_size, verbose=False):
85 | """Build a TensorRT engine from an ONNX file."""
86 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
87 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
88 | builder.max_workspace_size = 1 << 30
89 | builder.max_batch_size = 1
90 | builder.int8_mode = True
91 | builder.int8_calibrator = calib
92 |
93 | # Parse model file
94 | print('Loading ONNX file from path {}...'.format(onnx_file_path))
95 | with open(onnx_file_path, 'rb') as model:
96 | if not parser.parse(model.read()):
97 | print('ERROR: Failed to parse the ONNX file.')
98 | for error in range(parser.num_errors):
99 | print(parser.get_error(error))
100 | return None
101 | if trt.__version__[0] >= '7':
102 | # Reshape input to batch size 1
103 | shape = list(network.get_input(0).shape)
104 | shape[0] = 1
105 | network.get_input(0).shape = shape
106 |
107 | print('Adding yolo_layer plugins...')
108 | model_name = onnx_file_path[:-5]
109 |
110 | print('Building an engine. This would take a while...')
111 | print('(Use "--verbose" to enable verbose logging.)')
112 | engine = builder.build_cuda_engine(network)
113 | print('Completed creating engine.')
114 | return engine
115 |
116 |
117 | def main():
118 | """Create a TensorRT engine for ONNX-based model."""
119 | parser = argparse.ArgumentParser()
120 | parser.add_argument(
121 | '-v', '--verbose', action='store_true',
122 | help='enable verbose output (for debugging)')
123 | parser.add_argument(
124 | '-m', '--model', type=str, default='model.onnx',
125 | )
126 | args = parser.parse_args()
127 |
128 | calibration_cache = "calibration.cache"
129 | data_path = 'calibration_data/testset/'
130 | calib = EntropyCalibrator(data_path, cache_file=calibration_cache)
131 |
132 | onnx_file_path = args.model
133 | if not os.path.isfile(onnx_file_path):
134 | raise SystemExit('ERROR: file (%s) not found!' % onnx_file_path)
135 | engine_file_path = '%s_int8.trt' % args.model[:-5]
136 | engine = build_int8_engine(onnx_file_path, calib, 16)
137 | with open(engine_file_path, 'wb') as f:
138 | f.write(engine.serialize())
139 | print('Serialized the TensorRT engine to file: %s' % engine_file_path)
140 |
141 |
142 |
143 | if __name__ == '__main__':
144 | main()
145 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | opencv-python
3 | matplotlib
4 | PySpin
5 | scipy
6 | addict
7 | tqdm
8 | tensorboard
9 | onnx
10 | tensorrt
11 |
--------------------------------------------------------------------------------
/tensorrt_run.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import argparse
5 | import cv2
6 | import tensorrt as trt
7 | import common
8 | import pycuda.driver as cuda
9 | import pycuda.autoinit
10 | import numpy as np
11 | import pycuda.gpuarray as gpuarray
12 | import time
13 | import scipy.special
14 | import torchvision.transforms as transforms
15 | from PIL import Image
16 |
17 | img_transforms = transforms.Compose([
18 | transforms.Resize((288, 800)),
19 | transforms.ToTensor(),
20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
21 | ])
22 |
23 | col_sample = np.linspace(0, 800 - 1, 100)
24 | col_sample_w = col_sample[1] - col_sample[0]
25 |
26 | img_w = 640
27 | img_h = 480
28 |
29 | row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112,
30 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164,
31 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216,
32 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268,
33 | 272, 276, 280, 284]
34 |
35 | color = [(255,255,0), (255,0,0),(0,0,255),(0,255,0)]
36 |
37 | EXPLICIT_BATCH = []
38 | if trt.__version__[0] >= '7':
39 | EXPLICIT_BATCH.append(
40 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
41 |
42 | def load_engine(trt_file_path, verbose=False):
43 | """Build a TensorRT engine from a TRT file."""
44 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
45 | print('Loading TRT file from path {}...'.format(trt_file_path))
46 | with open(trt_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
47 | engine = runtime.deserialize_cuda_engine(f.read())
48 | return engine
49 |
50 | def main():
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument(
53 | '-v', '--verbose', action='store_true',
54 | help='enable verbose output (for debugging)')
55 | parser.add_argument(
56 | '-m', '--model', type=str, default='model',
57 | )
58 | args = parser.parse_args()
59 |
60 |
61 |
62 | trt_file_path = '%s.trt' % args.model
63 | if not os.path.isfile(trt_file_path):
64 | raise SystemExit('ERROR: file (%s) not found!' % trt_file_path)
65 | engine_file_path = '%s.trt' % args.model
66 | engine = load_engine(trt_file_path, args.verbose)
67 |
68 | h_inputs, h_outputs, bindings, stream = common.allocate_buffers(engine)
69 |
70 |
71 | cap = cv2.VideoCapture(2)
72 | with engine.create_execution_context() as context:
73 | while True:
74 | _,frame = cap.read()
75 | t1 = time.time()
76 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
77 | img = Image.fromarray(img)
78 | img = img_transforms(img).numpy()
79 |
80 | h_inputs[0].host = img
81 | t3 = time.time()
82 | trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=h_inputs, outputs=h_outputs, stream=stream)
83 | t4 = time.time()
84 |
85 |
86 | out_j = trt_outputs[0].reshape(101, 56, 4)
87 |
88 | prob = scipy.special.softmax(out_j[:-1, :, :], axis=0)
89 |
90 |
91 | idx = np.arange(100) + 1
92 | idx = idx.reshape(-1, 1, 1)
93 |
94 | loc = np.sum(prob * idx, axis=0)
95 | out_j = np.argmax(out_j, axis=0)
96 | loc[out_j == 100] = 0
97 | out_j = loc
98 |
99 | # import pdb; pdb.set_trace()
100 | vis = frame
101 | for i in range(out_j.shape[1]):
102 | if np.sum(out_j[:, i] != 0) > 2:
103 | for k in range(out_j.shape[0]):
104 | if out_j[k, i] > 0:
105 | ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1, int(img_h * (row_anchor[k]/288)) - 1 )
106 | cv2.circle(vis,ppp, img_w//300 ,color[i],-1)
107 |
108 | t2 = time.time()
109 | print('Inference time', (t4-t3)*1000)
110 | print('FPS', int(1/((t2-t1))))
111 | cv2.imshow("OUTPUT", vis)
112 | cv2.waitKey(1)
113 |
114 |
115 | if __name__ == '__main__':
116 | main()
117 |
118 |
--------------------------------------------------------------------------------
/test_devices.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 | import time
4 | import numpy as np
5 |
6 | currentImage = None
7 |
8 |
9 |
10 | frame = 0
11 | if __name__ == "__main__":
12 | cap = cv2.VideoCapture(0)
13 |
14 | while True:
15 | frame += 1
16 | _, currentImage = cap.read()
17 | cv2.imshow("",currentImage)
18 | key = cv2.waitKey(1)
19 | if key == 'a':
20 | cv2.imwrite(str(frame)+".jpg",currentImage)
21 |
22 |
23 |
--------------------------------------------------------------------------------
/torch2onnx.py:
--------------------------------------------------------------------------------
1 | from UFLD import *
2 | import cv2
3 | import time
4 | import numpy as np
5 | import torch
6 | import onnx
7 |
8 |
9 | detector = laneDetection()
10 | detector.setResolution(640, 480)
11 | frame = 0
12 | currentImage = None
13 |
14 |
15 | if __name__ == "__main__":
16 | filepath = "model.onnx"
17 | dummy_input = torch.rand((1,3,288,800)).cuda()
18 | torch.onnx.export(detector.net, dummy_input, filepath)
19 |
--------------------------------------------------------------------------------
/utils/__pycache__/common.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/common.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dist_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dist_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dist_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/factory.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/factory.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/loss.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/loss.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import os, argparse
2 | from utils.dist_utils import is_main_process, dist_print, DistSummaryWriter
3 | from utils.config import Config
4 | import torch
5 |
6 | def str2bool(v):
7 | if isinstance(v, bool):
8 | return v
9 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
10 | return True
11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
12 | return False
13 | else:
14 | raise argparse.ArgumentTypeError('Boolean value expected.')
15 |
16 | def get_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('config', help = 'path to config file')
19 | parser.add_argument('--local_rank', type=int, default=0)
20 |
21 | parser.add_argument('--dataset', default = 'Tusimple', type = str)
22 | parser.add_argument('--data_root', default = None, type = str)
23 | parser.add_argument('--epoch', default = None, type = int)
24 | parser.add_argument('--batch_size', default = None, type = int)
25 | parser.add_argument('--optimizer', default = None, type = str)
26 | parser.add_argument('--learning_rate', default = None, type = float)
27 | parser.add_argument('--weight_decay', default = None, type = float)
28 | parser.add_argument('--momentum', default = None, type = float)
29 | parser.add_argument('--scheduler', default = None, type = str)
30 | parser.add_argument('--steps', default = None, type = int, nargs='+')
31 | parser.add_argument('--gamma', default = None, type = float)
32 | parser.add_argument('--warmup', default = None, type = str)
33 | parser.add_argument('--warmup_iters', default = None, type = int)
34 | parser.add_argument('--backbone', default = None, type = str)
35 | parser.add_argument('--griding_num', default = 100, type = int)
36 | parser.add_argument('--use_aux', default = None, type = str2bool)
37 | parser.add_argument('--sim_loss_w', default = None, type = float)
38 | parser.add_argument('--shp_loss_w', default = None, type = float)
39 | parser.add_argument('--note', default = None, type = str)
40 | parser.add_argument('--log_path', default = None, type = str)
41 | parser.add_argument('--finetune', default = None, type = str)
42 | parser.add_argument('--resume', default = None, type = str)
43 | parser.add_argument('--test_model', default = 'model.pth', type = str)
44 | parser.add_argument('--test_work_dir', default = None, type = str)
45 | parser.add_argument('--num_lanes', default = 4, type = int)
46 | parser.add_argument('--video', default = 'test.avi', type = str)
47 |
48 | return parser
49 |
50 | def merge_config():
51 | args = get_args().parse_args()
52 | cfg = Config.fromfile(args.config)
53 |
54 | items = ['dataset','data_root','epoch','batch_size','optimizer','learning_rate',
55 | 'weight_decay','momentum','scheduler','steps','gamma','warmup','warmup_iters',
56 | 'use_aux','griding_num','backbone','sim_loss_w','shp_loss_w','note','log_path',
57 | 'finetune','resume', 'test_model','test_work_dir', 'num_lanes','video']
58 | for item in items:
59 | if getattr(args, item) is not None:
60 | dist_print('merge ', item, ' config')
61 | setattr(cfg, item, getattr(args, item))
62 | return args, cfg
63 |
64 |
65 | def save_model(net, optimizer, epoch,save_path, distributed):
66 | if is_main_process():
67 | model_state_dict = net.state_dict()
68 | state = {'model': model_state_dict, 'optimizer': optimizer.state_dict()}
69 | # state = {'model': model_state_dict}
70 | assert os.path.exists(save_path)
71 | model_path = os.path.join(save_path, 'ep%03d.pth' % epoch)
72 | torch.save(state, model_path)
73 |
74 | import pathspec
75 |
76 | def cp_projects(to_path):
77 | if is_main_process():
78 | with open('./.gitignore','r') as fp:
79 | ign = fp.read()
80 | ign += '\n.git'
81 | spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, ign.splitlines())
82 | all_files = {os.path.join(root,name) for root,dirs,files in os.walk('./') for name in files}
83 | matches = spec.match_files(all_files)
84 | matches = set(matches)
85 | to_cp_files = all_files - matches
86 | # to_cp_files = [f[2:] for f in to_cp_files]
87 | # pdb.set_trace()
88 | for f in to_cp_files:
89 | dirs = os.path.join(to_path,'code',os.path.split(f[2:])[0])
90 | if not os.path.exists(dirs):
91 | os.makedirs(dirs)
92 | os.system('cp %s %s'%(f,os.path.join(to_path,'code',f[2:])))
93 |
94 |
95 | import datetime, os
96 | def get_work_dir(cfg):
97 | now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
98 | hyper_param_str = '_lr_%1.0e_b_%d' % (cfg.learning_rate, cfg.batch_size)
99 | work_dir = os.path.join(cfg.log_path, now + hyper_param_str + cfg.note)
100 | return work_dir
101 |
102 | def get_logger(work_dir, cfg):
103 | logger = DistSummaryWriter(work_dir)
104 | config_txt = os.path.join(work_dir, 'cfg.txt')
105 | if is_main_process():
106 | with open(config_txt, 'w') as fp:
107 | fp.write(str(cfg))
108 |
109 | return logger
110 |
--------------------------------------------------------------------------------
/utils/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os.path as osp
3 | import shutil
4 | import sys
5 | import tempfile
6 | from argparse import Action, ArgumentParser
7 | from collections import abc
8 | from importlib import import_module
9 |
10 | from addict import Dict
11 |
12 |
13 | BASE_KEY = '_base_'
14 | DELETE_KEY = '_delete_'
15 |
16 |
17 | class ConfigDict(Dict):
18 |
19 | def __missing__(self, name):
20 | raise KeyError(name)
21 |
22 | def __getattr__(self, name):
23 | try:
24 | value = super(ConfigDict, self).__getattr__(name)
25 | except KeyError:
26 | ex = AttributeError(f"'{self.__class__.__name__}' object has no "
27 | f"attribute '{name}'")
28 | except Exception as e:
29 | ex = e
30 | else:
31 | return value
32 | raise ex
33 |
34 |
35 | def add_args(parser, cfg, prefix=''):
36 | for k, v in cfg.items():
37 | if isinstance(v, str):
38 | parser.add_argument('--' + prefix + k)
39 | elif isinstance(v, int):
40 | parser.add_argument('--' + prefix + k, type=int)
41 | elif isinstance(v, float):
42 | parser.add_argument('--' + prefix + k, type=float)
43 | elif isinstance(v, bool):
44 | parser.add_argument('--' + prefix + k, action='store_true')
45 | elif isinstance(v, dict):
46 | add_args(parser, v, prefix + k + '.')
47 | elif isinstance(v, abc.Iterable):
48 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
49 | else:
50 | print(f'cannot parse key {prefix + k} of type {type(v)}')
51 | return parser
52 |
53 |
54 | class Config(object):
55 | """A facility for config and config files.
56 | It supports common file formats as configs: python/json/yaml. The interface
57 | is the same as a dict object and also allows access config values as
58 | attributes.
59 | Example:
60 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
61 | >>> cfg.a
62 | 1
63 | >>> cfg.b
64 | {'b1': [0, 1]}
65 | >>> cfg.b.b1
66 | [0, 1]
67 | >>> cfg = Config.fromfile('tests/data/config/a.py')
68 | >>> cfg.filename
69 | "/home/kchen/projects/mmcv/tests/data/config/a.py"
70 | >>> cfg.item4
71 | 'test'
72 | >>> cfg
73 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
74 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
75 | """
76 |
77 | @staticmethod
78 | def _file2dict(filename):
79 | filename = osp.abspath(osp.expanduser(filename))
80 | if filename.endswith('.py'):
81 | with tempfile.TemporaryDirectory() as temp_config_dir:
82 | temp_config_file = tempfile.NamedTemporaryFile(
83 | dir=temp_config_dir, suffix='.py')
84 | temp_config_name = osp.basename(temp_config_file.name)
85 | # close temp file
86 | temp_config_file.close()
87 | shutil.copyfile(filename,
88 | osp.join(temp_config_dir, temp_config_name))
89 | temp_module_name = osp.splitext(temp_config_name)[0]
90 | sys.path.insert(0, temp_config_dir)
91 | mod = import_module(temp_module_name)
92 | sys.path.pop(0)
93 | cfg_dict = {
94 | name: value
95 | for name, value in mod.__dict__.items()
96 | if not name.startswith('__')
97 | }
98 | # delete imported module
99 | del sys.modules[temp_module_name]
100 |
101 | elif filename.endswith(('.yml', '.yaml', '.json')):
102 | import mmcv
103 | cfg_dict = mmcv.load(filename)
104 | else:
105 | raise IOError('Only py/yml/yaml/json type are supported now!')
106 |
107 | cfg_text = filename + '\n'
108 | with open(filename, 'r') as f:
109 | cfg_text += f.read()
110 |
111 | if BASE_KEY in cfg_dict:
112 | cfg_dir = osp.dirname(filename)
113 | base_filename = cfg_dict.pop(BASE_KEY)
114 | base_filename = base_filename if isinstance(
115 | base_filename, list) else [base_filename]
116 |
117 | cfg_dict_list = list()
118 | cfg_text_list = list()
119 | for f in base_filename:
120 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
121 | cfg_dict_list.append(_cfg_dict)
122 | cfg_text_list.append(_cfg_text)
123 |
124 | base_cfg_dict = dict()
125 | for c in cfg_dict_list:
126 | if len(base_cfg_dict.keys() & c.keys()) > 0:
127 | raise KeyError('Duplicate key is not allowed among bases')
128 | base_cfg_dict.update(c)
129 |
130 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
131 | cfg_dict = base_cfg_dict
132 |
133 | # merge cfg_text
134 | cfg_text_list.append(cfg_text)
135 | cfg_text = '\n'.join(cfg_text_list)
136 |
137 | return cfg_dict, cfg_text
138 |
139 | @staticmethod
140 | def _merge_a_into_b(a, b):
141 | # merge dict `a` into dict `b` (non-inplace). values in `a` will
142 | # overwrite `b`.
143 | # copy first to avoid inplace modification
144 | b = b.copy()
145 | for k, v in a.items():
146 | if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
147 | if not isinstance(b[k], dict):
148 | raise TypeError(
149 | f'{k}={v} in child config cannot inherit from base '
150 | f'because {k} is a dict in the child config but is of '
151 | f'type {type(b[k])} in base config. You may set '
152 | f'`{DELETE_KEY}=True` to ignore the base config')
153 | b[k] = Config._merge_a_into_b(v, b[k])
154 | else:
155 | b[k] = v
156 | return b
157 |
158 | @staticmethod
159 | def fromfile(filename):
160 | cfg_dict, cfg_text = Config._file2dict(filename)
161 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
162 |
163 | @staticmethod
164 | def auto_argparser(description=None):
165 | """Generate argparser from config file automatically (experimental)
166 | """
167 | partial_parser = ArgumentParser(description=description)
168 | partial_parser.add_argument('config', help='config file path')
169 | cfg_file = partial_parser.parse_known_args()[0].config
170 | cfg = Config.fromfile(cfg_file)
171 | parser = ArgumentParser(description=description)
172 | parser.add_argument('config', help='config file path')
173 | add_args(parser, cfg)
174 | return parser, cfg
175 |
176 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
177 | if cfg_dict is None:
178 | cfg_dict = dict()
179 | elif not isinstance(cfg_dict, dict):
180 | raise TypeError('cfg_dict must be a dict, but '
181 | f'got {type(cfg_dict)}')
182 |
183 | super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
184 | super(Config, self).__setattr__('_filename', filename)
185 | if cfg_text:
186 | text = cfg_text
187 | elif filename:
188 | with open(filename, 'r') as f:
189 | text = f.read()
190 | else:
191 | text = ''
192 | super(Config, self).__setattr__('_text', text)
193 |
194 | @property
195 | def filename(self):
196 | return self._filename
197 |
198 | @property
199 | def text(self):
200 | return self._text
201 |
202 | @property
203 | def pretty_text(self):
204 |
205 | indent = 4
206 |
207 | def _indent(s_, num_spaces):
208 | s = s_.split('\n')
209 | if len(s) == 1:
210 | return s_
211 | first = s.pop(0)
212 | s = [(num_spaces * ' ') + line for line in s]
213 | s = '\n'.join(s)
214 | s = first + '\n' + s
215 | return s
216 |
217 | def _format_basic_types(k, v):
218 | if isinstance(v, str):
219 | v_str = f"'{v}'"
220 | else:
221 | v_str = str(v)
222 | attr_str = f'{str(k)}={v_str}'
223 | attr_str = _indent(attr_str, indent)
224 |
225 | return attr_str
226 |
227 | def _format_list(k, v):
228 | # check if all items in the list are dict
229 | if all(isinstance(_, dict) for _ in v):
230 | v_str = '[\n'
231 | v_str += '\n'.join(
232 | f'dict({_indent(_format_dict(v_), indent)}),'
233 | for v_ in v).rstrip(',')
234 | attr_str = f'{str(k)}={v_str}'
235 | attr_str = _indent(attr_str, indent) + ']'
236 | else:
237 | attr_str = _format_basic_types(k, v)
238 | return attr_str
239 |
240 | def _format_dict(d, outest_level=False):
241 | r = ''
242 | s = []
243 | for idx, (k, v) in enumerate(d.items()):
244 | is_last = idx >= len(d) - 1
245 | end = '' if outest_level or is_last else ','
246 | if isinstance(v, dict):
247 | v_str = '\n' + _format_dict(v)
248 | attr_str = f'{str(k)}=dict({v_str}'
249 | attr_str = _indent(attr_str, indent) + ')' + end
250 | elif isinstance(v, list):
251 | attr_str = _format_list(k, v) + end
252 | else:
253 | attr_str = _format_basic_types(k, v) + end
254 |
255 | s.append(attr_str)
256 | r += '\n'.join(s)
257 | return r
258 |
259 | cfg_dict = self._cfg_dict.to_dict()
260 | text = _format_dict(cfg_dict, outest_level=True)
261 |
262 | return text
263 |
264 | def __repr__(self):
265 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
266 |
267 | def __len__(self):
268 | return len(self._cfg_dict)
269 |
270 | def __getattr__(self, name):
271 | return getattr(self._cfg_dict, name)
272 |
273 | def __getitem__(self, name):
274 | return self._cfg_dict.__getitem__(name)
275 |
276 | def __setattr__(self, name, value):
277 | if isinstance(value, dict):
278 | value = ConfigDict(value)
279 | self._cfg_dict.__setattr__(name, value)
280 |
281 | def __setitem__(self, name, value):
282 | if isinstance(value, dict):
283 | value = ConfigDict(value)
284 | self._cfg_dict.__setitem__(name, value)
285 |
286 | def __iter__(self):
287 | return iter(self._cfg_dict)
288 |
289 | def dump(self):
290 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
291 | format_text = json.dumps(cfg_dict, indent=2)
292 | return format_text
293 |
294 | def merge_from_dict(self, options):
295 | """Merge list into cfg_dict
296 | Merge the dict parsed by MultipleKVAction into this cfg.
297 | Examples:
298 | >>> options = {'model.backbone.depth': 50,
299 | ... 'model.backbone.with_cp':True}
300 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
301 | >>> cfg.merge_from_dict(options)
302 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
303 | >>> assert cfg_dict == dict(
304 | ... model=dict(backbone=dict(depth=50, with_cp=True)))
305 | Args:
306 | options (dict): dict of configs to merge from.
307 | """
308 | option_cfg_dict = {}
309 | for full_key, v in options.items():
310 | d = option_cfg_dict
311 | key_list = full_key.split('.')
312 | for subkey in key_list[:-1]:
313 | d.setdefault(subkey, ConfigDict())
314 | d = d[subkey]
315 | subkey = key_list[-1]
316 | d[subkey] = v
317 |
318 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
319 | super(Config, self).__setattr__(
320 | '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
321 |
322 |
323 | class DictAction(Action):
324 | """
325 | argparse action to split an argument into KEY=VALUE form
326 | on the first = and append to a dictionary. List options should
327 | be passed as comma separated values, i.e KEY=V1,V2,V3
328 | """
329 |
330 | @staticmethod
331 | def _parse_int_float_bool(val):
332 | try:
333 | return int(val)
334 | except ValueError:
335 | pass
336 | try:
337 | return float(val)
338 | except ValueError:
339 | pass
340 | if val.lower() in ['true', 'false']:
341 | return True if val.lower() == 'true' else False
342 | return val
343 |
344 | def __call__(self, parser, namespace, values, option_string=None):
345 | options = {}
346 | for kv in values:
347 | key, val = kv.split('=', maxsplit=1)
348 | val = [self._parse_int_float_bool(v) for v in val.split(',')]
349 | if len(val) == 1:
350 | val = val[0]
351 | options[key] = val
352 | setattr(namespace, self.dest, options)
--------------------------------------------------------------------------------
/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import pickle
4 |
5 |
6 | def get_world_size():
7 | if not dist.is_available():
8 | return 1
9 | if not dist.is_initialized():
10 | return 1
11 | return dist.get_world_size()
12 |
13 |
14 | def to_python_float(t):
15 | if hasattr(t, 'item'):
16 | return t.item()
17 | else:
18 | return t[0]
19 |
20 |
21 | def get_rank():
22 | if not dist.is_available():
23 | return 0
24 | if not dist.is_initialized():
25 | return 0
26 | return dist.get_rank()
27 |
28 |
29 | def is_main_process():
30 | return get_rank() == 0
31 |
32 |
33 | def can_log():
34 | return is_main_process()
35 |
36 |
37 | def dist_print(*args, **kwargs):
38 | if can_log():
39 | print(*args, **kwargs)
40 |
41 |
42 | def synchronize():
43 | """
44 | Helper function to synchronize (barrier) among all processes when
45 | using distributed training
46 | """
47 | if not dist.is_available():
48 | return
49 | if not dist.is_initialized():
50 | return
51 | world_size = dist.get_world_size()
52 | if world_size == 1:
53 | return
54 | dist.barrier()
55 |
56 | def dist_cat_reduce_tensor(tensor):
57 | if not dist.is_available():
58 | return tensor
59 | if not dist.is_initialized():
60 | return tensor
61 | # dist_print(tensor)
62 | rt = tensor.clone()
63 | all_list = [torch.zeros_like(tensor) for _ in range(get_world_size())]
64 | dist.all_gather(all_list,rt)
65 | # dist_print(all_list[0][1],all_list[1][1],all_list[2][1],all_list[3][1])
66 | # dist_print(all_list[0][2],all_list[1][2],all_list[2][2],all_list[3][2])
67 | # dist_print(all_list[0][3],all_list[1][3],all_list[2][3],all_list[3][3])
68 | # dist_print(all_list[0].shape)
69 | return torch.cat(all_list,dim = 0)
70 |
71 | def dist_sum_reduce_tensor(tensor):
72 | if not dist.is_available():
73 | return tensor
74 | if not dist.is_initialized():
75 | return tensor
76 | if not isinstance(tensor, torch.Tensor):
77 | return tensor
78 | rt = tensor.clone()
79 | dist.all_reduce(rt, op=dist.reduce_op.SUM)
80 | return rt
81 |
82 |
83 | def dist_mean_reduce_tensor(tensor):
84 | rt = dist_sum_reduce_tensor(tensor)
85 | rt /= get_world_size()
86 | return rt
87 |
88 |
89 | def all_gather(data):
90 | """
91 | Run all_gather on arbitrary picklable data (not necessarily tensors)
92 | Args:
93 | data: any picklable object
94 | Returns:
95 | list[data]: list of data gathered from each rank
96 | """
97 | world_size = get_world_size()
98 | if world_size == 1:
99 | return [data]
100 |
101 | # serialized to a Tensor
102 | buffer = pickle.dumps(data)
103 | storage = torch.ByteStorage.from_buffer(buffer)
104 | tensor = torch.ByteTensor(storage).to("cuda")
105 |
106 | # obtain Tensor size of each rank
107 | local_size = torch.LongTensor([tensor.numel()]).to("cuda")
108 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
109 | dist.all_gather(size_list, local_size)
110 | size_list = [int(size.item()) for size in size_list]
111 | max_size = max(size_list)
112 |
113 | # receiving Tensor from all ranks
114 | # we pad the tensor because torch all_gather does not support
115 | # gathering tensors of different shapes
116 | tensor_list = []
117 | for _ in size_list:
118 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
119 | if local_size != max_size:
120 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
121 | tensor = torch.cat((tensor, padding), dim=0)
122 | dist.all_gather(tensor_list, tensor)
123 |
124 | data_list = []
125 | for size, tensor in zip(size_list, tensor_list):
126 | buffer = tensor.cpu().numpy().tobytes()[:size]
127 | data_list.append(pickle.loads(buffer))
128 |
129 | return data_list
130 |
131 |
132 | from torch.utils.tensorboard import SummaryWriter
133 |
134 |
135 | class DistSummaryWriter(SummaryWriter):
136 | def __init__(self, *args, **kwargs):
137 | if can_log():
138 | super(DistSummaryWriter, self).__init__(*args, **kwargs)
139 |
140 | def add_scalar(self, *args, **kwargs):
141 | if can_log():
142 | super(DistSummaryWriter, self).add_scalar(*args, **kwargs)
143 |
144 | def add_figure(self, *args, **kwargs):
145 | if can_log():
146 | super(DistSummaryWriter, self).add_figure(*args, **kwargs)
147 |
148 | def add_graph(self, *args, **kwargs):
149 | if can_log():
150 | super(DistSummaryWriter, self).add_graph(*args, **kwargs)
151 |
152 | def add_histogram(self, *args, **kwargs):
153 | if can_log():
154 | super(DistSummaryWriter, self).add_histogram(*args, **kwargs)
155 |
156 | def add_image(self, *args, **kwargs):
157 | if can_log():
158 | super(DistSummaryWriter, self).add_image(*args, **kwargs)
159 |
160 | def close(self):
161 | if can_log():
162 | super(DistSummaryWriter, self).close()
163 |
164 |
165 | import tqdm
166 |
167 |
168 | def dist_tqdm(obj, *args, **kwargs):
169 | if can_log():
170 | return tqdm.tqdm(obj, *args, **kwargs)
171 | else:
172 | return obj
173 |
174 |
--------------------------------------------------------------------------------
/utils/factory.py:
--------------------------------------------------------------------------------
1 | from utils.loss import SoftmaxFocalLoss, ParsingRelationLoss, ParsingRelationDis
2 | from utils.metrics import MultiLabelAcc, AccTopk, Metric_mIoU
3 | from utils.dist_utils import DistSummaryWriter
4 |
5 | import torch
6 |
7 |
8 | def get_optimizer(net,cfg):
9 | training_params = filter(lambda p: p.requires_grad, net.parameters())
10 | if cfg.optimizer == 'Adam':
11 | optimizer = torch.optim.Adam(training_params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
12 | elif cfg.optimizer == 'SGD':
13 | optimizer = torch.optim.SGD(training_params, lr=cfg.learning_rate, momentum=cfg.momentum,
14 | weight_decay=cfg.weight_decay)
15 | else:
16 | raise NotImplementedError
17 | return optimizer
18 |
19 | def get_scheduler(optimizer, cfg, iters_per_epoch):
20 | if cfg.scheduler == 'multi':
21 | scheduler = MultiStepLR(optimizer, cfg.steps, cfg.gamma, iters_per_epoch, cfg.warmup, iters_per_epoch if cfg.warmup_iters is None else cfg.warmup_iters)
22 | elif cfg.scheduler == 'cos':
23 | scheduler = CosineAnnealingLR(optimizer, cfg.epoch * iters_per_epoch, eta_min = 0, warmup = cfg.warmup, warmup_iters = cfg.warmup_iters)
24 | else:
25 | raise NotImplementedError
26 | return scheduler
27 |
28 | def get_loss_dict(cfg):
29 |
30 | if cfg.use_aux:
31 | loss_dict = {
32 | 'name': ['cls_loss', 'relation_loss', 'aux_loss', 'relation_dis'],
33 | 'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), torch.nn.CrossEntropyLoss(), ParsingRelationDis()],
34 | 'weight': [1.0, cfg.sim_loss_w, 1.0, cfg.shp_loss_w],
35 | 'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('seg_out', 'seg_label'), ('cls_out',)]
36 | }
37 | else:
38 | loss_dict = {
39 | 'name': ['cls_loss', 'relation_loss', 'relation_dis'],
40 | 'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), ParsingRelationDis()],
41 | 'weight': [1.0, cfg.sim_loss_w, cfg.shp_loss_w],
42 | 'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('cls_out',)]
43 | }
44 |
45 | return loss_dict
46 |
47 | def get_metric_dict(cfg):
48 |
49 | if cfg.use_aux:
50 | metric_dict = {
51 | 'name': ['top1', 'top2', 'top3', 'iou'],
52 | 'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3), Metric_mIoU(8+1)],
53 | 'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('seg_out', 'seg_label')]
54 | }
55 | else:
56 | metric_dict = {
57 | 'name': ['top1', 'top2', 'top3'],
58 | 'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3)],
59 | 'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label')]
60 | }
61 |
62 |
63 | return metric_dict
64 |
65 |
66 | class MultiStepLR:
67 | def __init__(self, optimizer, steps, gamma = 0.1, iters_per_epoch = None, warmup = None, warmup_iters = None):
68 | self.warmup = warmup
69 | self.warmup_iters = warmup_iters
70 | self.optimizer = optimizer
71 | self.steps = steps
72 | self.steps.sort()
73 | self.gamma = gamma
74 | self.iters_per_epoch = iters_per_epoch
75 | self.iters = 0
76 | self.base_lr = [group['lr'] for group in optimizer.param_groups]
77 |
78 | def step(self, external_iter = None):
79 | self.iters += 1
80 | if external_iter is not None:
81 | self.iters = external_iter
82 | if self.warmup == 'linear' and self.iters < self.warmup_iters:
83 | rate = self.iters / self.warmup_iters
84 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
85 | group['lr'] = lr * rate
86 | return
87 |
88 | # multi policy
89 | if self.iters % self.iters_per_epoch == 0:
90 | epoch = int(self.iters / self.iters_per_epoch)
91 | power = -1
92 | for i, st in enumerate(self.steps):
93 | if epoch < st:
94 | power = i
95 | break
96 | if power == -1:
97 | power = len(self.steps)
98 | # print(self.iters, self.iters_per_epoch, self.steps, power)
99 |
100 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
101 | group['lr'] = lr * (self.gamma ** power)
102 | import math
103 | class CosineAnnealingLR:
104 | def __init__(self, optimizer, T_max , eta_min = 0, warmup = None, warmup_iters = None):
105 | self.warmup = warmup
106 | self.warmup_iters = warmup_iters
107 | self.optimizer = optimizer
108 | self.T_max = T_max
109 | self.eta_min = eta_min
110 |
111 | self.iters = 0
112 | self.base_lr = [group['lr'] for group in optimizer.param_groups]
113 |
114 | def step(self, external_iter = None):
115 | self.iters += 1
116 | if external_iter is not None:
117 | self.iters = external_iter
118 | if self.warmup == 'linear' and self.iters < self.warmup_iters:
119 | rate = self.iters / self.warmup_iters
120 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
121 | group['lr'] = lr * rate
122 | return
123 |
124 | # cos policy
125 |
126 | for group, lr in zip(self.optimizer.param_groups, self.base_lr):
127 | group['lr'] = self.eta_min + (lr - self.eta_min) * (1 + math.cos(math.pi * self.iters / self.T_max)) / 2
128 |
129 |
130 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import numpy as np
6 |
7 | class OhemCELoss(nn.Module):
8 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
9 | super(OhemCELoss, self).__init__()
10 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
11 | self.n_min = n_min
12 | self.ignore_lb = ignore_lb
13 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
14 |
15 | def forward(self, logits, labels):
16 | N, C, H, W = logits.size()
17 | loss = self.criteria(logits, labels).view(-1)
18 | loss, _ = torch.sort(loss, descending=True)
19 | if loss[self.n_min] > self.thresh:
20 | loss = loss[loss>self.thresh]
21 | else:
22 | loss = loss[:self.n_min]
23 | return torch.mean(loss)
24 |
25 |
26 | class SoftmaxFocalLoss(nn.Module):
27 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
28 | super(SoftmaxFocalLoss, self).__init__()
29 | self.gamma = gamma
30 | self.nll = nn.NLLLoss(ignore_index=ignore_lb)
31 |
32 | def forward(self, logits, labels):
33 | scores = F.softmax(logits, dim=1)
34 | factor = torch.pow(1.-scores, self.gamma)
35 | log_score = F.log_softmax(logits, dim=1)
36 | log_score = factor * log_score
37 | loss = self.nll(log_score, labels)
38 | return loss
39 |
40 | class ParsingRelationLoss(nn.Module):
41 | def __init__(self):
42 | super(ParsingRelationLoss, self).__init__()
43 | def forward(self,logits):
44 | n,c,h,w = logits.shape
45 | loss_all = []
46 | for i in range(0,h-1):
47 | loss_all.append(logits[:,:,i,:] - logits[:,:,i+1,:])
48 | #loss0 : n,c,w
49 | loss = torch.cat(loss_all)
50 | return torch.nn.functional.smooth_l1_loss(loss,torch.zeros_like(loss))
51 |
52 |
53 |
54 | class ParsingRelationDis(nn.Module):
55 | def __init__(self):
56 | super(ParsingRelationDis, self).__init__()
57 | self.l1 = torch.nn.L1Loss()
58 | # self.l1 = torch.nn.MSELoss()
59 | def forward(self, x):
60 | n,dim,num_rows,num_cols = x.shape
61 | x = torch.nn.functional.softmax(x[:,:dim-1,:,:],dim=1)
62 | embedding = torch.Tensor(np.arange(dim-1)).float().to(x.device).view(1,-1,1,1)
63 | pos = torch.sum(x*embedding,dim = 1)
64 |
65 | diff_list1 = []
66 | for i in range(0,num_rows // 2):
67 | diff_list1.append(pos[:,i,:] - pos[:,i+1,:])
68 |
69 | loss = 0
70 | for i in range(len(diff_list1)-1):
71 | loss += self.l1(diff_list1[i],diff_list1[i+1])
72 | loss /= len(diff_list1) - 1
73 | return loss
74 |
75 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import time,pdb
4 |
5 | def converter(data):
6 | if isinstance(data,torch.Tensor):
7 | data = data.cpu().data.numpy().flatten()
8 | return data.flatten()
9 | def fast_hist(label_pred, label_true,num_classes):
10 | #pdb.set_trace()
11 | hist = np.bincount(num_classes * label_true.astype(int) + label_pred, minlength=num_classes ** 2)
12 | hist = hist.reshape(num_classes, num_classes)
13 | return hist
14 |
15 | class Metric_mIoU():
16 | def __init__(self,class_num):
17 | self.class_num = class_num
18 | self.hist = np.zeros((self.class_num,self.class_num))
19 | def update(self,predict,target):
20 | predict,target = converter(predict),converter(target)
21 |
22 | self.hist += fast_hist(predict,target,self.class_num)
23 |
24 | def reset(self):
25 | self.hist = np.zeros((self.class_num,self.class_num))
26 | def get_miou(self):
27 | miou = np.diag(self.hist) / (
28 | np.sum(self.hist, axis=1) + np.sum(self.hist, axis=0) -
29 | np.diag(self.hist))
30 | miou = np.nanmean(miou)
31 | return miou
32 |
33 | def get_acc(self):
34 | acc = np.diag(self.hist) / self.hist.sum(axis=1)
35 | acc = np.nanmean(acc)
36 | return acc
37 | def get(self):
38 | return self.get_miou()
39 | class MultiLabelAcc():
40 | def __init__(self):
41 | self.cnt = 0
42 | self.correct = 0
43 | def reset(self):
44 | self.cnt = 0
45 | self.correct = 0
46 | def update(self,predict,target):
47 | predict,target = converter(predict),converter(target)
48 | self.cnt += len(predict)
49 | self.correct += np.sum(predict==target)
50 | def get_acc(self):
51 | return self.correct * 1.0 / self.cnt
52 | def get(self):
53 | return self.get_acc()
54 | class AccTopk():
55 | def __init__(self,background_classes,k):
56 | self.background_classes = background_classes
57 | self.k = k
58 | self.cnt = 0
59 | self.top5_correct = 0
60 | def reset(self):
61 | self.cnt = 0
62 | self.top5_correct = 0
63 | def update(self,predict,target):
64 | predict,target = converter(predict),converter(target)
65 | self.cnt += len(predict)
66 | background_idx = (predict == self.background_classes) + (target == self.background_classes)
67 | self.top5_correct += np.sum(predict[background_idx] == target[background_idx])
68 | not_background_idx = np.logical_not(background_idx)
69 | self.top5_correct += np.sum(np.absolute(predict[not_background_idx]-target[not_background_idx])= '7':
13 | EXPLICIT_BATCH.append(
14 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
15 |
16 |
17 | def build_engine(onnx_file_path, verbose=False):
18 | """Build a TensorRT engine from an ONNX file."""
19 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger()
20 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
21 | builder.max_workspace_size = 1 << 28
22 | builder.max_batch_size = 1
23 | builder.fp16_mode = False
24 | #builder.strict_type_constraints = True
25 |
26 | # Parse model file
27 | print('Loading ONNX file from path {}...'.format(onnx_file_path))
28 | with open(onnx_file_path, 'rb') as model:
29 | if not parser.parse(model.read()):
30 | print('ERROR: Failed to parse the ONNX file.')
31 | for error in range(parser.num_errors):
32 | print(parser.get_error(error))
33 | return None
34 | #if trt.__version__[0] >= '7':
35 | # The actual yolo*.onnx is generated with batch size 64.
36 | # Reshape input to batch size 1
37 | # shape = list(network.get_input(0).shape)
38 | # shape[0] = 1
39 | # network.get_input(0).shape = shape
40 |
41 | print('Adding yolo_layer plugins...')
42 | model_name = onnx_file_path[:-5]
43 | #network = add_yolo_plugins(
44 | # network, model_name, category_num, TRT_LOGGER)
45 |
46 | print('Building an engine. This would take a while...')
47 | print('(Use "--verbose" to enable verbose logging.)')
48 | engine = builder.build_cuda_engine(network)
49 | print('Completed creating engine.')
50 | return engine
51 |
52 |
53 | def main():
54 | """Create a TensorRT engine for ONNX-based YOLO."""
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument(
57 | '-v', '--verbose', action='store_true',
58 | help='enable verbose output (for debugging)')
59 | parser.add_argument(
60 | '-c', '--category_num', type=int, default=4,
61 | help='number of object categories [80]')
62 | parser.add_argument(
63 | '-m', '--model', type=str, default='model', # 修改这里即可,例如:res18_lane.pth->default='res18_lane'
64 | help=('[yolov3|yolov3-tiny|yolov3-spp|yolov4|yolov4-tiny]-'
65 | '[{dimension}], where dimension could be a single '
66 | 'number (e.g. 288, 416, 608) or WxH (e.g. 416x256)'))
67 | args = parser.parse_args()
68 |
69 | onnx_file_path = '%s.onnx' % args.model
70 | if not os.path.isfile(onnx_file_path):
71 | raise SystemExit('ERROR: file (%s) not found! You might want to run yolo_to_onnx.py first to generate it.' % onnx_file_path)
72 | engine_file_path = '%s.trt' % args.model
73 | engine = build_engine(onnx_file_path, args.verbose)
74 | with open(engine_file_path, 'wb') as f:
75 | f.write(engine.serialize())
76 | print('Serialized the TensorRT engine to file: %s' % engine_file_path)
77 |
78 |
79 | if __name__ == '__main__':
80 | main()
81 |
--------------------------------------------------------------------------------