├── LICENSE ├── README.md ├── UFLD.py ├── __pycache__ ├── UFLD.cpython-37.pyc ├── UFLD.cpython-38.pyc ├── common.cpython-37.pyc ├── dataloader.cpython-37.pyc ├── flirCapture2.cpython-36.pyc ├── flirCapture2.cpython-38.pyc ├── laneDetection.cpython-35.pyc ├── laneDetection.cpython-36.pyc ├── laneDetection.cpython-37.pyc └── laneDetection.cpython-38.pyc ├── calibration.cache ├── calibration_data ├── __pycache__ │ ├── constant.cpython-37.pyc │ ├── dataloader.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ └── mytransforms.cpython-37.pyc ├── constant.py ├── dataloader.py ├── dataset.py ├── make_mini_tusimple.py └── mytransforms.py ├── common.py ├── configs ├── __pycache__ │ ├── constant.cpython-36.pyc │ ├── constant.cpython-37.pyc │ └── constant.cpython-38.pyc ├── constant.py └── tusimple_4.py ├── launch_opencv.py ├── mnist_calibration.cache ├── model ├── __pycache__ │ ├── backbone.cpython-36.pyc │ ├── backbone.cpython-37.pyc │ ├── backbone.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── model_convert.cpython-38.pyc │ └── model_convert2.cpython-38.pyc ├── backbone.py └── model.py ├── onnx_to_tensorrt.py ├── onnx_to_tensorrt_int8.py ├── requirement.txt ├── tensorrt_run.py ├── test_devices.py ├── torch2onnx.py └── utils ├── __pycache__ ├── common.cpython-36.pyc ├── common.cpython-37.pyc ├── common.cpython-38.pyc ├── config.cpython-36.pyc ├── config.cpython-37.pyc ├── config.cpython-38.pyc ├── dist_utils.cpython-36.pyc ├── dist_utils.cpython-37.pyc ├── dist_utils.cpython-38.pyc ├── factory.cpython-38.pyc ├── loss.cpython-38.pyc └── metrics.cpython-38.pyc ├── common.py ├── config.py ├── dist_utils.py ├── factory.py ├── loss.py ├── metrics.py ├── onnx2trt.py ├── onnx2trt_test.py └── onnx_to_tensorrt.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TRT_Ultra_Fast_Lane_Detect 2 | 3 | TRT_Ultra_Fast_Lane_Detect is an implementation of converting Ultra fast lane detection into tensorRT model by Python API. There are some other works in our project are listed below: 4 | 5 | - The detection procedure is encapsulated. 6 | - The pytorch model is transformed into onnx model and trt model. 7 | - The trt models have different versions: FP32, FP16, INT8. 8 | - The Tusimple data set can be compressed by /calibration_data/make_mini_tusimple.py. There are redundancies in the Tusimple data set, for only 20-th frames are used. The compressed tusimple data set takes about 1GB. 9 | 10 | The original project, model, and paper is available from https://github.com/cfzd/Ultra-Fast-Lane-Detection 11 | 12 | 13 | 14 | ### Ultra-Fast-Lane-Detection 15 | 16 | PyTorch implementation of the paper "[Ultra Fast Structure-aware Deep Lane Detection](https://arxiv.org/abs/2004.11757)". 17 | 18 | Updates: Our paper has been accepted by ECCV2020. 19 | 20 | [![alt text](https://github.com/cfzd/Ultra-Fast-Lane-Detection/raw/master/vis.jpg)](https://github.com/cfzd/Ultra-Fast-Lane-Detection/blob/master/vis.jpg) 21 | 22 | The evaluation code is modified from [SCNN](https://github.com/XingangPan/SCNN) and [Tusimple Benchmark](https://github.com/TuSimple/tusimple-benchmark). 23 | 24 | Caffe model and prototxt can be found [here](https://github.com/Jade999/caffe_lane_detection). 25 | 26 | 27 | 28 | ### Trained models 29 | 30 | The trained models can be obtained by the following table: 31 | 32 | | Dataset | Metric paper | Metric This repo | Avg FPS on GTX 1080Ti | Model | 33 | | -------- | ------------ | ---------------- | --------------------- | ------------------------------------------------------------ | 34 | | Tusimple | 95.87 | 95.82 | 306 | [GoogleDrive](https://drive.google.com/file/d/1WCYyur5ZaWczH15ecmeDowrW30xcLrCn/view?usp=sharing)/[BaiduDrive(code:bghd)](https://pan.baidu.com/s/1Fjm5yVq1JDpGjh4bdgdDLA) | 35 | | CULane | 68.4 | 69.7 | 324 | [GoogleDrive](https://drive.google.com/file/d/1zXBRTw50WOzvUp6XKsi8Zrk3MUC3uFuq/view?usp=sharing)/[BaiduDrive(code:w9tw)](https://pan.baidu.com/s/19Ig0TrV8MfmFTyCvbSa4ag) | 36 | 37 | 38 | 39 | ### Installation 40 | 41 | `pip3 install -r requirement.txt` 42 | 43 | 44 | 45 | ### Convert 46 | 47 | Above all, you have to train or download a 4 lane model trained by the Ultra Fast Lane Detection pytorch version. You have to change some codes, if you want to use different lane number. 48 | 49 | 50 | 51 | Now, we have a trained pytorch model "model.pth". 52 | 53 | 1. Use torch2onnx.py to convert the the model into onnx model. You should rename your model as "model.pth". The original configuration file is configs/tusimple_4.py. 54 | 55 | `python3 configs/${config_file}.py ` 56 | 57 | 2. Use onnx_to_tensorrt.py to convert the onnx model in to tensorRT model (FP16, FP32). 58 | 59 | `python3 onnx_to_tensorrt.py -p ${mode_in_fp16_or_fp32} --model ${model_name}` 60 | 61 | 3. Use onnx_to_tensorrt.py to convert the onnx model in to tensorRT model (INT8). 62 | 63 | `python3 onnx_to_tensorrt.py --model ${model_name}` 64 | 65 | 4. Run tensorrt_run.py to activate detection 66 | 67 | `python tensorrt_run.py --model ${model_name}` 68 | 69 | 70 | 71 | ### Evalutaion 72 | 73 | | | Pytorch | libtorch | tensorRT(FP32) | tensorRT(FP16) | tensorRT(int8) | 74 | | :--------: | :-----: | :------: | :------------: | :------------: | :------------: | 75 | | GTX1060 | 55fps | 55fps | 55fps | Unsupported | 99fps | 76 | | Xavier AGX | 27fps | 27fps | -- | -- | -- | 77 | | Jetson TX1 | 8fps | 8fps | 8fps | 16fps | Unsupported | 78 | | Jetson nano A01(4GB) | -- | -- | -- | 8fps | Unsupported | 79 | 80 | Where "--" denotes the experiment hasn't been completed yet. 81 | Anyone with untested equipment can send his results to the issues. The results will be adopted. 82 | -------------------------------------------------------------------------------- /UFLD.py: -------------------------------------------------------------------------------- 1 | import torch, os, cv2 2 | from model.model import parsingNet 3 | from utils.common import merge_config 4 | from utils.dist_utils import dist_print 5 | from configs.constant import tusimple_row_anchor 6 | import torch 7 | import scipy.special, tqdm 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import time 12 | from torch.autograd import Variable 13 | import onnx 14 | 15 | 16 | class laneDetection(): 17 | def __init__(self): 18 | torch.backends.cudnn.benchmark = True 19 | self.args, self.cfg = merge_config() 20 | self.cls_num_per_lane = 56 21 | self.row_anchor = tusimple_row_anchor 22 | self.net = parsingNet(pretrained = False, backbone=self.cfg.backbone, cls_dim = (self.cfg.griding_num+1, self.cls_num_per_lane, self.cfg.num_lanes), use_aux=False).cuda() 23 | 24 | state_dict = torch.load(self.cfg.test_model, map_location='cpu')['model'] 25 | compatible_state_dict = {} 26 | for k, v in state_dict.items(): 27 | if 'module.' in k: 28 | compatible_state_dict[k[7:]] = v 29 | else: 30 | compatible_state_dict[k] = v 31 | 32 | self.net.load_state_dict(compatible_state_dict, strict=False) 33 | 34 | #not recommend to uncommen this line 35 | self.net.eval() 36 | 37 | self.img_transforms = transforms.Compose([ 38 | transforms.Resize((288, 800)), 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 41 | ]) 42 | 43 | self.img_w = 960 44 | self.img_h = 480 45 | self.scale_factor = 1 46 | self.color = [(255,0,0),(0,255,0),(0,0,255),(255,255,0)] 47 | self.idx = np.arange(self.cfg.griding_num) + 1 48 | self.idx = self.idx.reshape(-1, 1, 1) 49 | 50 | self.cpu_img = None 51 | self.gpu_img = None 52 | self.type = None 53 | self.gpu_output = None 54 | self.cpu_output = None 55 | 56 | col_sample = np.linspace(0, 800 - 1, self.cfg.griding_num) 57 | self.col_sample_w = col_sample[1] - col_sample[0] 58 | 59 | def setResolution(self, w, h): 60 | self.img_w = w 61 | self.img_h = h 62 | 63 | def getFrame(self, frame): 64 | self.cpu_img = frame 65 | 66 | def setScaleFactor(self, factor=1): 67 | self.scale_factor = factor 68 | 69 | def preprocess(self): 70 | tmp_img = cv2.cvtColor(self.cpu_img, cv2.COLOR_BGR2RGB) 71 | if self.scale_factor != 1: 72 | tmp_img = cv2.resize(tmp_img, (self.img_w//self.scale_factor, self.img_h//self.scale_factor)) 73 | tmp_img = Image.fromarray(tmp_img) 74 | tmp_img = self.img_transforms(tmp_img) 75 | self.gpu_img = tmp_img.unsqueeze(0).cuda() 76 | 77 | def inference(self): 78 | self.gpu_output = self.net(self.gpu_img) 79 | 80 | def parseResults(self): 81 | self.cpu_output = self.gpu_output[0].data.cpu().numpy() 82 | self.prob = scipy.special.softmax(self.cpu_output[:-1, :, :], axis=0) 83 | 84 | self.loc = np.sum(self.prob * self.idx, axis=0) 85 | self.cpu_output = np.argmax(self.cpu_output, axis=0) 86 | 87 | self.loc[self.cpu_output == self.cfg.griding_num] = 0 88 | #self.cpu_output = self.loc 89 | 90 | # import pdb; pdb.set_trace() 91 | vis = self.cpu_img 92 | for i in range(self.loc.shape[1]): 93 | if np.sum(self.loc[:, i] > 0) > 40: 94 | for k in range(self.loc.shape[0]): 95 | if self.loc[k, i] > 0: 96 | ppp = (int(self.loc[k, i] * self.col_sample_w * self.img_w / 800) - 1, int(self.img_h * (self.row_anchor[k]/288)) - 1 ) 97 | cv2.circle(vis,ppp,3, self.color[i], -1) 98 | 99 | cv2.imshow("output",vis) 100 | cv2.waitKey(1) 101 | return vis 102 | 103 | 104 | -------------------------------------------------------------------------------- /__pycache__/UFLD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/UFLD.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/UFLD.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/UFLD.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/flirCapture2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/flirCapture2.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/flirCapture2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/flirCapture2.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/laneDetection.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/laneDetection.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/laneDetection.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/laneDetection.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/__pycache__/laneDetection.cpython-38.pyc -------------------------------------------------------------------------------- /calibration.cache: -------------------------------------------------------------------------------- 1 | TRT-7000-EntropyCalibration2 2 | input.1: 3c6966a5 3 | 127: 3c298ff7 4 | 128: 3b5df2ba 5 | 129: 3b5df2ba 6 | 130: 3b5df2ba 7 | 131: 3bd8b106 8 | 132: 3b5d4402 9 | 133: 3b151529 10 | 134: 3b5e7f99 11 | 135: 3bbe2fc8 12 | 136: 3b936c4c 13 | 137: 3b840b6f 14 | 138: 3bb62876 15 | 139: 3b44a09e 16 | 140: 3b1f11fd 17 | 141: 3b5c9d1f 18 | 142: 3b87e6e8 19 | 143: 3bd1e091 20 | 144: 3bd1e091 21 | 145: 3bed89d1 22 | 146: 3c3c2f3a 23 | 147: 3b7fc3b6 24 | 148: 3b636566 25 | 149: 3c2c5bd4 26 | 150: 3b248a2b 27 | 151: 3b69ad12 28 | 152: 3c749c21 29 | 153: 3b74d13d 30 | 154: 3c0c0a0a 31 | 155: 3c2a12fa 32 | 156: 3ade8c52 33 | 157: 3b74846d 34 | 158: 3c133ae7 35 | 159: 3c19d149 36 | 160: 3c0e6787 37 | 161: 3c1f198d 38 | 162: 3c475aa0 39 | 163: 3c020fdd 40 | 164: 3bc506df 41 | 165: 3c5efda5 42 | 166: 3b152a0b 43 | 167: 3b4bbbdd 44 | 168: 3c5adf69 45 | 169: 3c4ba73f 46 | 170: 3bcb724d 47 | 171: 3c4472d7 48 | 172: 3c03f7bf 49 | 173: 3b837231 50 | 174: 3c58ce13 51 | 175: 3c51b338 52 | 176: 3c4474b3 53 | 177: 3b98ca6b 54 | 178: 3c1a27de 55 | 179: 3bdc38ad 56 | 180: 3b06c3e8 57 | 181: 3c59fb33 58 | 182: 3ad771f3 59 | 183: 3c198c73 60 | 184: 3cb0b2f4 61 | 185: 3c6782e0 62 | 186: 3bcc34f8 63 | 187: 3be23788 64 | 188: 3be23788 65 | 189: 3b241ede 66 | 190: 3d46f62c 67 | 191: 3d4323f7 68 | 192: 3d549149 69 | 193: 3e3bfc93 70 | 195: 3e3bfc93 71 | (Unnamed Layer* 69) [Constant]_output: 3a0ac72c 72 | (Unnamed Layer* 70) [Matrix Multiply]_output: 3dd67f34 73 | (Unnamed Layer* 71) [Constant]_output: 38a25c2e 74 | (Unnamed Layer* 72) [Shuffle]_output: 38a25c2e 75 | 196: 3dd69133 76 | 197: 3dd69133 77 | (Unnamed Layer* 76) [Constant]_output: 396ebf65 78 | (Unnamed Layer* 77) [Matrix Multiply]_output: 3d8ddbd0 79 | (Unnamed Layer* 78) [Constant]_output: 38714d42 80 | (Unnamed Layer* 79) [Shuffle]_output: 38714d42 81 | 198: 3d8ddc8e 82 | 200: 3d8ddc8e 83 | -------------------------------------------------------------------------------- /calibration_data/__pycache__/constant.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/constant.cpython-37.pyc -------------------------------------------------------------------------------- /calibration_data/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /calibration_data/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /calibration_data/__pycache__/mytransforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/calibration_data/__pycache__/mytransforms.cpython-37.pyc -------------------------------------------------------------------------------- /calibration_data/constant.py: -------------------------------------------------------------------------------- 1 | # row anchors are a series of pre-defined coordinates in image height to detect lanes 2 | # the row anchors are defined according to the evaluation protocol of CULane and Tusimple 3 | # since our method will resize the image to 288x800 for training, the row anchors are defined with the height of 288 4 | # you can modify these row anchors according to your training image resolution 5 | 6 | tusimple_row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 7 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 8 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216, 9 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268, 10 | 272, 276, 280, 284] 11 | culane_row_anchor = [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287] 12 | -------------------------------------------------------------------------------- /calibration_data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | import numpy as np 3 | 4 | import torchvision.transforms as transforms 5 | import calibration_data.mytransforms as mytransforms 6 | from calibration_data.constant import tusimple_row_anchor, culane_row_anchor 7 | from calibration_data.dataset import LaneClsDataset, LaneTestDataset 8 | 9 | def get_train_loader(batch_size, data_root, griding_num, dataset, use_aux, distributed, num_lanes): 10 | target_transform = transforms.Compose([ 11 | mytransforms.FreeScaleMask((288, 800)), 12 | mytransforms.MaskToTensor(), 13 | ]) 14 | segment_transform = transforms.Compose([ 15 | mytransforms.FreeScaleMask((36, 100)), 16 | mytransforms.MaskToTensor(), 17 | ]) 18 | img_transform = transforms.Compose([ 19 | transforms.Resize((288, 800)), 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 22 | ]) 23 | simu_transform = mytransforms.Compose2([ 24 | mytransforms.RandomRotate(6), 25 | mytransforms.RandomUDoffsetLABEL(100), 26 | mytransforms.RandomLROffsetLABEL(200) 27 | ]) 28 | if dataset == 'CULane': 29 | train_dataset = LaneClsDataset(data_root, 30 | os.path.join(data_root, 'list/train_gt.txt'), 31 | img_transform=img_transform, target_transform=target_transform, 32 | simu_transform = simu_transform, 33 | segment_transform=segment_transform, 34 | row_anchor = culane_row_anchor, 35 | griding_num=griding_num, use_aux=use_aux, num_lanes = num_lanes) 36 | cls_num_per_lane = 18 37 | 38 | elif dataset == 'Tusimple': 39 | train_dataset = LaneClsDataset(data_root, 40 | os.path.join(data_root, 'train_gt.txt'), 41 | img_transform=img_transform, target_transform=target_transform, 42 | simu_transform = simu_transform, 43 | griding_num=griding_num, 44 | row_anchor = tusimple_row_anchor, 45 | segment_transform=segment_transform,use_aux=use_aux, num_lanes = num_lanes) 46 | cls_num_per_lane = 56 47 | else: 48 | raise NotImplementedError 49 | 50 | if distributed: 51 | sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 52 | else: 53 | sampler = torch.utils.data.RandomSampler(train_dataset) 54 | 55 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler = sampler, num_workers=4) 56 | 57 | return train_loader, cls_num_per_lane 58 | 59 | def get_test_loader(batch_size, data_root,dataset, distributed): 60 | img_transforms = transforms.Compose([ 61 | transforms.Resize((288, 800)), 62 | transforms.ToTensor(), 63 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 64 | ]) 65 | if dataset == 'CULane': 66 | test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'list/test.txt'),img_transform = img_transforms) 67 | cls_num_per_lane = 18 68 | elif dataset == 'Tusimple': 69 | test_dataset = LaneTestDataset(data_root,os.path.join(data_root, 'test.txt'), img_transform = img_transforms) 70 | cls_num_per_lane = 56 71 | 72 | if distributed: 73 | sampler = SeqDistributedSampler(test_dataset, shuffle = False) 74 | else: 75 | sampler = torch.utils.data.SequentialSampler(test_dataset) 76 | loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, sampler = sampler, num_workers=4) 77 | return loader 78 | 79 | 80 | class SeqDistributedSampler(torch.utils.data.distributed.DistributedSampler): 81 | ''' 82 | Change the behavior of DistributedSampler to sequential distributed sampling. 83 | The sequential sampling helps the stability of multi-thread testing, which needs multi-thread file io. 84 | Without sequentially sampling, the file io on thread may interfere other threads. 85 | ''' 86 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False): 87 | super().__init__(dataset, num_replicas, rank, shuffle) 88 | def __iter__(self): 89 | g = torch.Generator() 90 | g.manual_seed(self.epoch) 91 | if self.shuffle: 92 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 93 | else: 94 | indices = list(range(len(self.dataset))) 95 | 96 | 97 | # add extra samples to make it evenly divisible 98 | indices += indices[:(self.total_size - len(indices))] 99 | assert len(indices) == self.total_size 100 | 101 | 102 | num_per_rank = int(self.total_size // self.num_replicas) 103 | 104 | # sequential sampling 105 | indices = indices[num_per_rank * self.rank : num_per_rank * (self.rank + 1)] 106 | 107 | assert len(indices) == self.num_samples 108 | 109 | return iter(indices) 110 | -------------------------------------------------------------------------------- /calibration_data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import os 4 | import pdb 5 | import numpy as np 6 | import cv2 7 | from calibration_data.mytransforms import find_start_pos 8 | 9 | 10 | def loader_func(path): 11 | return Image.open(path) 12 | 13 | 14 | class LaneTestDataset(torch.utils.data.Dataset): 15 | def __init__(self, path, list_path, img_transform=None): 16 | super(LaneTestDataset, self).__init__() 17 | self.path = path 18 | self.img_transform = img_transform 19 | with open(list_path, 'r') as f: 20 | self.list = f.readlines() 21 | self.list = [l[1:] if l[0] == '/' else l for l in self.list] # exclude the incorrect path prefix '/' of CULane 22 | 23 | 24 | def __getitem__(self, index): 25 | name = self.list[index].split()[0] 26 | img_path = os.path.join(self.path, name) 27 | img = loader_func(img_path) 28 | 29 | if self.img_transform is not None: 30 | img = self.img_transform(img) 31 | 32 | return img, name 33 | 34 | def __len__(self): 35 | return len(self.list) 36 | 37 | 38 | class LaneClsDataset(torch.utils.data.Dataset): 39 | def __init__(self, path, list_path, img_transform = None,target_transform = None,simu_transform = None, griding_num=50, load_name = False, 40 | row_anchor = None,use_aux=False,segment_transform=None, num_lanes = 8): 41 | super(LaneClsDataset, self).__init__() 42 | self.img_transform = img_transform 43 | self.target_transform = target_transform 44 | self.segment_transform = segment_transform 45 | self.simu_transform = simu_transform 46 | self.path = path 47 | self.griding_num = griding_num 48 | self.load_name = load_name 49 | self.use_aux = use_aux 50 | self.num_lanes = num_lanes 51 | 52 | with open(list_path, 'r') as f: 53 | self.list = f.readlines() 54 | 55 | self.row_anchor = row_anchor 56 | self.row_anchor.sort() 57 | 58 | def __getitem__(self, index): 59 | l = self.list[index] 60 | l_info = l.split() 61 | img_name, label_name = l_info[0], l_info[1] 62 | if img_name[0] == '/': 63 | img_name = img_name[1:] 64 | label_name = label_name[1:] 65 | 66 | label_path = os.path.join(self.path, label_name) 67 | label = loader_func(label_path) 68 | 69 | img_path = os.path.join(self.path, img_name) 70 | img = loader_func(img_path) 71 | 72 | 73 | if self.simu_transform is not None: 74 | img, label = self.simu_transform(img, label) 75 | lane_pts = self._get_index(label) 76 | # get the coordinates of lanes at row anchors 77 | 78 | 79 | 80 | w, h = img.size 81 | cls_label = self._grid_pts(lane_pts, self.griding_num, w) 82 | # make the coordinates to classification label 83 | if self.use_aux: 84 | assert self.segment_transform is not None 85 | seg_label = self.segment_transform(label) 86 | 87 | if self.img_transform is not None: 88 | img = self.img_transform(img) 89 | 90 | if self.use_aux: 91 | return img, cls_label, seg_label 92 | if self.load_name: 93 | return img, cls_label, img_name 94 | return img, cls_label 95 | 96 | def __len__(self): 97 | return len(self.list) 98 | 99 | def _grid_pts(self, pts, num_cols, w): 100 | # pts : numlane,n,2 101 | num_lane, n, n2 = pts.shape 102 | col_sample = np.linspace(0, w - 1, num_cols) 103 | 104 | assert n2 == 2 105 | to_pts = np.zeros((n, num_lane)) 106 | for i in range(num_lane): 107 | pti = pts[i, :, 1] 108 | to_pts[:, i] = np.asarray( 109 | [int(pt // (col_sample[1] - col_sample[0])) if pt != -1 else num_cols for pt in pti]) 110 | return to_pts.astype(int) 111 | 112 | def _get_index(self, label): 113 | w, h = label.size 114 | 115 | if h != 288: 116 | scale_f = lambda x : int((x * 1.0/288) * h) 117 | sample_tmp = list(map(scale_f,self.row_anchor)) 118 | 119 | all_idx = np.zeros((self.num_lanes,len(sample_tmp),2)) 120 | for i,r in enumerate(sample_tmp): 121 | label_r = np.asarray(label)[int(round(r))] 122 | for lane_idx in range(1, self.num_lanes + 1): 123 | pos = np.where(label_r == lane_idx)[0] 124 | if len(pos) == 0: 125 | all_idx[lane_idx - 1, i, 0] = r 126 | all_idx[lane_idx - 1, i, 1] = -1 127 | continue 128 | pos = np.mean(pos) 129 | all_idx[lane_idx - 1, i, 0] = r 130 | all_idx[lane_idx - 1, i, 1] = pos 131 | 132 | # data augmentation: extend the lane to the boundary of image 133 | 134 | all_idx_cp = all_idx.copy() 135 | for i in range(self.num_lanes): 136 | if np.all(all_idx_cp[i,:,1] == -1): 137 | continue 138 | # if there is no lane 139 | 140 | valid = all_idx_cp[i,:,1] != -1 141 | # get all valid lane points' index 142 | valid_idx = all_idx_cp[i,valid,:] 143 | # get all valid lane points 144 | if valid_idx[-1,0] == all_idx_cp[0,-1,0]: 145 | # if the last valid lane point's y-coordinate is already the last y-coordinate of all rows 146 | # this means this lane has reached the bottom boundary of the image 147 | # so we skip 148 | continue 149 | if len(valid_idx) < 6: 150 | continue 151 | # if the lane is too short to extend 152 | 153 | valid_idx_half = valid_idx[len(valid_idx) // 2:,:] 154 | p = np.polyfit(valid_idx_half[:,0], valid_idx_half[:,1],deg = 1) 155 | start_line = valid_idx_half[-1,0] 156 | pos = find_start_pos(all_idx_cp[i,:,0],start_line) + 1 157 | 158 | fitted = np.polyval(p,all_idx_cp[i,pos:,0]) 159 | fitted = np.array([-1 if y < 0 or y > w-1 else y for y in fitted]) 160 | 161 | assert np.all(all_idx_cp[i,pos:,1] == -1) 162 | all_idx_cp[i,pos:,1] = fitted 163 | if -1 in all_idx[:, :, 0]: 164 | pdb.set_trace() 165 | return all_idx_cp 166 | -------------------------------------------------------------------------------- /calibration_data/make_mini_tusimple.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | path ='Tusimple/clips' 4 | 5 | def get_filelist(dir): 6 | for home, dirs, files in os.walk(path): 7 | for filename in files: 8 | if filename == "20.jpg" or filename == "20.png": 9 | continue 10 | else: 11 | print(filename) 12 | os.remove(os.path.join(home, filename)) 13 | 14 | if __name__ =="__main__": 15 | get_filelist(path) 16 | -------------------------------------------------------------------------------- /calibration_data/mytransforms.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import random 3 | import numpy as np 4 | from PIL import Image, ImageOps, ImageFilter 5 | #from config import cfg 6 | import torch 7 | import pdb 8 | import cv2 9 | 10 | # ===============================img tranforms============================ 11 | 12 | class Compose2(object): 13 | def __init__(self, transforms): 14 | self.transforms = transforms 15 | 16 | def __call__(self, img, mask, bbx=None): 17 | if bbx is None: 18 | for t in self.transforms: 19 | img, mask = t(img, mask) 20 | return img, mask 21 | for t in self.transforms: 22 | img, mask, bbx = t(img, mask, bbx) 23 | return img, mask, bbx 24 | 25 | class FreeScale(object): 26 | def __init__(self, size): 27 | self.size = size # (h, w) 28 | 29 | def __call__(self, img, mask): 30 | return img.resize((self.size[1], self.size[0]), Image.BILINEAR), mask.resize((self.size[1], self.size[0]), Image.NEAREST) 31 | 32 | class FreeScaleMask(object): 33 | def __init__(self,size): 34 | self.size = size 35 | def __call__(self,mask): 36 | return mask.resize((self.size[1], self.size[0]), Image.NEAREST) 37 | 38 | class Scale(object): 39 | def __init__(self, size): 40 | self.size = size 41 | 42 | def __call__(self, img, mask): 43 | if img.size != mask.size: 44 | print(img.size) 45 | print(mask.size) 46 | assert img.size == mask.size 47 | w, h = img.size 48 | if (w <= h and w == self.size) or (h <= w and h == self.size): 49 | return img, mask 50 | if w < h: 51 | ow = self.size 52 | oh = int(self.size * h / w) 53 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 54 | else: 55 | oh = self.size 56 | ow = int(self.size * w / h) 57 | return img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST) 58 | 59 | 60 | class RandomRotate(object): 61 | """Crops the given PIL.Image at a random location to have a region of 62 | the given size. size can be a tuple (target_height, target_width) 63 | or an integer, in which case the target will be of a square shape (size, size) 64 | """ 65 | 66 | def __init__(self, angle): 67 | self.angle = angle 68 | 69 | def __call__(self, image, label): 70 | #assert label is None or image.size == label.size 71 | #assert label is None or image.size == label.size 72 | 73 | angle = random.randint(0, self.angle * 2) - self.angle 74 | 75 | label = label.rotate(angle, resample=Image.NEAREST) 76 | image = image.rotate(angle, resample=Image.BILINEAR) 77 | 78 | return image, label 79 | 80 | 81 | 82 | # ===============================label tranforms============================ 83 | 84 | class DeNormalize(object): 85 | def __init__(self, mean, std): 86 | self.mean = mean 87 | self.std = std 88 | 89 | def __call__(self, tensor): 90 | for t, m, s in zip(tensor, self.mean, self.std): 91 | t.mul_(s).add_(m) 92 | return tensor 93 | 94 | 95 | class MaskToTensor(object): 96 | def __call__(self, img): 97 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 98 | 99 | 100 | def find_start_pos(row_sample,start_line): 101 | # row_sample = row_sample.sort() 102 | # for i,r in enumerate(row_sample): 103 | # if r >= start_line: 104 | # return i 105 | l,r = 0,len(row_sample)-1 106 | while True: 107 | mid = int((l+r)/2) 108 | if r - l == 1: 109 | return r 110 | if row_sample[mid] < start_line: 111 | l = mid 112 | if row_sample[mid] > start_line: 113 | r = mid 114 | if row_sample[mid] == start_line: 115 | return mid 116 | 117 | class RandomLROffsetLABEL(object): 118 | def __init__(self,max_offset): 119 | self.max_offset = max_offset 120 | def __call__(self,img,label): 121 | offset = np.random.randint(-self.max_offset,self.max_offset) 122 | w, h = img.size 123 | 124 | img = np.array(img) 125 | if offset > 0: 126 | img[:,offset:,:] = img[:,0:w-offset,:] 127 | img[:,:offset,:] = 0 128 | if offset < 0: 129 | real_offset = -offset 130 | img[:,0:w-real_offset,:] = img[:,real_offset:,:] 131 | img[:,w-real_offset:,:] = 0 132 | 133 | label = np.array(label) 134 | if offset > 0: 135 | label[:,offset:] = label[:,0:w-offset] 136 | label[:,:offset] = 0 137 | if offset < 0: 138 | offset = -offset 139 | label[:,0:w-offset] = label[:,offset:] 140 | label[:,w-offset:] = 0 141 | return Image.fromarray(img),Image.fromarray(label) 142 | 143 | class RandomUDoffsetLABEL(object): 144 | def __init__(self,max_offset): 145 | self.max_offset = max_offset 146 | def __call__(self,img,label): 147 | offset = np.random.randint(-self.max_offset,self.max_offset) 148 | w, h = img.size 149 | 150 | img = np.array(img) 151 | if offset > 0: 152 | img[offset:,:,:] = img[0:h-offset,:,:] 153 | img[:offset,:,:] = 0 154 | if offset < 0: 155 | real_offset = -offset 156 | img[0:h-real_offset,:,:] = img[real_offset:,:,:] 157 | img[h-real_offset:,:,:] = 0 158 | 159 | label = np.array(label) 160 | if offset > 0: 161 | label[offset:,:] = label[0:h-offset,:] 162 | label[:offset,:] = 0 163 | if offset < 0: 164 | offset = -offset 165 | label[0:h-offset,:] = label[offset:,:] 166 | label[h-offset:,:] = 0 167 | return Image.fromarray(img),Image.fromarray(label) 168 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | # 4 | # NOTICE TO LICENSEE: 5 | # 6 | # This source code and/or documentation ("Licensed Deliverables") are 7 | # subject to NVIDIA intellectual property rights under U.S. and 8 | # international Copyright laws. 9 | # 10 | # These Licensed Deliverables contained herein is PROPRIETARY and 11 | # CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | # conditions of a form of NVIDIA software license agreement by and 13 | # between NVIDIA and Licensee ("License Agreement") or electronically 14 | # accepted by Licensee. Notwithstanding any terms or conditions to 15 | # the contrary in the License Agreement, reproduction or disclosure 16 | # of the Licensed Deliverables to any third party without the express 17 | # written consent of NVIDIA is prohibited. 18 | # 19 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | # OF THESE LICENSED DELIVERABLES. 33 | # 34 | # U.S. Government End Users. These Licensed Deliverables are a 35 | # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | # 1995), consisting of "commercial computer software" and "commercial 37 | # computer software documentation" as such terms are used in 48 38 | # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | # only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | # U.S. Government End Users acquire the Licensed Deliverables with 42 | # only those rights set forth herein. 43 | # 44 | # Any use of the Licensed Deliverables in individual and commercial 45 | # software must include, in the user documentation and internal 46 | # comments to the code, the above Disclaimer and U.S. Government End 47 | # Users Notice. 48 | # 49 | 50 | from itertools import chain 51 | import argparse 52 | import os 53 | 54 | import pycuda.driver as cuda 55 | import pycuda.autoinit 56 | import numpy as np 57 | 58 | import tensorrt as trt 59 | 60 | try: 61 | # Sometimes python2 does not understand FileNotFoundError 62 | FileNotFoundError 63 | except NameError: 64 | FileNotFoundError = IOError 65 | 66 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 67 | 68 | def GiB(val): 69 | return val * 1 << 30 70 | 71 | 72 | def add_help(description): 73 | parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 74 | args, _ = parser.parse_known_args() 75 | 76 | 77 | def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]): 78 | ''' 79 | Parses sample arguments. 80 | 81 | Args: 82 | description (str): Description of the sample. 83 | subfolder (str): The subfolder containing data relevant to this sample 84 | find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. 85 | 86 | Returns: 87 | str: Path of data directory. 88 | ''' 89 | 90 | # Standard command-line arguments for all samples. 91 | kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") 92 | parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 93 | parser.add_argument("-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT]) 94 | args, _ = parser.parse_known_args() 95 | 96 | def get_data_path(data_dir): 97 | # If the subfolder exists, append it to the path, otherwise use the provided path as-is. 98 | data_path = os.path.join(data_dir, subfolder) 99 | if not os.path.exists(data_path): 100 | print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.") 101 | data_path = data_dir 102 | # Make sure data directory exists. 103 | if not (os.path.exists(data_path)): 104 | print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path)) 105 | return data_path 106 | 107 | data_paths = [get_data_path(data_dir) for data_dir in args.datadir] 108 | return data_paths, locate_files(data_paths, find_files) 109 | 110 | def locate_files(data_paths, filenames): 111 | """ 112 | Locates the specified files in the specified data directories. 113 | If a file exists in multiple data directories, the first directory is used. 114 | 115 | Args: 116 | data_paths (List[str]): The data directories. 117 | filename (List[str]): The names of the files to find. 118 | 119 | Returns: 120 | List[str]: The absolute paths of the files. 121 | 122 | Raises: 123 | FileNotFoundError if a file could not be located. 124 | """ 125 | found_files = [None] * len(filenames) 126 | for data_path in data_paths: 127 | # Find all requested files. 128 | for index, (found, filename) in enumerate(zip(found_files, filenames)): 129 | if not found: 130 | file_path = os.path.abspath(os.path.join(data_path, filename)) 131 | if os.path.exists(file_path): 132 | found_files[index] = file_path 133 | 134 | # Check that all files were found 135 | for f, filename in zip(found_files, filenames): 136 | if not f or not os.path.exists(f): 137 | raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}".format(filename, data_paths)) 138 | return found_files 139 | 140 | # Simple helper data class that's a little nicer to use than a 2-tuple. 141 | class HostDeviceMem(object): 142 | def __init__(self, host_mem, device_mem): 143 | self.host = host_mem 144 | self.device = device_mem 145 | 146 | def __str__(self): 147 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 148 | 149 | def __repr__(self): 150 | return self.__str__() 151 | 152 | # Allocates all buffers required for an engine, i.e. host/device inputs/outputs. 153 | def allocate_buffers(engine): 154 | inputs = [] 155 | outputs = [] 156 | bindings = [] 157 | stream = cuda.Stream() 158 | for binding in engine: 159 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size 160 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 161 | # Allocate host and device buffers 162 | host_mem = cuda.pagelocked_empty(size, dtype) 163 | device_mem = cuda.mem_alloc(host_mem.nbytes) 164 | # Append the device buffer to device bindings. 165 | bindings.append(int(device_mem)) 166 | # Append to the appropriate list. 167 | if engine.binding_is_input(binding): 168 | inputs.append(HostDeviceMem(host_mem, device_mem)) 169 | else: 170 | outputs.append(HostDeviceMem(host_mem, device_mem)) 171 | return inputs, outputs, bindings, stream 172 | 173 | # This function is generalized for multiple inputs/outputs. 174 | # inputs and outputs are expected to be lists of HostDeviceMem objects. 175 | def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): 176 | # Transfer input data to the GPU. 177 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] 178 | # Run inference. 179 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) 180 | # Transfer predictions back from the GPU. 181 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] 182 | # Synchronize the stream 183 | stream.synchronize() 184 | # Return only the host outputs. 185 | return [out.host for out in outputs] 186 | 187 | # This function is generalized for multiple inputs/outputs for full dimension networks. 188 | # inputs and outputs are expected to be lists of HostDeviceMem objects. 189 | def do_inference_v2(context, bindings, inputs, outputs, stream): 190 | # Transfer input data to the GPU. 191 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] 192 | # Run inference. 193 | context.execute_async_v2(bindings=bindings, stream_handle=stream.handle) 194 | # Transfer predictions back from the GPU. 195 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] 196 | # Synchronize the stream 197 | stream.synchronize() 198 | # Return only the host outputs. 199 | return [out.host for out in outputs] 200 | -------------------------------------------------------------------------------- /configs/__pycache__/constant.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/constant.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-37.pyc -------------------------------------------------------------------------------- /configs/__pycache__/constant.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/configs/__pycache__/constant.cpython-38.pyc -------------------------------------------------------------------------------- /configs/constant.py: -------------------------------------------------------------------------------- 1 | # row anchors are a series of pre-defined coordinates in image height to detect lanes 2 | # the row anchors are defined according to the evaluation protocol of CULane and Tusimple 3 | # since our method will resize the image to 288x800 for training, the row anchors are defined with the height of 288 4 | # you can modify these row anchors according to your training image resolution 5 | 6 | tusimple_row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 7 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 8 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216, 9 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268, 10 | 272, 276, 280, 284] 11 | culane_row_anchor = [121, 131, 141, 150, 160, 170, 180, 189, 199, 209, 219, 228, 238, 248, 258, 267, 277, 287] 12 | -------------------------------------------------------------------------------- /configs/tusimple_4.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='Tusimple' 3 | data_root = "./data/Tusimple_ours" 4 | 5 | # TRAIN 6 | epoch = 60 7 | batch_size = 16 8 | optimizer = 'Adam' #['SGD','Adam'] 9 | # learning_rate = 0.1 10 | learning_rate = 2e-4 11 | weight_decay = 1e-4 12 | momentum = 0.9 13 | 14 | scheduler = 'cos' #['multi', 'cos'] 15 | # steps = [50,75] 16 | gamma = 0.1 17 | warmup = 'linear' 18 | warmup_iters = 100 19 | 20 | # NETWORK 21 | backbone = '18' 22 | griding_num = 100 23 | use_aux = True 24 | 25 | # LOSS 26 | sim_loss_w = 1.0 27 | shp_loss_w = 0.1 28 | 29 | # EXP 30 | note = '' 31 | 32 | log_path = '/media/kyle/Seagate/train_lane/train_lane4.log' 33 | 34 | # FINETUNE or RESUME MODEL PATH 35 | finetune = None 36 | resume = None 37 | 38 | # TEST 39 | test_model = None 40 | test_work_dir = "./data/Tusimple_ours" 41 | 42 | num_lanes = 4 43 | 44 | -------------------------------------------------------------------------------- /launch_opencv.py: -------------------------------------------------------------------------------- 1 | from UFLD import * 2 | import cv2 3 | import threading 4 | import time 5 | import numpy as np 6 | 7 | detector = laneDetection() 8 | detector.setResolution(640, 480) 9 | detector.setScaleFactor(4) 10 | frame = 0 11 | currentImage = None 12 | 13 | 14 | def threadDetect(): 15 | print("Detection initiating") 16 | global detector, currentImage, frame 17 | fps = [] 18 | print("Waiting for camera") 19 | time.sleep(1) 20 | ret = True 21 | print("Detection Begins:") 22 | while ret: 23 | t1 = time.time() 24 | detector.getFrame(currentImage) 25 | detector.preprocess() 26 | detector.inference() 27 | detector.parseResults() 28 | t2 = time.time() 29 | if frame > 30: 30 | fps.append(1/(t2-t1)) 31 | print("\ravg FPS: "+str(np.mean(fps)), end="", flush=True) 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | 37 | cap = cv2.VideoCapture(2) 38 | 39 | detecting = threading.Thread(target=threadDetect) 40 | detecting.setDaemon(True) 41 | detecting.start() 42 | 43 | 44 | while True: 45 | _, currentImage = cap.read() 46 | frame += 1 47 | 48 | detecting.join() 49 | 50 | 51 | -------------------------------------------------------------------------------- /mnist_calibration.cache: -------------------------------------------------------------------------------- 1 | TRT-7000-EntropyCalibration2 2 | input.1: 3c6966a5 3 | 127: 3c298ff7 4 | 128: 3b5df2ba 5 | 129: 3b5df2ba 6 | 130: 3b5df2ba 7 | 131: 3bd8b106 8 | 132: 3b5d4402 9 | 133: 3b151529 10 | 134: 3b5e7f99 11 | 135: 3bbe2fc8 12 | 136: 3b936c4c 13 | 137: 3b840b6f 14 | 138: 3bb62876 15 | 139: 3b44a09e 16 | 140: 3b1f11fd 17 | 141: 3b5c9d1f 18 | 142: 3b87e6e8 19 | 143: 3bd1e091 20 | 144: 3bd1e091 21 | 145: 3bed89d1 22 | 146: 3c3c2f3a 23 | 147: 3b7fc3b6 24 | 148: 3b636566 25 | 149: 3c2c5bd4 26 | 150: 3b248a2b 27 | 151: 3b69ad12 28 | 152: 3c749c21 29 | 153: 3b74d13d 30 | 154: 3c0c0a0a 31 | 155: 3c2a12fa 32 | 156: 3ade8c52 33 | 157: 3b74846d 34 | 158: 3c133ae7 35 | 159: 3c19d149 36 | 160: 3c0e6787 37 | 161: 3c1f198d 38 | 162: 3c475aa0 39 | 163: 3c020fdd 40 | 164: 3bc506df 41 | 165: 3c5efda5 42 | 166: 3b152a0b 43 | 167: 3b4bbbdd 44 | 168: 3c5adf69 45 | 169: 3c4ba73f 46 | 170: 3bcb724d 47 | 171: 3c4472d7 48 | 172: 3c03f7bf 49 | 173: 3b837231 50 | 174: 3c58ce13 51 | 175: 3c51b338 52 | 176: 3c4474b3 53 | 177: 3b98ca6b 54 | 178: 3c1a27de 55 | 179: 3bdc38ad 56 | 180: 3b06c3e8 57 | 181: 3c59fb33 58 | 182: 3ad771f3 59 | 183: 3c198c73 60 | 184: 3cb0b2f4 61 | 185: 3c6782e0 62 | 186: 3bcc34f8 63 | 187: 3be23788 64 | 188: 3be23788 65 | 189: 3b241ede 66 | 190: 3d46f62c 67 | 191: 3d4323f7 68 | 192: 3d549149 69 | 193: 3e3bfc93 70 | 195: 3e3bfc93 71 | (Unnamed Layer* 69) [Constant]_output: 3a0ac72c 72 | (Unnamed Layer* 70) [Matrix Multiply]_output: 3dd67f34 73 | (Unnamed Layer* 71) [Constant]_output: 38a25c2e 74 | (Unnamed Layer* 72) [Shuffle]_output: 38a25c2e 75 | 196: 3dd69133 76 | 197: 3dd69133 77 | (Unnamed Layer* 76) [Constant]_output: 396ebf65 78 | (Unnamed Layer* 77) [Matrix Multiply]_output: 3d8ddbd0 79 | (Unnamed Layer* 78) [Constant]_output: 38714d42 80 | (Unnamed Layer* 79) [Shuffle]_output: 38714d42 81 | 198: 3d8ddc8e 82 | 200: 3d8ddc8e 83 | -------------------------------------------------------------------------------- /model/__pycache__/backbone.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/backbone.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_convert.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model_convert.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model_convert2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/model/__pycache__/model_convert2.cpython-38.pyc -------------------------------------------------------------------------------- /model/backbone.py: -------------------------------------------------------------------------------- 1 | import torch,pdb 2 | import torchvision 3 | import torch.nn.modules 4 | 5 | class vgg16bn(torch.nn.Module): 6 | def __init__(self,pretrained = False): 7 | super(vgg16bn,self).__init__() 8 | model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children()) 9 | model = model[:33]+model[34:43] 10 | self.model = torch.nn.Sequential(*model) 11 | 12 | def forward(self,x): 13 | return self.model(x) 14 | class resnet(torch.nn.Module): 15 | def __init__(self,layers,pretrained = False): 16 | super(resnet,self).__init__() 17 | if layers == '18': 18 | model = torchvision.models.resnet18(pretrained=pretrained) 19 | elif layers == '34': 20 | model = torchvision.models.resnet34(pretrained=pretrained) 21 | elif layers == '50': 22 | model = torchvision.models.resnet50(pretrained=pretrained) 23 | elif layers == '101': 24 | model = torchvision.models.resnet101(pretrained=pretrained) 25 | elif layers == '152': 26 | model = torchvision.models.resnet152(pretrained=pretrained) 27 | elif layers == '50next': 28 | model = torchvision.models.resnext50_32x4d(pretrained=pretrained) 29 | elif layers == '101next': 30 | model = torchvision.models.resnext101_32x8d(pretrained=pretrained) 31 | elif layers == '50wide': 32 | model = torchvision.models.wide_resnet50_2(pretrained=pretrained) 33 | elif layers == '101wide': 34 | model = torchvision.models.wide_resnet101_2(pretrained=pretrained) 35 | else: 36 | raise NotImplementedError 37 | 38 | self.conv1 = model.conv1 39 | self.bn1 = model.bn1 40 | self.relu = model.relu 41 | self.maxpool = model.maxpool 42 | self.layer1 = model.layer1 43 | self.layer2 = model.layer2 44 | self.layer3 = model.layer3 45 | self.layer4 = model.layer4 46 | 47 | def forward(self,x): 48 | x = self.conv1(x) 49 | x = self.bn1(x) 50 | x = self.relu(x) 51 | x = self.maxpool(x) 52 | x = self.layer1(x) 53 | x2 = self.layer2(x) 54 | x3 = self.layer3(x2) 55 | x4 = self.layer4(x3) 56 | return x2,x3,x4 57 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from model.backbone import resnet 3 | import numpy as np 4 | 5 | class conv_bn_relu(torch.nn.Module): 6 | def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False): 7 | super(conv_bn_relu,self).__init__() 8 | self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size, 9 | stride = stride, padding = padding, dilation = dilation,bias = bias) 10 | self.bn = torch.nn.BatchNorm2d(out_channels) 11 | self.relu = torch.nn.ReLU() 12 | 13 | def forward(self,x): 14 | x = self.conv(x) 15 | x = self.bn(x) 16 | x = self.relu(x) 17 | return x 18 | class parsingNet(torch.nn.Module): 19 | def __init__(self, size=(288, 800), pretrained=True, backbone='50', cls_dim=(37, 10, 4), use_aux=False): 20 | super(parsingNet, self).__init__() 21 | 22 | self.size = size 23 | self.w = size[0] 24 | self.h = size[1] 25 | self.cls_dim = cls_dim # (num_gridding, num_cls_per_lane, num_of_lanes) 26 | # num_cls_per_lane is the number of row anchors 27 | self.use_aux = use_aux 28 | self.total_dim = np.prod(cls_dim)#produce 29 | 30 | # input : nchw, 31 | # output: (w+1) * sample_rows * 4 32 | self.model = resnet(backbone, pretrained=pretrained) 33 | 34 | if self.use_aux: 35 | self.aux_header2 = torch.nn.Sequential( 36 | conv_bn_relu(128, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1), 37 | conv_bn_relu(128,128,3,padding=1), 38 | conv_bn_relu(128,128,3,padding=1), 39 | conv_bn_relu(128,128,3,padding=1), 40 | ) 41 | self.aux_header3 = torch.nn.Sequential( 42 | conv_bn_relu(256, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1), 43 | conv_bn_relu(128,128,3,padding=1), 44 | conv_bn_relu(128,128,3,padding=1), 45 | ) 46 | self.aux_header4 = torch.nn.Sequential( 47 | conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1), 48 | conv_bn_relu(128,128,3,padding=1), 49 | ) 50 | self.aux_combine = torch.nn.Sequential( 51 | conv_bn_relu(384, 256, 3,padding=2,dilation=2), 52 | conv_bn_relu(256, 128, 3,padding=2,dilation=2), 53 | conv_bn_relu(128, 128, 3,padding=2,dilation=2), 54 | conv_bn_relu(128, 128, 3,padding=4,dilation=4), 55 | torch.nn.Conv2d(128, cls_dim[-1] + 1,1) 56 | # output : n, num_of_lanes+1, h, w 57 | ) 58 | initialize_weights(self.aux_header2,self.aux_header3,self.aux_header4,self.aux_combine) 59 | 60 | self.cls = torch.nn.Sequential( 61 | torch.nn.Linear(1800, 2048), 62 | torch.nn.ReLU(), 63 | torch.nn.Linear(2048, self.total_dim), 64 | ) 65 | 66 | self.pool = torch.nn.Conv2d(512,8,1) if backbone in ['34','18'] else torch.nn.Conv2d(2048,8,1) 67 | # 1/32,2048 channel 68 | # 288,800 -> 9,40,2048 69 | # (w+1) * sample_rows * 4 70 | # 37 * 10 * 4 71 | initialize_weights(self.cls) 72 | 73 | def forward(self, x): 74 | # n c h w - > n 2048 sh sw 75 | # -> n 2048 76 | x2,x3,fea = self.model(x) 77 | if self.use_aux: 78 | x2 = self.aux_header2(x2) 79 | x3 = self.aux_header3(x3) 80 | x3 = torch.nn.functional.interpolate(x3,scale_factor = 2,mode='bilinear') 81 | x4 = self.aux_header4(fea) 82 | x4 = torch.nn.functional.interpolate(x4,scale_factor = 4,mode='bilinear') 83 | aux_seg = torch.cat([x2,x3,x4],dim=1) 84 | aux_seg = self.aux_combine(aux_seg) 85 | else: 86 | aux_seg = None 87 | 88 | fea = self.pool(fea).view(-1, 1800) 89 | 90 | group_cls = self.cls(fea).view(-1, *self.cls_dim) 91 | 92 | if self.use_aux: 93 | return group_cls, aux_seg 94 | 95 | return group_cls 96 | 97 | 98 | def initialize_weights(*models): 99 | for model in models: 100 | real_init_weights(model) 101 | def real_init_weights(m): 102 | 103 | if isinstance(m, list): 104 | for mini_m in m: 105 | real_init_weights(mini_m) 106 | else: 107 | if isinstance(m, torch.nn.Conv2d): 108 | torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 109 | if m.bias is not None: 110 | torch.nn.init.constant_(m.bias, 0) 111 | elif isinstance(m, torch.nn.Linear): 112 | m.weight.data.normal_(0.0, std=0.01) 113 | elif isinstance(m, torch.nn.BatchNorm2d): 114 | torch.nn.init.constant_(m.weight, 1) 115 | torch.nn.init.constant_(m.bias, 0) 116 | elif isinstance(m,torch.nn.Module): 117 | for mini_m in m.children(): 118 | real_init_weights(mini_m) 119 | else: 120 | print('unkonwn module', m) 121 | -------------------------------------------------------------------------------- /onnx_to_tensorrt.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import cv2 6 | import tensorrt as trt 7 | import pycuda.driver as cuda 8 | import pycuda.autoinit 9 | import numpy as np 10 | 11 | EXPLICIT_BATCH = [] 12 | if trt.__version__[0] >= '7': 13 | EXPLICIT_BATCH.append( 14 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 15 | 16 | mode = 'fp16' 17 | 18 | def build_engine(onnx_file_path, mode, verbose=False): 19 | """Build a TensorRT engine from an ONNX file.""" 20 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() 21 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: 22 | builder.max_workspace_size = 1 << 30 23 | builder.max_batch_size = 1 24 | if mode=='fp16': 25 | builder.fp16_mode = True 26 | else: 27 | builder.fp16_mode = False 28 | #builder.strict_type_constraints = True 29 | 30 | # Parse model file 31 | print('Loading ONNX file from path {}...'.format(onnx_file_path)) 32 | with open(onnx_file_path, 'rb') as model: 33 | if not parser.parse(model.read()): 34 | print('ERROR: Failed to parse the ONNX file.') 35 | for error in range(parser.num_errors): 36 | print(parser.get_error(error)) 37 | return None 38 | if trt.__version__[0] >= '7': 39 | # Reshape input to batch size 1 40 | shape = list(network.get_input(0).shape) 41 | shape[0] = 1 42 | network.get_input(0).shape = shape 43 | 44 | model_name = onnx_file_path[:-5] 45 | 46 | print('Building an engine. This would take a while...') 47 | print('(Use "--verbose" to enable verbose logging.)') 48 | engine = builder.build_cuda_engine(network) 49 | print('Completed creating engine.') 50 | return engine 51 | 52 | 53 | def main(): 54 | """Create a TensorRT engine for ONNX-based Model.""" 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | '-v', '--verbose', action='store_true', 58 | help='enable verbose output (for debugging)') 59 | parser.add_argument( 60 | '-m', '--model', type=str, default='model.onnx') 61 | parser.add_argument( 62 | '-p', '--precision', type=str, default='fp16') 63 | args = parser.parse_args() 64 | 65 | mode = args.precision 66 | onnx_file_path = args.model 67 | if not os.path.isfile(onnx_file_path): 68 | raise SystemExit('ERROR: file (%s) not found!' % onnx_file_path) 69 | if mode=='fp16': 70 | engine_file_path = '%s_fp16.trt'% args.model[:-5] 71 | elif mode == 'fp32': 72 | engine_file_path = '%s_fp32.trt'% args.model[:-5] 73 | else: 74 | print("illegal mode") 75 | exit(0) 76 | engine = build_engine(onnx_file_path, mode,args.verbose) 77 | with open(engine_file_path, 'wb') as f: 78 | f.write(engine.serialize()) 79 | print('Serialized the TensorRT engine to file: %s' % engine_file_path) 80 | 81 | 82 | 83 | if __name__ == '__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /onnx_to_tensorrt_int8.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import cv2 6 | import tensorrt as trt 7 | import pycuda.driver as cuda 8 | import pycuda.autoinit 9 | import numpy as np 10 | import torchvision.transforms as transforms 11 | from PIL import Image 12 | 13 | img_transforms = transforms.Compose([ 14 | transforms.Resize((288, 800)), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 17 | ]) 18 | 19 | EXPLICIT_BATCH = [] 20 | if trt.__version__[0] >= '7': 21 | EXPLICIT_BATCH.append( 22 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 23 | 24 | 25 | 26 | class EntropyCalibrator(trt.IInt8EntropyCalibrator2): 27 | def __init__(self, training_data, cache_file, batch_size=16): 28 | # Whenever you specify a custom constructor for a TensorRT class, 29 | # you MUST call the constructor of the parent explicitly. 30 | trt.IInt8EntropyCalibrator2.__init__(self) 31 | 32 | self.cache_file = cache_file 33 | # Every time get_batch is called, the next batch of size batch_size will be copied to the device and returned. 34 | self.data = self.load_data(training_data) 35 | self.batch_size = batch_size 36 | self.current_index = 0 37 | 38 | # Allocate enough memory for a whole batch. 39 | self.device_input = cuda.mem_alloc(self.data[0].nbytes * self.batch_size) 40 | 41 | # Returns a numpy buffer of shape (num_images, 1, 28, 28) 42 | def load_data(self, datapath): 43 | print("loading image data") 44 | imgs = os.listdir(datapath) 45 | dataset = [] 46 | for data in imgs: 47 | img = cv2.imread(datapath+data) 48 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 49 | img = Image.fromarray(img) 50 | img = img_transforms(img).numpy() 51 | dataset.append(img) 52 | print(dataset) 53 | return np.array(dataset) 54 | 55 | def get_batch_size(self): 56 | return self.batch_size 57 | 58 | # TensorRT passes along the names of the engine bindings to the get_batch function. 59 | # You don't necessarily have to use them, but they can be useful to understand the order of 60 | # the inputs. The bindings list is expected to have the same ordering as 'names'. 61 | def get_batch(self, names): 62 | if self.current_index + self.batch_size > self.data.shape[0]: 63 | return None 64 | 65 | current_batch = int(self.current_index / self.batch_size) 66 | if current_batch % 10 == 0: 67 | print("Calibrating batch {:}, containing {:} images".format(current_batch, self.batch_size)) 68 | 69 | batch = self.data[self.current_index:self.current_index + self.batch_size].ravel() 70 | cuda.memcpy_htod(self.device_input, batch) 71 | self.current_index += self.batch_size 72 | return [self.device_input] 73 | 74 | def read_calibration_cache(self): 75 | # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. 76 | if os.path.exists(self.cache_file): 77 | with open(self.cache_file, "rb") as f: 78 | return f.read() 79 | 80 | def write_calibration_cache(self, cache): 81 | with open(self.cache_file, "wb") as f: 82 | f.write(cache) 83 | 84 | def build_int8_engine(onnx_file_path, calib, batch_size, verbose=False): 85 | """Build a TensorRT engine from an ONNX file.""" 86 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() 87 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: 88 | builder.max_workspace_size = 1 << 30 89 | builder.max_batch_size = 1 90 | builder.int8_mode = True 91 | builder.int8_calibrator = calib 92 | 93 | # Parse model file 94 | print('Loading ONNX file from path {}...'.format(onnx_file_path)) 95 | with open(onnx_file_path, 'rb') as model: 96 | if not parser.parse(model.read()): 97 | print('ERROR: Failed to parse the ONNX file.') 98 | for error in range(parser.num_errors): 99 | print(parser.get_error(error)) 100 | return None 101 | if trt.__version__[0] >= '7': 102 | # Reshape input to batch size 1 103 | shape = list(network.get_input(0).shape) 104 | shape[0] = 1 105 | network.get_input(0).shape = shape 106 | 107 | print('Adding yolo_layer plugins...') 108 | model_name = onnx_file_path[:-5] 109 | 110 | print('Building an engine. This would take a while...') 111 | print('(Use "--verbose" to enable verbose logging.)') 112 | engine = builder.build_cuda_engine(network) 113 | print('Completed creating engine.') 114 | return engine 115 | 116 | 117 | def main(): 118 | """Create a TensorRT engine for ONNX-based model.""" 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument( 121 | '-v', '--verbose', action='store_true', 122 | help='enable verbose output (for debugging)') 123 | parser.add_argument( 124 | '-m', '--model', type=str, default='model.onnx', 125 | ) 126 | args = parser.parse_args() 127 | 128 | calibration_cache = "calibration.cache" 129 | data_path = 'calibration_data/testset/' 130 | calib = EntropyCalibrator(data_path, cache_file=calibration_cache) 131 | 132 | onnx_file_path = args.model 133 | if not os.path.isfile(onnx_file_path): 134 | raise SystemExit('ERROR: file (%s) not found!' % onnx_file_path) 135 | engine_file_path = '%s_int8.trt' % args.model[:-5] 136 | engine = build_int8_engine(onnx_file_path, calib, 16) 137 | with open(engine_file_path, 'wb') as f: 138 | f.write(engine.serialize()) 139 | print('Serialized the TensorRT engine to file: %s' % engine_file_path) 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | opencv-python 3 | matplotlib 4 | PySpin 5 | scipy 6 | addict 7 | tqdm 8 | tensorboard 9 | onnx 10 | tensorrt 11 | -------------------------------------------------------------------------------- /tensorrt_run.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import argparse 5 | import cv2 6 | import tensorrt as trt 7 | import common 8 | import pycuda.driver as cuda 9 | import pycuda.autoinit 10 | import numpy as np 11 | import pycuda.gpuarray as gpuarray 12 | import time 13 | import scipy.special 14 | import torchvision.transforms as transforms 15 | from PIL import Image 16 | 17 | img_transforms = transforms.Compose([ 18 | transforms.Resize((288, 800)), 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 21 | ]) 22 | 23 | col_sample = np.linspace(0, 800 - 1, 100) 24 | col_sample_w = col_sample[1] - col_sample[0] 25 | 26 | img_w = 640 27 | img_h = 480 28 | 29 | row_anchor = [ 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 30 | 116, 120, 124, 128, 132, 136, 140, 144, 148, 152, 156, 160, 164, 31 | 168, 172, 176, 180, 184, 188, 192, 196, 200, 204, 208, 212, 216, 32 | 220, 224, 228, 232, 236, 240, 244, 248, 252, 256, 260, 264, 268, 33 | 272, 276, 280, 284] 34 | 35 | color = [(255,255,0), (255,0,0),(0,0,255),(0,255,0)] 36 | 37 | EXPLICIT_BATCH = [] 38 | if trt.__version__[0] >= '7': 39 | EXPLICIT_BATCH.append( 40 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 41 | 42 | def load_engine(trt_file_path, verbose=False): 43 | """Build a TensorRT engine from a TRT file.""" 44 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() 45 | print('Loading TRT file from path {}...'.format(trt_file_path)) 46 | with open(trt_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: 47 | engine = runtime.deserialize_cuda_engine(f.read()) 48 | return engine 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument( 53 | '-v', '--verbose', action='store_true', 54 | help='enable verbose output (for debugging)') 55 | parser.add_argument( 56 | '-m', '--model', type=str, default='model', 57 | ) 58 | args = parser.parse_args() 59 | 60 | 61 | 62 | trt_file_path = '%s.trt' % args.model 63 | if not os.path.isfile(trt_file_path): 64 | raise SystemExit('ERROR: file (%s) not found!' % trt_file_path) 65 | engine_file_path = '%s.trt' % args.model 66 | engine = load_engine(trt_file_path, args.verbose) 67 | 68 | h_inputs, h_outputs, bindings, stream = common.allocate_buffers(engine) 69 | 70 | 71 | cap = cv2.VideoCapture(2) 72 | with engine.create_execution_context() as context: 73 | while True: 74 | _,frame = cap.read() 75 | t1 = time.time() 76 | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 77 | img = Image.fromarray(img) 78 | img = img_transforms(img).numpy() 79 | 80 | h_inputs[0].host = img 81 | t3 = time.time() 82 | trt_outputs = common.do_inference_v2(context, bindings=bindings, inputs=h_inputs, outputs=h_outputs, stream=stream) 83 | t4 = time.time() 84 | 85 | 86 | out_j = trt_outputs[0].reshape(101, 56, 4) 87 | 88 | prob = scipy.special.softmax(out_j[:-1, :, :], axis=0) 89 | 90 | 91 | idx = np.arange(100) + 1 92 | idx = idx.reshape(-1, 1, 1) 93 | 94 | loc = np.sum(prob * idx, axis=0) 95 | out_j = np.argmax(out_j, axis=0) 96 | loc[out_j == 100] = 0 97 | out_j = loc 98 | 99 | # import pdb; pdb.set_trace() 100 | vis = frame 101 | for i in range(out_j.shape[1]): 102 | if np.sum(out_j[:, i] != 0) > 2: 103 | for k in range(out_j.shape[0]): 104 | if out_j[k, i] > 0: 105 | ppp = (int(out_j[k, i] * col_sample_w * img_w / 800) - 1, int(img_h * (row_anchor[k]/288)) - 1 ) 106 | cv2.circle(vis,ppp, img_w//300 ,color[i],-1) 107 | 108 | t2 = time.time() 109 | print('Inference time', (t4-t3)*1000) 110 | print('FPS', int(1/((t2-t1)))) 111 | cv2.imshow("OUTPUT", vis) 112 | cv2.waitKey(1) 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | 118 | -------------------------------------------------------------------------------- /test_devices.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | import time 4 | import numpy as np 5 | 6 | currentImage = None 7 | 8 | 9 | 10 | frame = 0 11 | if __name__ == "__main__": 12 | cap = cv2.VideoCapture(0) 13 | 14 | while True: 15 | frame += 1 16 | _, currentImage = cap.read() 17 | cv2.imshow("",currentImage) 18 | key = cv2.waitKey(1) 19 | if key == 'a': 20 | cv2.imwrite(str(frame)+".jpg",currentImage) 21 | 22 | 23 | -------------------------------------------------------------------------------- /torch2onnx.py: -------------------------------------------------------------------------------- 1 | from UFLD import * 2 | import cv2 3 | import time 4 | import numpy as np 5 | import torch 6 | import onnx 7 | 8 | 9 | detector = laneDetection() 10 | detector.setResolution(640, 480) 11 | frame = 0 12 | currentImage = None 13 | 14 | 15 | if __name__ == "__main__": 16 | filepath = "model.onnx" 17 | dummy_input = torch.rand((1,3,288,800)).cuda() 18 | torch.onnx.export(detector.net, dummy_input, filepath) 19 | -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dist_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/dist_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/factory.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KopiSoftware/TRT_Ultra_Fast_Lane_Detect/8ec69d25220f531a44933c5c5b8cae915511041f/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | from utils.dist_utils import is_main_process, dist_print, DistSummaryWriter 3 | from utils.config import Config 4 | import torch 5 | 6 | def str2bool(v): 7 | if isinstance(v, bool): 8 | return v 9 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 10 | return True 11 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 12 | return False 13 | else: 14 | raise argparse.ArgumentTypeError('Boolean value expected.') 15 | 16 | def get_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('config', help = 'path to config file') 19 | parser.add_argument('--local_rank', type=int, default=0) 20 | 21 | parser.add_argument('--dataset', default = 'Tusimple', type = str) 22 | parser.add_argument('--data_root', default = None, type = str) 23 | parser.add_argument('--epoch', default = None, type = int) 24 | parser.add_argument('--batch_size', default = None, type = int) 25 | parser.add_argument('--optimizer', default = None, type = str) 26 | parser.add_argument('--learning_rate', default = None, type = float) 27 | parser.add_argument('--weight_decay', default = None, type = float) 28 | parser.add_argument('--momentum', default = None, type = float) 29 | parser.add_argument('--scheduler', default = None, type = str) 30 | parser.add_argument('--steps', default = None, type = int, nargs='+') 31 | parser.add_argument('--gamma', default = None, type = float) 32 | parser.add_argument('--warmup', default = None, type = str) 33 | parser.add_argument('--warmup_iters', default = None, type = int) 34 | parser.add_argument('--backbone', default = None, type = str) 35 | parser.add_argument('--griding_num', default = 100, type = int) 36 | parser.add_argument('--use_aux', default = None, type = str2bool) 37 | parser.add_argument('--sim_loss_w', default = None, type = float) 38 | parser.add_argument('--shp_loss_w', default = None, type = float) 39 | parser.add_argument('--note', default = None, type = str) 40 | parser.add_argument('--log_path', default = None, type = str) 41 | parser.add_argument('--finetune', default = None, type = str) 42 | parser.add_argument('--resume', default = None, type = str) 43 | parser.add_argument('--test_model', default = 'model.pth', type = str) 44 | parser.add_argument('--test_work_dir', default = None, type = str) 45 | parser.add_argument('--num_lanes', default = 4, type = int) 46 | parser.add_argument('--video', default = 'test.avi', type = str) 47 | 48 | return parser 49 | 50 | def merge_config(): 51 | args = get_args().parse_args() 52 | cfg = Config.fromfile(args.config) 53 | 54 | items = ['dataset','data_root','epoch','batch_size','optimizer','learning_rate', 55 | 'weight_decay','momentum','scheduler','steps','gamma','warmup','warmup_iters', 56 | 'use_aux','griding_num','backbone','sim_loss_w','shp_loss_w','note','log_path', 57 | 'finetune','resume', 'test_model','test_work_dir', 'num_lanes','video'] 58 | for item in items: 59 | if getattr(args, item) is not None: 60 | dist_print('merge ', item, ' config') 61 | setattr(cfg, item, getattr(args, item)) 62 | return args, cfg 63 | 64 | 65 | def save_model(net, optimizer, epoch,save_path, distributed): 66 | if is_main_process(): 67 | model_state_dict = net.state_dict() 68 | state = {'model': model_state_dict, 'optimizer': optimizer.state_dict()} 69 | # state = {'model': model_state_dict} 70 | assert os.path.exists(save_path) 71 | model_path = os.path.join(save_path, 'ep%03d.pth' % epoch) 72 | torch.save(state, model_path) 73 | 74 | import pathspec 75 | 76 | def cp_projects(to_path): 77 | if is_main_process(): 78 | with open('./.gitignore','r') as fp: 79 | ign = fp.read() 80 | ign += '\n.git' 81 | spec = pathspec.PathSpec.from_lines(pathspec.patterns.GitWildMatchPattern, ign.splitlines()) 82 | all_files = {os.path.join(root,name) for root,dirs,files in os.walk('./') for name in files} 83 | matches = spec.match_files(all_files) 84 | matches = set(matches) 85 | to_cp_files = all_files - matches 86 | # to_cp_files = [f[2:] for f in to_cp_files] 87 | # pdb.set_trace() 88 | for f in to_cp_files: 89 | dirs = os.path.join(to_path,'code',os.path.split(f[2:])[0]) 90 | if not os.path.exists(dirs): 91 | os.makedirs(dirs) 92 | os.system('cp %s %s'%(f,os.path.join(to_path,'code',f[2:]))) 93 | 94 | 95 | import datetime, os 96 | def get_work_dir(cfg): 97 | now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') 98 | hyper_param_str = '_lr_%1.0e_b_%d' % (cfg.learning_rate, cfg.batch_size) 99 | work_dir = os.path.join(cfg.log_path, now + hyper_param_str + cfg.note) 100 | return work_dir 101 | 102 | def get_logger(work_dir, cfg): 103 | logger = DistSummaryWriter(work_dir) 104 | config_txt = os.path.join(work_dir, 'cfg.txt') 105 | if is_main_process(): 106 | with open(config_txt, 'w') as fp: 107 | fp.write(str(cfg)) 108 | 109 | return logger 110 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | import shutil 4 | import sys 5 | import tempfile 6 | from argparse import Action, ArgumentParser 7 | from collections import abc 8 | from importlib import import_module 9 | 10 | from addict import Dict 11 | 12 | 13 | BASE_KEY = '_base_' 14 | DELETE_KEY = '_delete_' 15 | 16 | 17 | class ConfigDict(Dict): 18 | 19 | def __missing__(self, name): 20 | raise KeyError(name) 21 | 22 | def __getattr__(self, name): 23 | try: 24 | value = super(ConfigDict, self).__getattr__(name) 25 | except KeyError: 26 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 27 | f"attribute '{name}'") 28 | except Exception as e: 29 | ex = e 30 | else: 31 | return value 32 | raise ex 33 | 34 | 35 | def add_args(parser, cfg, prefix=''): 36 | for k, v in cfg.items(): 37 | if isinstance(v, str): 38 | parser.add_argument('--' + prefix + k) 39 | elif isinstance(v, int): 40 | parser.add_argument('--' + prefix + k, type=int) 41 | elif isinstance(v, float): 42 | parser.add_argument('--' + prefix + k, type=float) 43 | elif isinstance(v, bool): 44 | parser.add_argument('--' + prefix + k, action='store_true') 45 | elif isinstance(v, dict): 46 | add_args(parser, v, prefix + k + '.') 47 | elif isinstance(v, abc.Iterable): 48 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 49 | else: 50 | print(f'cannot parse key {prefix + k} of type {type(v)}') 51 | return parser 52 | 53 | 54 | class Config(object): 55 | """A facility for config and config files. 56 | It supports common file formats as configs: python/json/yaml. The interface 57 | is the same as a dict object and also allows access config values as 58 | attributes. 59 | Example: 60 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 61 | >>> cfg.a 62 | 1 63 | >>> cfg.b 64 | {'b1': [0, 1]} 65 | >>> cfg.b.b1 66 | [0, 1] 67 | >>> cfg = Config.fromfile('tests/data/config/a.py') 68 | >>> cfg.filename 69 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 70 | >>> cfg.item4 71 | 'test' 72 | >>> cfg 73 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 74 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 75 | """ 76 | 77 | @staticmethod 78 | def _file2dict(filename): 79 | filename = osp.abspath(osp.expanduser(filename)) 80 | if filename.endswith('.py'): 81 | with tempfile.TemporaryDirectory() as temp_config_dir: 82 | temp_config_file = tempfile.NamedTemporaryFile( 83 | dir=temp_config_dir, suffix='.py') 84 | temp_config_name = osp.basename(temp_config_file.name) 85 | # close temp file 86 | temp_config_file.close() 87 | shutil.copyfile(filename, 88 | osp.join(temp_config_dir, temp_config_name)) 89 | temp_module_name = osp.splitext(temp_config_name)[0] 90 | sys.path.insert(0, temp_config_dir) 91 | mod = import_module(temp_module_name) 92 | sys.path.pop(0) 93 | cfg_dict = { 94 | name: value 95 | for name, value in mod.__dict__.items() 96 | if not name.startswith('__') 97 | } 98 | # delete imported module 99 | del sys.modules[temp_module_name] 100 | 101 | elif filename.endswith(('.yml', '.yaml', '.json')): 102 | import mmcv 103 | cfg_dict = mmcv.load(filename) 104 | else: 105 | raise IOError('Only py/yml/yaml/json type are supported now!') 106 | 107 | cfg_text = filename + '\n' 108 | with open(filename, 'r') as f: 109 | cfg_text += f.read() 110 | 111 | if BASE_KEY in cfg_dict: 112 | cfg_dir = osp.dirname(filename) 113 | base_filename = cfg_dict.pop(BASE_KEY) 114 | base_filename = base_filename if isinstance( 115 | base_filename, list) else [base_filename] 116 | 117 | cfg_dict_list = list() 118 | cfg_text_list = list() 119 | for f in base_filename: 120 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) 121 | cfg_dict_list.append(_cfg_dict) 122 | cfg_text_list.append(_cfg_text) 123 | 124 | base_cfg_dict = dict() 125 | for c in cfg_dict_list: 126 | if len(base_cfg_dict.keys() & c.keys()) > 0: 127 | raise KeyError('Duplicate key is not allowed among bases') 128 | base_cfg_dict.update(c) 129 | 130 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) 131 | cfg_dict = base_cfg_dict 132 | 133 | # merge cfg_text 134 | cfg_text_list.append(cfg_text) 135 | cfg_text = '\n'.join(cfg_text_list) 136 | 137 | return cfg_dict, cfg_text 138 | 139 | @staticmethod 140 | def _merge_a_into_b(a, b): 141 | # merge dict `a` into dict `b` (non-inplace). values in `a` will 142 | # overwrite `b`. 143 | # copy first to avoid inplace modification 144 | b = b.copy() 145 | for k, v in a.items(): 146 | if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): 147 | if not isinstance(b[k], dict): 148 | raise TypeError( 149 | f'{k}={v} in child config cannot inherit from base ' 150 | f'because {k} is a dict in the child config but is of ' 151 | f'type {type(b[k])} in base config. You may set ' 152 | f'`{DELETE_KEY}=True` to ignore the base config') 153 | b[k] = Config._merge_a_into_b(v, b[k]) 154 | else: 155 | b[k] = v 156 | return b 157 | 158 | @staticmethod 159 | def fromfile(filename): 160 | cfg_dict, cfg_text = Config._file2dict(filename) 161 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename) 162 | 163 | @staticmethod 164 | def auto_argparser(description=None): 165 | """Generate argparser from config file automatically (experimental) 166 | """ 167 | partial_parser = ArgumentParser(description=description) 168 | partial_parser.add_argument('config', help='config file path') 169 | cfg_file = partial_parser.parse_known_args()[0].config 170 | cfg = Config.fromfile(cfg_file) 171 | parser = ArgumentParser(description=description) 172 | parser.add_argument('config', help='config file path') 173 | add_args(parser, cfg) 174 | return parser, cfg 175 | 176 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 177 | if cfg_dict is None: 178 | cfg_dict = dict() 179 | elif not isinstance(cfg_dict, dict): 180 | raise TypeError('cfg_dict must be a dict, but ' 181 | f'got {type(cfg_dict)}') 182 | 183 | super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 184 | super(Config, self).__setattr__('_filename', filename) 185 | if cfg_text: 186 | text = cfg_text 187 | elif filename: 188 | with open(filename, 'r') as f: 189 | text = f.read() 190 | else: 191 | text = '' 192 | super(Config, self).__setattr__('_text', text) 193 | 194 | @property 195 | def filename(self): 196 | return self._filename 197 | 198 | @property 199 | def text(self): 200 | return self._text 201 | 202 | @property 203 | def pretty_text(self): 204 | 205 | indent = 4 206 | 207 | def _indent(s_, num_spaces): 208 | s = s_.split('\n') 209 | if len(s) == 1: 210 | return s_ 211 | first = s.pop(0) 212 | s = [(num_spaces * ' ') + line for line in s] 213 | s = '\n'.join(s) 214 | s = first + '\n' + s 215 | return s 216 | 217 | def _format_basic_types(k, v): 218 | if isinstance(v, str): 219 | v_str = f"'{v}'" 220 | else: 221 | v_str = str(v) 222 | attr_str = f'{str(k)}={v_str}' 223 | attr_str = _indent(attr_str, indent) 224 | 225 | return attr_str 226 | 227 | def _format_list(k, v): 228 | # check if all items in the list are dict 229 | if all(isinstance(_, dict) for _ in v): 230 | v_str = '[\n' 231 | v_str += '\n'.join( 232 | f'dict({_indent(_format_dict(v_), indent)}),' 233 | for v_ in v).rstrip(',') 234 | attr_str = f'{str(k)}={v_str}' 235 | attr_str = _indent(attr_str, indent) + ']' 236 | else: 237 | attr_str = _format_basic_types(k, v) 238 | return attr_str 239 | 240 | def _format_dict(d, outest_level=False): 241 | r = '' 242 | s = [] 243 | for idx, (k, v) in enumerate(d.items()): 244 | is_last = idx >= len(d) - 1 245 | end = '' if outest_level or is_last else ',' 246 | if isinstance(v, dict): 247 | v_str = '\n' + _format_dict(v) 248 | attr_str = f'{str(k)}=dict({v_str}' 249 | attr_str = _indent(attr_str, indent) + ')' + end 250 | elif isinstance(v, list): 251 | attr_str = _format_list(k, v) + end 252 | else: 253 | attr_str = _format_basic_types(k, v) + end 254 | 255 | s.append(attr_str) 256 | r += '\n'.join(s) 257 | return r 258 | 259 | cfg_dict = self._cfg_dict.to_dict() 260 | text = _format_dict(cfg_dict, outest_level=True) 261 | 262 | return text 263 | 264 | def __repr__(self): 265 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 266 | 267 | def __len__(self): 268 | return len(self._cfg_dict) 269 | 270 | def __getattr__(self, name): 271 | return getattr(self._cfg_dict, name) 272 | 273 | def __getitem__(self, name): 274 | return self._cfg_dict.__getitem__(name) 275 | 276 | def __setattr__(self, name, value): 277 | if isinstance(value, dict): 278 | value = ConfigDict(value) 279 | self._cfg_dict.__setattr__(name, value) 280 | 281 | def __setitem__(self, name, value): 282 | if isinstance(value, dict): 283 | value = ConfigDict(value) 284 | self._cfg_dict.__setitem__(name, value) 285 | 286 | def __iter__(self): 287 | return iter(self._cfg_dict) 288 | 289 | def dump(self): 290 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 291 | format_text = json.dumps(cfg_dict, indent=2) 292 | return format_text 293 | 294 | def merge_from_dict(self, options): 295 | """Merge list into cfg_dict 296 | Merge the dict parsed by MultipleKVAction into this cfg. 297 | Examples: 298 | >>> options = {'model.backbone.depth': 50, 299 | ... 'model.backbone.with_cp':True} 300 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 301 | >>> cfg.merge_from_dict(options) 302 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 303 | >>> assert cfg_dict == dict( 304 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 305 | Args: 306 | options (dict): dict of configs to merge from. 307 | """ 308 | option_cfg_dict = {} 309 | for full_key, v in options.items(): 310 | d = option_cfg_dict 311 | key_list = full_key.split('.') 312 | for subkey in key_list[:-1]: 313 | d.setdefault(subkey, ConfigDict()) 314 | d = d[subkey] 315 | subkey = key_list[-1] 316 | d[subkey] = v 317 | 318 | cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 319 | super(Config, self).__setattr__( 320 | '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict)) 321 | 322 | 323 | class DictAction(Action): 324 | """ 325 | argparse action to split an argument into KEY=VALUE form 326 | on the first = and append to a dictionary. List options should 327 | be passed as comma separated values, i.e KEY=V1,V2,V3 328 | """ 329 | 330 | @staticmethod 331 | def _parse_int_float_bool(val): 332 | try: 333 | return int(val) 334 | except ValueError: 335 | pass 336 | try: 337 | return float(val) 338 | except ValueError: 339 | pass 340 | if val.lower() in ['true', 'false']: 341 | return True if val.lower() == 'true' else False 342 | return val 343 | 344 | def __call__(self, parser, namespace, values, option_string=None): 345 | options = {} 346 | for kv in values: 347 | key, val = kv.split('=', maxsplit=1) 348 | val = [self._parse_int_float_bool(v) for v in val.split(',')] 349 | if len(val) == 1: 350 | val = val[0] 351 | options[key] = val 352 | setattr(namespace, self.dest, options) -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import pickle 4 | 5 | 6 | def get_world_size(): 7 | if not dist.is_available(): 8 | return 1 9 | if not dist.is_initialized(): 10 | return 1 11 | return dist.get_world_size() 12 | 13 | 14 | def to_python_float(t): 15 | if hasattr(t, 'item'): 16 | return t.item() 17 | else: 18 | return t[0] 19 | 20 | 21 | def get_rank(): 22 | if not dist.is_available(): 23 | return 0 24 | if not dist.is_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | def is_main_process(): 30 | return get_rank() == 0 31 | 32 | 33 | def can_log(): 34 | return is_main_process() 35 | 36 | 37 | def dist_print(*args, **kwargs): 38 | if can_log(): 39 | print(*args, **kwargs) 40 | 41 | 42 | def synchronize(): 43 | """ 44 | Helper function to synchronize (barrier) among all processes when 45 | using distributed training 46 | """ 47 | if not dist.is_available(): 48 | return 49 | if not dist.is_initialized(): 50 | return 51 | world_size = dist.get_world_size() 52 | if world_size == 1: 53 | return 54 | dist.barrier() 55 | 56 | def dist_cat_reduce_tensor(tensor): 57 | if not dist.is_available(): 58 | return tensor 59 | if not dist.is_initialized(): 60 | return tensor 61 | # dist_print(tensor) 62 | rt = tensor.clone() 63 | all_list = [torch.zeros_like(tensor) for _ in range(get_world_size())] 64 | dist.all_gather(all_list,rt) 65 | # dist_print(all_list[0][1],all_list[1][1],all_list[2][1],all_list[3][1]) 66 | # dist_print(all_list[0][2],all_list[1][2],all_list[2][2],all_list[3][2]) 67 | # dist_print(all_list[0][3],all_list[1][3],all_list[2][3],all_list[3][3]) 68 | # dist_print(all_list[0].shape) 69 | return torch.cat(all_list,dim = 0) 70 | 71 | def dist_sum_reduce_tensor(tensor): 72 | if not dist.is_available(): 73 | return tensor 74 | if not dist.is_initialized(): 75 | return tensor 76 | if not isinstance(tensor, torch.Tensor): 77 | return tensor 78 | rt = tensor.clone() 79 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 80 | return rt 81 | 82 | 83 | def dist_mean_reduce_tensor(tensor): 84 | rt = dist_sum_reduce_tensor(tensor) 85 | rt /= get_world_size() 86 | return rt 87 | 88 | 89 | def all_gather(data): 90 | """ 91 | Run all_gather on arbitrary picklable data (not necessarily tensors) 92 | Args: 93 | data: any picklable object 94 | Returns: 95 | list[data]: list of data gathered from each rank 96 | """ 97 | world_size = get_world_size() 98 | if world_size == 1: 99 | return [data] 100 | 101 | # serialized to a Tensor 102 | buffer = pickle.dumps(data) 103 | storage = torch.ByteStorage.from_buffer(buffer) 104 | tensor = torch.ByteTensor(storage).to("cuda") 105 | 106 | # obtain Tensor size of each rank 107 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 108 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 109 | dist.all_gather(size_list, local_size) 110 | size_list = [int(size.item()) for size in size_list] 111 | max_size = max(size_list) 112 | 113 | # receiving Tensor from all ranks 114 | # we pad the tensor because torch all_gather does not support 115 | # gathering tensors of different shapes 116 | tensor_list = [] 117 | for _ in size_list: 118 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 119 | if local_size != max_size: 120 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 121 | tensor = torch.cat((tensor, padding), dim=0) 122 | dist.all_gather(tensor_list, tensor) 123 | 124 | data_list = [] 125 | for size, tensor in zip(size_list, tensor_list): 126 | buffer = tensor.cpu().numpy().tobytes()[:size] 127 | data_list.append(pickle.loads(buffer)) 128 | 129 | return data_list 130 | 131 | 132 | from torch.utils.tensorboard import SummaryWriter 133 | 134 | 135 | class DistSummaryWriter(SummaryWriter): 136 | def __init__(self, *args, **kwargs): 137 | if can_log(): 138 | super(DistSummaryWriter, self).__init__(*args, **kwargs) 139 | 140 | def add_scalar(self, *args, **kwargs): 141 | if can_log(): 142 | super(DistSummaryWriter, self).add_scalar(*args, **kwargs) 143 | 144 | def add_figure(self, *args, **kwargs): 145 | if can_log(): 146 | super(DistSummaryWriter, self).add_figure(*args, **kwargs) 147 | 148 | def add_graph(self, *args, **kwargs): 149 | if can_log(): 150 | super(DistSummaryWriter, self).add_graph(*args, **kwargs) 151 | 152 | def add_histogram(self, *args, **kwargs): 153 | if can_log(): 154 | super(DistSummaryWriter, self).add_histogram(*args, **kwargs) 155 | 156 | def add_image(self, *args, **kwargs): 157 | if can_log(): 158 | super(DistSummaryWriter, self).add_image(*args, **kwargs) 159 | 160 | def close(self): 161 | if can_log(): 162 | super(DistSummaryWriter, self).close() 163 | 164 | 165 | import tqdm 166 | 167 | 168 | def dist_tqdm(obj, *args, **kwargs): 169 | if can_log(): 170 | return tqdm.tqdm(obj, *args, **kwargs) 171 | else: 172 | return obj 173 | 174 | -------------------------------------------------------------------------------- /utils/factory.py: -------------------------------------------------------------------------------- 1 | from utils.loss import SoftmaxFocalLoss, ParsingRelationLoss, ParsingRelationDis 2 | from utils.metrics import MultiLabelAcc, AccTopk, Metric_mIoU 3 | from utils.dist_utils import DistSummaryWriter 4 | 5 | import torch 6 | 7 | 8 | def get_optimizer(net,cfg): 9 | training_params = filter(lambda p: p.requires_grad, net.parameters()) 10 | if cfg.optimizer == 'Adam': 11 | optimizer = torch.optim.Adam(training_params, lr=cfg.learning_rate, weight_decay=cfg.weight_decay) 12 | elif cfg.optimizer == 'SGD': 13 | optimizer = torch.optim.SGD(training_params, lr=cfg.learning_rate, momentum=cfg.momentum, 14 | weight_decay=cfg.weight_decay) 15 | else: 16 | raise NotImplementedError 17 | return optimizer 18 | 19 | def get_scheduler(optimizer, cfg, iters_per_epoch): 20 | if cfg.scheduler == 'multi': 21 | scheduler = MultiStepLR(optimizer, cfg.steps, cfg.gamma, iters_per_epoch, cfg.warmup, iters_per_epoch if cfg.warmup_iters is None else cfg.warmup_iters) 22 | elif cfg.scheduler == 'cos': 23 | scheduler = CosineAnnealingLR(optimizer, cfg.epoch * iters_per_epoch, eta_min = 0, warmup = cfg.warmup, warmup_iters = cfg.warmup_iters) 24 | else: 25 | raise NotImplementedError 26 | return scheduler 27 | 28 | def get_loss_dict(cfg): 29 | 30 | if cfg.use_aux: 31 | loss_dict = { 32 | 'name': ['cls_loss', 'relation_loss', 'aux_loss', 'relation_dis'], 33 | 'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), torch.nn.CrossEntropyLoss(), ParsingRelationDis()], 34 | 'weight': [1.0, cfg.sim_loss_w, 1.0, cfg.shp_loss_w], 35 | 'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('seg_out', 'seg_label'), ('cls_out',)] 36 | } 37 | else: 38 | loss_dict = { 39 | 'name': ['cls_loss', 'relation_loss', 'relation_dis'], 40 | 'op': [SoftmaxFocalLoss(2), ParsingRelationLoss(), ParsingRelationDis()], 41 | 'weight': [1.0, cfg.sim_loss_w, cfg.shp_loss_w], 42 | 'data_src': [('cls_out', 'cls_label'), ('cls_out',), ('cls_out',)] 43 | } 44 | 45 | return loss_dict 46 | 47 | def get_metric_dict(cfg): 48 | 49 | if cfg.use_aux: 50 | metric_dict = { 51 | 'name': ['top1', 'top2', 'top3', 'iou'], 52 | 'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3), Metric_mIoU(8+1)], 53 | 'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('seg_out', 'seg_label')] 54 | } 55 | else: 56 | metric_dict = { 57 | 'name': ['top1', 'top2', 'top3'], 58 | 'op': [MultiLabelAcc(), AccTopk(cfg.griding_num, 2), AccTopk(cfg.griding_num, 3)], 59 | 'data_src': [('cls_out', 'cls_label'), ('cls_out', 'cls_label'), ('cls_out', 'cls_label')] 60 | } 61 | 62 | 63 | return metric_dict 64 | 65 | 66 | class MultiStepLR: 67 | def __init__(self, optimizer, steps, gamma = 0.1, iters_per_epoch = None, warmup = None, warmup_iters = None): 68 | self.warmup = warmup 69 | self.warmup_iters = warmup_iters 70 | self.optimizer = optimizer 71 | self.steps = steps 72 | self.steps.sort() 73 | self.gamma = gamma 74 | self.iters_per_epoch = iters_per_epoch 75 | self.iters = 0 76 | self.base_lr = [group['lr'] for group in optimizer.param_groups] 77 | 78 | def step(self, external_iter = None): 79 | self.iters += 1 80 | if external_iter is not None: 81 | self.iters = external_iter 82 | if self.warmup == 'linear' and self.iters < self.warmup_iters: 83 | rate = self.iters / self.warmup_iters 84 | for group, lr in zip(self.optimizer.param_groups, self.base_lr): 85 | group['lr'] = lr * rate 86 | return 87 | 88 | # multi policy 89 | if self.iters % self.iters_per_epoch == 0: 90 | epoch = int(self.iters / self.iters_per_epoch) 91 | power = -1 92 | for i, st in enumerate(self.steps): 93 | if epoch < st: 94 | power = i 95 | break 96 | if power == -1: 97 | power = len(self.steps) 98 | # print(self.iters, self.iters_per_epoch, self.steps, power) 99 | 100 | for group, lr in zip(self.optimizer.param_groups, self.base_lr): 101 | group['lr'] = lr * (self.gamma ** power) 102 | import math 103 | class CosineAnnealingLR: 104 | def __init__(self, optimizer, T_max , eta_min = 0, warmup = None, warmup_iters = None): 105 | self.warmup = warmup 106 | self.warmup_iters = warmup_iters 107 | self.optimizer = optimizer 108 | self.T_max = T_max 109 | self.eta_min = eta_min 110 | 111 | self.iters = 0 112 | self.base_lr = [group['lr'] for group in optimizer.param_groups] 113 | 114 | def step(self, external_iter = None): 115 | self.iters += 1 116 | if external_iter is not None: 117 | self.iters = external_iter 118 | if self.warmup == 'linear' and self.iters < self.warmup_iters: 119 | rate = self.iters / self.warmup_iters 120 | for group, lr in zip(self.optimizer.param_groups, self.base_lr): 121 | group['lr'] = lr * rate 122 | return 123 | 124 | # cos policy 125 | 126 | for group, lr in zip(self.optimizer.param_groups, self.base_lr): 127 | group['lr'] = self.eta_min + (lr - self.eta_min) * (1 + math.cos(math.pi * self.iters / self.T_max)) / 2 128 | 129 | 130 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | class OhemCELoss(nn.Module): 8 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs): 9 | super(OhemCELoss, self).__init__() 10 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda() 11 | self.n_min = n_min 12 | self.ignore_lb = ignore_lb 13 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') 14 | 15 | def forward(self, logits, labels): 16 | N, C, H, W = logits.size() 17 | loss = self.criteria(logits, labels).view(-1) 18 | loss, _ = torch.sort(loss, descending=True) 19 | if loss[self.n_min] > self.thresh: 20 | loss = loss[loss>self.thresh] 21 | else: 22 | loss = loss[:self.n_min] 23 | return torch.mean(loss) 24 | 25 | 26 | class SoftmaxFocalLoss(nn.Module): 27 | def __init__(self, gamma, ignore_lb=255, *args, **kwargs): 28 | super(SoftmaxFocalLoss, self).__init__() 29 | self.gamma = gamma 30 | self.nll = nn.NLLLoss(ignore_index=ignore_lb) 31 | 32 | def forward(self, logits, labels): 33 | scores = F.softmax(logits, dim=1) 34 | factor = torch.pow(1.-scores, self.gamma) 35 | log_score = F.log_softmax(logits, dim=1) 36 | log_score = factor * log_score 37 | loss = self.nll(log_score, labels) 38 | return loss 39 | 40 | class ParsingRelationLoss(nn.Module): 41 | def __init__(self): 42 | super(ParsingRelationLoss, self).__init__() 43 | def forward(self,logits): 44 | n,c,h,w = logits.shape 45 | loss_all = [] 46 | for i in range(0,h-1): 47 | loss_all.append(logits[:,:,i,:] - logits[:,:,i+1,:]) 48 | #loss0 : n,c,w 49 | loss = torch.cat(loss_all) 50 | return torch.nn.functional.smooth_l1_loss(loss,torch.zeros_like(loss)) 51 | 52 | 53 | 54 | class ParsingRelationDis(nn.Module): 55 | def __init__(self): 56 | super(ParsingRelationDis, self).__init__() 57 | self.l1 = torch.nn.L1Loss() 58 | # self.l1 = torch.nn.MSELoss() 59 | def forward(self, x): 60 | n,dim,num_rows,num_cols = x.shape 61 | x = torch.nn.functional.softmax(x[:,:dim-1,:,:],dim=1) 62 | embedding = torch.Tensor(np.arange(dim-1)).float().to(x.device).view(1,-1,1,1) 63 | pos = torch.sum(x*embedding,dim = 1) 64 | 65 | diff_list1 = [] 66 | for i in range(0,num_rows // 2): 67 | diff_list1.append(pos[:,i,:] - pos[:,i+1,:]) 68 | 69 | loss = 0 70 | for i in range(len(diff_list1)-1): 71 | loss += self.l1(diff_list1[i],diff_list1[i+1]) 72 | loss /= len(diff_list1) - 1 73 | return loss 74 | 75 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time,pdb 4 | 5 | def converter(data): 6 | if isinstance(data,torch.Tensor): 7 | data = data.cpu().data.numpy().flatten() 8 | return data.flatten() 9 | def fast_hist(label_pred, label_true,num_classes): 10 | #pdb.set_trace() 11 | hist = np.bincount(num_classes * label_true.astype(int) + label_pred, minlength=num_classes ** 2) 12 | hist = hist.reshape(num_classes, num_classes) 13 | return hist 14 | 15 | class Metric_mIoU(): 16 | def __init__(self,class_num): 17 | self.class_num = class_num 18 | self.hist = np.zeros((self.class_num,self.class_num)) 19 | def update(self,predict,target): 20 | predict,target = converter(predict),converter(target) 21 | 22 | self.hist += fast_hist(predict,target,self.class_num) 23 | 24 | def reset(self): 25 | self.hist = np.zeros((self.class_num,self.class_num)) 26 | def get_miou(self): 27 | miou = np.diag(self.hist) / ( 28 | np.sum(self.hist, axis=1) + np.sum(self.hist, axis=0) - 29 | np.diag(self.hist)) 30 | miou = np.nanmean(miou) 31 | return miou 32 | 33 | def get_acc(self): 34 | acc = np.diag(self.hist) / self.hist.sum(axis=1) 35 | acc = np.nanmean(acc) 36 | return acc 37 | def get(self): 38 | return self.get_miou() 39 | class MultiLabelAcc(): 40 | def __init__(self): 41 | self.cnt = 0 42 | self.correct = 0 43 | def reset(self): 44 | self.cnt = 0 45 | self.correct = 0 46 | def update(self,predict,target): 47 | predict,target = converter(predict),converter(target) 48 | self.cnt += len(predict) 49 | self.correct += np.sum(predict==target) 50 | def get_acc(self): 51 | return self.correct * 1.0 / self.cnt 52 | def get(self): 53 | return self.get_acc() 54 | class AccTopk(): 55 | def __init__(self,background_classes,k): 56 | self.background_classes = background_classes 57 | self.k = k 58 | self.cnt = 0 59 | self.top5_correct = 0 60 | def reset(self): 61 | self.cnt = 0 62 | self.top5_correct = 0 63 | def update(self,predict,target): 64 | predict,target = converter(predict),converter(target) 65 | self.cnt += len(predict) 66 | background_idx = (predict == self.background_classes) + (target == self.background_classes) 67 | self.top5_correct += np.sum(predict[background_idx] == target[background_idx]) 68 | not_background_idx = np.logical_not(background_idx) 69 | self.top5_correct += np.sum(np.absolute(predict[not_background_idx]-target[not_background_idx])= '7': 13 | EXPLICIT_BATCH.append( 14 | 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 15 | 16 | 17 | def build_engine(onnx_file_path, verbose=False): 18 | """Build a TensorRT engine from an ONNX file.""" 19 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) if verbose else trt.Logger() 20 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network(*EXPLICIT_BATCH) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: 21 | builder.max_workspace_size = 1 << 28 22 | builder.max_batch_size = 1 23 | builder.fp16_mode = False 24 | #builder.strict_type_constraints = True 25 | 26 | # Parse model file 27 | print('Loading ONNX file from path {}...'.format(onnx_file_path)) 28 | with open(onnx_file_path, 'rb') as model: 29 | if not parser.parse(model.read()): 30 | print('ERROR: Failed to parse the ONNX file.') 31 | for error in range(parser.num_errors): 32 | print(parser.get_error(error)) 33 | return None 34 | #if trt.__version__[0] >= '7': 35 | # The actual yolo*.onnx is generated with batch size 64. 36 | # Reshape input to batch size 1 37 | # shape = list(network.get_input(0).shape) 38 | # shape[0] = 1 39 | # network.get_input(0).shape = shape 40 | 41 | print('Adding yolo_layer plugins...') 42 | model_name = onnx_file_path[:-5] 43 | #network = add_yolo_plugins( 44 | # network, model_name, category_num, TRT_LOGGER) 45 | 46 | print('Building an engine. This would take a while...') 47 | print('(Use "--verbose" to enable verbose logging.)') 48 | engine = builder.build_cuda_engine(network) 49 | print('Completed creating engine.') 50 | return engine 51 | 52 | 53 | def main(): 54 | """Create a TensorRT engine for ONNX-based YOLO.""" 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument( 57 | '-v', '--verbose', action='store_true', 58 | help='enable verbose output (for debugging)') 59 | parser.add_argument( 60 | '-c', '--category_num', type=int, default=4, 61 | help='number of object categories [80]') 62 | parser.add_argument( 63 | '-m', '--model', type=str, default='model', # 修改这里即可,例如:res18_lane.pth->default='res18_lane' 64 | help=('[yolov3|yolov3-tiny|yolov3-spp|yolov4|yolov4-tiny]-' 65 | '[{dimension}], where dimension could be a single ' 66 | 'number (e.g. 288, 416, 608) or WxH (e.g. 416x256)')) 67 | args = parser.parse_args() 68 | 69 | onnx_file_path = '%s.onnx' % args.model 70 | if not os.path.isfile(onnx_file_path): 71 | raise SystemExit('ERROR: file (%s) not found! You might want to run yolo_to_onnx.py first to generate it.' % onnx_file_path) 72 | engine_file_path = '%s.trt' % args.model 73 | engine = build_engine(onnx_file_path, args.verbose) 74 | with open(engine_file_path, 'wb') as f: 75 | f.write(engine.serialize()) 76 | print('Serialized the TensorRT engine to file: %s' % engine_file_path) 77 | 78 | 79 | if __name__ == '__main__': 80 | main() 81 | --------------------------------------------------------------------------------