├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── rn18_pyramid.py └── rn18_single_scale.py ├── data ├── __init__.py ├── ade20k │ ├── __init__.py │ └── ade20k.py ├── camvid │ ├── __init__.py │ └── camvid.py ├── cityscapes │ ├── __init__.py │ ├── cityscapes.py │ └── labels.py ├── mux │ ├── __init__.py │ └── util.py ├── transform │ ├── __init__.py │ ├── base.py │ ├── border.py │ ├── class_uniform.py │ ├── flow.py │ ├── flow_utils.py │ ├── jitter.py │ ├── labels.py │ └── photometric.py ├── util.py └── vistas │ ├── __init__.py │ └── vistas.py ├── datasets └── .gitkeep ├── eval.py ├── evaluation ├── __init__.py ├── evaluate.py └── prediction.py ├── lib ├── build.sh ├── cylib.h └── cylib.pyx ├── models ├── __init__.py ├── loss │ ├── __init__.py │ ├── boundary_loss.py │ ├── semseg_loss.py │ └── util.py ├── resnet │ ├── __init__.py │ ├── resnet_pyramid.py │ └── resnet_single_scale.py ├── semseg.py └── util.py ├── requirements.txt ├── train.py └── weights └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | lib/cylib.c 2 | lib/cylib.cc 3 | lib/cylib.html 4 | lib/cylib.so -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwiftNet 2 | 3 | Source code to reproduce results from 4 | 17 | 18 | ## Steps to reproduce 19 | 20 | ### Install requirements 21 | * Python 3.7+ 22 | ```bash 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ### Download Cityscapes 27 | 28 | From https://www.cityscapes-dataset.com/downloads/ download: 29 | * leftImg8bit_trainvaltest.zip (11GB) 30 | * gtFine_trainvaltest.zip (241MB) 31 | 32 | Either download and extract to `datasets/` or create a symbolic link `datasets/Cityscapes` 33 | Expected dataset structure for Cityscapes is: 34 | ``` 35 | labels/ 36 | train/ 37 | aachen/ 38 | aachen_000000_000019.png 39 | ... 40 | ... 41 | val/ 42 | ... 43 | rgb/ 44 | train/ 45 | aachen/ 46 | aachen_000000_000019.png 47 | ... 48 | ... 49 | val/ 50 | ... 51 | ``` 52 | 53 | 54 | ### Evaluate 55 | ##### Pre-trained Cityscapes models [available](https://drive.google.com/drive/folders/1DqX-N-nMtGG9QfMY_cKtULCKTfEuV4WT?usp=sharing) 56 | * Download and extract to `weights` directory. 57 | 58 | Set `evaluating = True` inside config file (eg. `configs/rn18_single_scale.py`) and run: 59 | ```bash 60 | python eval.py configs/rn18_single_scale.py 61 | ``` 62 | 63 | ### Train 64 | ```bash 65 | python train.py configs/rn18_single_scale.py --store_dir=/path/to/store/experiments 66 | ``` 67 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/configs/__init__.py -------------------------------------------------------------------------------- /configs/rn18_pyramid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torchvision.transforms import Compose 5 | import torch.optim as optim 6 | from pathlib import Path 7 | import os 8 | import numpy as np 9 | 10 | from models.semseg import SemsegModel 11 | from models.resnet.resnet_pyramid import * 12 | from models.loss import BoundaryAwareFocalLoss 13 | from data.transform import * 14 | from data.cityscapes import Cityscapes 15 | from evaluation import StorePreds 16 | 17 | from models.util import get_n_params 18 | 19 | path = os.path.abspath(__file__) 20 | dir_path = os.path.dirname(path) 21 | root = Path.home() / Path('datasets/Cityscapes') 22 | 23 | evaluating = False 24 | random_crop_size = 768 25 | 26 | scale = 1 27 | mean = [73.15, 82.90, 72.3] 28 | std = [47.67, 48.49, 47.73] 29 | mean_rgb = tuple(np.uint8(scale * np.array(mean))) 30 | 31 | num_classes = Cityscapes.num_classes 32 | ignore_id = Cityscapes.num_classes 33 | class_info = Cityscapes.class_info 34 | color_info = Cityscapes.color_info 35 | 36 | num_levels = 3 37 | ostride = 4 38 | target_size_crops = (random_crop_size, random_crop_size) 39 | target_size_crops_feats = (random_crop_size // ostride, random_crop_size // ostride) 40 | 41 | eval_each = 4 42 | dist_trans_bins = (16, 64, 128) 43 | dist_trans_alphas = (8., 4., 2., 1.) 44 | target_size = (2048, 1024) 45 | target_size_feats = (2048 // ostride, 1024 // ostride) 46 | 47 | trans_val = Compose( 48 | [Open(), 49 | SetTargetSize(target_size=target_size, target_size_feats=target_size_feats), 50 | Tensor(), 51 | ] 52 | ) 53 | 54 | if evaluating: 55 | trans_train = trans_train_val = trans_val 56 | else: 57 | trans_train = Compose( 58 | [Open(), 59 | RandomFlip(), 60 | RandomSquareCropAndScale(random_crop_size, ignore_id=ignore_id, mean=mean_rgb), 61 | SetTargetSize(target_size=target_size_crops, target_size_feats=target_size_crops_feats), 62 | LabelDistanceTransform(num_classes=num_classes, reduce=True, bins=dist_trans_bins, alphas=dist_trans_alphas), 63 | Tensor(), 64 | ]) 65 | 66 | dataset_train = Cityscapes(root, transforms=trans_train, subset='train') 67 | dataset_val = Cityscapes(root, transforms=trans_val, subset='val') 68 | 69 | backbone = resnet18(pretrained=True, 70 | pyramid_levels=num_levels, 71 | k_upsample=3, 72 | scale=scale, 73 | mean=mean, 74 | std=std, 75 | k_bneck=1, 76 | output_stride=ostride, 77 | efficient=True) 78 | model = SemsegModel(backbone, num_classes, k=1, bias=True) 79 | if evaluating: 80 | model.load_state_dict(torch.load('weights/rn18_pyramid/model_best.pt'), strict=False) 81 | else: 82 | model.criterion = BoundaryAwareFocalLoss(gamma=.5, num_classes=num_classes, ignore_id=ignore_id) 83 | 84 | bn_count = 0 85 | for m in model.modules(): 86 | if isinstance(m, nn.BatchNorm2d): 87 | bn_count += 1 88 | print(f'Num BN layers: {bn_count}') 89 | 90 | if not evaluating: 91 | lr = 4e-4 92 | lr_min = 1e-6 93 | fine_tune_factor = 4 94 | weight_decay = 1e-4 95 | epochs = 250 96 | 97 | optim_params = [ 98 | {'params': model.random_init_params(), 'lr': lr, 'weight_decay': weight_decay}, 99 | {'params': model.fine_tune_params(), 'lr': lr / fine_tune_factor, 100 | 'weight_decay': weight_decay / fine_tune_factor}, 101 | ] 102 | 103 | optimizer = optim.Adam(optim_params, betas=(0.9, 0.99)) 104 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, lr_min) 105 | 106 | batch_size = bs = 14 107 | print(f'Batch size: {bs}') 108 | nw = 4 109 | 110 | loader_val = DataLoader(dataset_val, batch_size=1, collate_fn=custom_collate, num_workers=nw) 111 | if evaluating: 112 | loader_train = DataLoader(dataset_train, batch_size=1, collate_fn=custom_collate, num_workers=nw) 113 | else: 114 | loader_train = DataLoader(dataset_train, batch_size=batch_size, num_workers=nw, pin_memory=True, 115 | drop_last=True, collate_fn=custom_collate, shuffle=True) 116 | 117 | total_params = get_n_params(model.parameters()) 118 | ft_params = get_n_params(model.fine_tune_params()) 119 | ran_params = get_n_params(model.random_init_params()) 120 | assert total_params == (ft_params + ran_params) 121 | print(f'Num params: {total_params:,} = {ran_params:,}(random init) + {ft_params:,}(fine tune)') 122 | 123 | if evaluating: 124 | eval_loaders = [(loader_val, 'val'), (loader_train, 'train')] 125 | store_dir = f'{dir_path}/out/' 126 | for d in ['', 'val', 'train']: 127 | os.makedirs(store_dir + d, exist_ok=True) 128 | to_color = ColorizeLabels(color_info) 129 | to_image = Compose([Numpy(), to_color]) 130 | eval_observers = [StorePreds(store_dir, to_image, to_color)] 131 | -------------------------------------------------------------------------------- /configs/rn18_single_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision.transforms import Compose 4 | import torch.optim as optim 5 | from pathlib import Path 6 | import numpy as np 7 | import os 8 | 9 | from models.semseg import SemsegModel 10 | from models.resnet.resnet_single_scale import * 11 | from models.loss import SemsegCrossEntropy 12 | from data.transform import * 13 | from data.cityscapes import Cityscapes 14 | from evaluation import StorePreds 15 | 16 | from models.util import get_n_params 17 | 18 | root = Path.home() / Path('datasets/Cityscapes') 19 | path = os.path.abspath(__file__) 20 | dir_path = os.path.dirname(path) 21 | 22 | evaluating = False 23 | random_crop_size = 768 24 | 25 | scale = 1 26 | mean = [73.15, 82.90, 72.3] 27 | std = [47.67, 48.49, 47.73] 28 | mean_rgb = tuple(np.uint8(scale * np.array(mean))) 29 | 30 | num_classes = Cityscapes.num_classes 31 | ignore_id = Cityscapes.num_classes 32 | class_info = Cityscapes.class_info 33 | color_info = Cityscapes.color_info 34 | 35 | target_size_crops = (random_crop_size, random_crop_size) 36 | target_size_crops_feats = (random_crop_size // 4, random_crop_size // 4) 37 | target_size = (2048, 1024) 38 | target_size_feats = (2048 // 4, 1024 // 4) 39 | 40 | eval_each = 4 41 | 42 | 43 | trans_val = Compose( 44 | [Open(), 45 | SetTargetSize(target_size=target_size, target_size_feats=target_size_feats), 46 | Tensor(), 47 | ] 48 | ) 49 | 50 | if evaluating: 51 | trans_train = trans_val 52 | else: 53 | trans_train = Compose( 54 | [Open(), 55 | RandomFlip(), 56 | RandomSquareCropAndScale(random_crop_size, ignore_id=num_classes, mean=mean_rgb), 57 | SetTargetSize(target_size=target_size_crops, target_size_feats=target_size_crops_feats), 58 | Tensor(), 59 | ] 60 | ) 61 | 62 | dataset_train = Cityscapes(root, transforms=trans_train, subset='train') 63 | dataset_val = Cityscapes(root, transforms=trans_val, subset='val') 64 | 65 | resnet = resnet18(pretrained=True, efficient=False, mean=mean, std=std, scale=scale) 66 | model = SemsegModel(resnet, num_classes) 67 | if evaluating: 68 | model.load_state_dict(torch.load('weights/rn18_single_scale/model_best.pt')) 69 | else: 70 | model.criterion = SemsegCrossEntropy(num_classes=num_classes, ignore_id=ignore_id) 71 | lr = 4e-4 72 | lr_min = 1e-6 73 | fine_tune_factor = 4 74 | weight_decay = 1e-4 75 | epochs = 250 76 | 77 | optim_params = [ 78 | {'params': model.random_init_params(), 'lr': lr, 'weight_decay': weight_decay}, 79 | {'params': model.fine_tune_params(), 'lr': lr / fine_tune_factor, 80 | 'weight_decay': weight_decay / fine_tune_factor}, 81 | ] 82 | 83 | optimizer = optim.Adam(optim_params, betas=(0.9, 0.99)) 84 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, lr_min) 85 | 86 | batch_size = 14 87 | print(f'Batch size: {batch_size}') 88 | 89 | if evaluating: 90 | loader_train = DataLoader(dataset_train, batch_size=1, collate_fn=custom_collate) 91 | else: 92 | loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, 93 | pin_memory=True, 94 | drop_last=True, collate_fn=custom_collate) 95 | loader_val = DataLoader(dataset_val, batch_size=1, collate_fn=custom_collate) 96 | 97 | total_params = get_n_params(model.parameters()) 98 | ft_params = get_n_params(model.fine_tune_params()) 99 | ran_params = get_n_params(model.random_init_params()) 100 | spp_params = get_n_params(model.backbone.spp.parameters()) 101 | assert total_params == (ft_params + ran_params) 102 | print(f'Num params: {total_params:,} = {ran_params:,}(random init) + {ft_params:,}(fine tune)') 103 | print(f'SPP params: {spp_params:,}') 104 | 105 | if evaluating: 106 | eval_loaders = [(loader_val, 'val'), (loader_train, 'train')] 107 | store_dir = f'{dir_path}/out/' 108 | for d in ['', 'val', 'train', 'training']: 109 | os.makedirs(store_dir + d, exist_ok=True) 110 | to_color = ColorizeLabels(color_info) 111 | to_image = Compose([DenormalizeTh(scale, mean, std), Numpy(), to_color]) 112 | eval_observers = [StorePreds(store_dir, to_image, to_color)] 113 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .cityscapes import Cityscapes 2 | from .mux import * -------------------------------------------------------------------------------- /data/ade20k/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/data/ade20k/__init__.py -------------------------------------------------------------------------------- /data/ade20k/ade20k.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | from scipy.io import loadmat 4 | import numpy as np 5 | 6 | 7 | def init_ade20k_class_color_info(path: Path): 8 | colors = loadmat(str(path / 'color150.mat'))['colors'] 9 | classes = [] 10 | with (path / 'object150_info.csv').open('r') as f: 11 | for i, line in enumerate(f.readlines()): 12 | if bool(i): 13 | classes += [line.rstrip().split(',')[-1]] 14 | return classes + ['void'], np.concatenate([colors, np.array([[0, 0, 0]], dtype=colors.dtype)]) 15 | 16 | 17 | class_info, color_info = init_ade20k_class_color_info(Path('/home/morsic/datasets/ADE20k')) 18 | 19 | 20 | class ADE20k(Dataset): 21 | class_info = class_info 22 | color_info = color_info 23 | num_classes = 150 24 | 25 | def __init__(self, root: Path, transforms: lambda x: x, subset='training', open_images=True, epoch=None): 26 | self.root = root 27 | self.open_images = open_images 28 | self.images_dir = root / 'ADEChallengeData2016/images/' / subset 29 | self.labels_dir = root / 'ADEChallengeData2016/annotations/' / subset 30 | 31 | self.images = list(sorted(self.images_dir.glob('*.jpg'))) 32 | self.labels = list(sorted(self.labels_dir.glob('*.png'))) 33 | 34 | self.transforms = transforms 35 | self.subset = subset 36 | self.epoch = epoch 37 | 38 | print(f'Num images: {len(self)}') 39 | 40 | def __len__(self): 41 | return len(self.images) 42 | 43 | def __getitem__(self, item): 44 | ret_dict = { 45 | 'name': self.images[item].stem, 46 | 'subset': self.subset, 47 | 'labels': self.labels[item] 48 | } 49 | if self.open_images: 50 | ret_dict['image'] = self.images[item] 51 | if self.epoch is not None: 52 | ret_dict['epoch'] = int(self.epoch.value) 53 | return self.transforms(ret_dict) 54 | -------------------------------------------------------------------------------- /data/camvid/__init__.py: -------------------------------------------------------------------------------- 1 | from .camvid import CamVid -------------------------------------------------------------------------------- /data/camvid/camvid.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | 4 | class_info = ['building', 'tree', 'sky', 'car', 'sign', 'road', 'pedestrian', 'fence', 'column pole', 'sidewalk', 5 | 'bicyclist'] 6 | color_info = [(128, 0, 0), (128, 128, 0), (128, 128, 128), (64, 0, 128), (192, 128, 128), (128, 64, 128), (64, 64, 0), 7 | (64, 74, 128), (192, 192, 128), (0, 0, 192), (0, 128, 192)] 8 | 9 | color_info += [[0, 0, 0]] 10 | 11 | 12 | class CamVid(Dataset): 13 | class_info = class_info 14 | color_info = color_info 15 | num_classes = len(class_info) 16 | 17 | mean = [111.376, 63.110, 83.670] 18 | std = [41.608, 54.237, 68.889] 19 | 20 | def __init__(self, root: Path, transforms: lambda x: x, subset='train'): 21 | self.root = root 22 | self.subset = subset 23 | self.image_names = [line.rstrip() for line in (root / f'{subset}.txt').open('r').readlines()] 24 | name_filter = lambda x: x.name in self.image_names 25 | self.images = list(filter(name_filter, (self.root / 'rgb').iterdir())) 26 | self.labels = list(filter(name_filter, (self.root / 'labels/ids').iterdir())) 27 | self.transforms = transforms 28 | print(f'Num images: {len(self)}') 29 | 30 | def __len__(self): 31 | return len(self.images) 32 | 33 | def __getitem__(self, item): 34 | ret_dict = { 35 | 'image': self.images[item], 36 | 'name': self.images[item].stem, 37 | 'subset': self.subset, 38 | 'labels': self.labels[item] 39 | } 40 | return self.transforms(ret_dict) 41 | -------------------------------------------------------------------------------- /data/cityscapes/__init__.py: -------------------------------------------------------------------------------- 1 | from .cityscapes import Cityscapes 2 | -------------------------------------------------------------------------------- /data/cityscapes/cityscapes.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | 4 | from .labels import labels 5 | 6 | class_info = [label.name for label in labels if label.ignoreInEval is False] 7 | color_info = [label.color for label in labels if label.ignoreInEval is False] 8 | 9 | color_info += [[0, 0, 0]] 10 | 11 | map_to_id = {} 12 | inst_map_to_id = {} 13 | i, j = 0, 0 14 | for label in labels: 15 | if label.ignoreInEval is False: 16 | map_to_id[label.id] = i 17 | i += 1 18 | if label.hasInstances is True: 19 | inst_map_to_id[label.id] = j 20 | j += 1 21 | 22 | id_to_map = {id: i for i, id in map_to_id.items()} 23 | inst_id_to_map = {id: i for i, id in inst_map_to_id.items()} 24 | 25 | 26 | class Cityscapes(Dataset): 27 | class_info = class_info 28 | color_info = color_info 29 | num_classes = 19 30 | 31 | map_to_id = map_to_id 32 | id_to_map = id_to_map 33 | 34 | inst_map_to_id = inst_map_to_id 35 | inst_id_to_map = inst_id_to_map 36 | 37 | mean = [0.485, 0.456, 0.406] 38 | std = [0.229, 0.224, 0.225] 39 | 40 | def __init__(self, root: Path, transforms: lambda x: x, subset='train', open_depth=False, labels_dir='labels', epoch=None): 41 | self.root = root 42 | self.images_dir = self.root / 'rgb' / subset 43 | self.labels_dir = self.root / labels_dir / subset 44 | self.depth_dir = self.root / 'depth' / subset 45 | self.subset = subset 46 | self.has_labels = subset != 'test' 47 | self.open_depth = open_depth 48 | self.images = list(sorted(self.images_dir.glob('*/*.ppm'))) 49 | if self.has_labels: 50 | self.labels = list(sorted(self.labels_dir.glob('*/*.png'))) 51 | self.transforms = transforms 52 | self.epoch = epoch 53 | 54 | print(f'Num images: {len(self)}') 55 | 56 | def __len__(self): 57 | return len(self.images) 58 | 59 | def __getitem__(self, item): 60 | ret_dict = { 61 | 'image': self.images[item], 62 | 'name': self.images[item].stem, 63 | 'subset': self.subset, 64 | } 65 | if self.has_labels: 66 | ret_dict['labels'] = self.labels[item] 67 | if self.epoch is not None: 68 | ret_dict['epoch'] = int(self.epoch.value) 69 | return self.transforms(ret_dict) 70 | -------------------------------------------------------------------------------- /data/cityscapes/labels.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | #-------------------------------------------------------------------------------- 4 | # Definitions 5 | #-------------------------------------------------------------------------------- 6 | 7 | # a label and all meta information 8 | Label = namedtuple( 'Label' , [ 9 | 10 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 11 | # We use them to uniquely name a class 12 | 13 | 'id' , # An integer ID that is associated with this label. 14 | # The IDs are used to represent the label in ground truth images 15 | # An ID of -1 means that this label does not have an ID and thus 16 | # is ignored when creating ground truth images (e.g. license plate). 17 | # Do not modify these IDs, since exactly these IDs are expected by the 18 | # evaluation server. 19 | 20 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 21 | # ground truth images with train IDs, using the tools provided in the 22 | # 'preparation' folder. However, make sure to validate or submit results 23 | # to our evaluation server using the regular IDs above! 24 | # For trainIds, multiple labels might have the same ID. Then, these labels 25 | # are mapped to the same class in the ground truth images. For the inverse 26 | # mapping, we use the label that is defined first in the list below. 27 | # For example, mapping all void-type classes to the same ID in training, 28 | # might make sense for some approaches. 29 | # Max value is 255! 30 | 31 | 'category' , # The name of the category that this label belongs to 32 | 33 | 'categoryId' , # The ID of this category. Used to create ground truth images 34 | # on category level. 35 | 36 | 'hasInstances', # Whether this label distinguishes between single instances or not 37 | 38 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 39 | # during evaluations or not 40 | 41 | 'color' , # The color of this label 42 | ] ) 43 | 44 | 45 | #-------------------------------------------------------------------------------- 46 | # A list of all labels 47 | #-------------------------------------------------------------------------------- 48 | 49 | # Please adapt the train IDs as appropriate for you approach. 50 | # Note that you might want to ignore labels with ID 255 during training. 51 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 52 | # Make sure to provide your results using the original IDs and not the training IDs. 53 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 54 | 55 | labels = [ 56 | # name id trainId category catId hasInstances ignoreInEval color 57 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 58 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 59 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 60 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 61 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 62 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 63 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 64 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 65 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 66 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 67 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 68 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 69 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 70 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 71 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 72 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 73 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 74 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 75 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 76 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 77 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 78 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 79 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 80 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 81 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 82 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 83 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 84 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 85 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 86 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 87 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 88 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 89 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 90 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 91 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 92 | ] 93 | 94 | 95 | def get_train_ids(): 96 | train_ids = [] 97 | for i in labels: 98 | if not i.ignoreInEval: 99 | train_ids.append(i.id) 100 | return train_ids -------------------------------------------------------------------------------- /data/mux/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import pyramid_sizes -------------------------------------------------------------------------------- /data/mux/util.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | 4 | def pyramid_sizes(size, alphas, scale=1.0): 5 | w, h = size[0], size[1] 6 | th_sc = lambda wh, alpha: int(ceil(wh / (alpha * scale))) 7 | return [(th_sc(w, a), th_sc(h, a)) for a in alphas] 8 | -------------------------------------------------------------------------------- /data/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .border import * 3 | from .flow import * 4 | from .jitter import * 5 | from .labels import * 6 | from .photometric import * 7 | from .class_uniform import * -------------------------------------------------------------------------------- /data/transform/base.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from torch.utils.data.dataloader import default_collate 3 | import numpy as np 4 | import torch 5 | from PIL import Image as pimg 6 | 7 | from data.transform.flow_utils import readFlow 8 | 9 | RESAMPLE = pimg.BICUBIC 10 | RESAMPLE_D = pimg.BILINEAR 11 | 12 | __all__ = ['Open', 'SetTargetSize', 'Numpy', 'Tensor', 'detection_collate', 'custom_collate', 'RESAMPLE', 'RESAMPLE_D'] 13 | 14 | 15 | class Open: 16 | def __init__(self, palette=None, copy_labels=True): 17 | self.palette = palette 18 | self.copy_labels = copy_labels 19 | 20 | def __call__(self, example: dict): 21 | try: 22 | ret_dict = {} 23 | for k in ['image', 'image_next', 'image_prev']: 24 | if k in example: 25 | ret_dict[k] = pimg.open(example[k]).convert('RGB') 26 | if k == 'image': 27 | ret_dict['target_size'] = ret_dict['image'].size 28 | if 'depth' in example: 29 | example['depth'] = pimg.open(example['depth']) 30 | if 'labels' in example: 31 | ret_dict['labels'] = pimg.open(example['labels']) 32 | if self.palette is not None: 33 | ret_dict['labels'].putpalette(self.palette) 34 | if self.copy_labels: 35 | ret_dict['original_labels'] = ret_dict['labels'].copy() 36 | if 'flow' in example: 37 | ret_dict['flow'] = readFlow(example['flow']) 38 | except OSError: 39 | print(example) 40 | raise 41 | return {**example, **ret_dict} 42 | 43 | 44 | class SetTargetSize: 45 | def __init__(self, target_size, target_size_feats, stride=4): 46 | self.target_size = target_size 47 | self.target_size_feats = target_size_feats 48 | self.stride = stride 49 | 50 | def __call__(self, example): 51 | if all([self.target_size, self.target_size_feats]): 52 | example['target_size'] = self.target_size[::-1] 53 | example['target_size_feats'] = self.target_size_feats[::-1] 54 | else: 55 | k = 'original_labels' if 'original_labels' in example else 'image' 56 | example['target_size'] = example[k].shape[-2:] 57 | example['target_size_feats'] = tuple([s // self.stride for s in example[k].shape[-2:]]) 58 | example['alphas'] = [-1] 59 | example['target_level'] = 0 60 | return example 61 | 62 | 63 | class Tensor: 64 | def _trans(self, img, dtype): 65 | img = np.array(img, dtype=dtype) 66 | if len(img.shape) == 3: 67 | img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) 68 | return torch.from_numpy(img) 69 | 70 | def __call__(self, example): 71 | ret_dict = {} 72 | for k in ['image', 'image_next', 'image_prev']: 73 | if k in example: 74 | ret_dict[k] = self._trans(example[k], np.float32) 75 | if 'depth' in example: 76 | ret_dict['depth'] = self._trans(example['depth'], np.uint8) 77 | if 'labels' in example: 78 | ret_dict['labels'] = self._trans(example['labels'], np.int64) 79 | if 'original_labels' in example: 80 | ret_dict['original_labels'] = self._trans(example['original_labels'], np.int64) 81 | if 'depth_hist' in example: 82 | ret_dict['depth_hist'] = [self._trans(d, np.float32) for d in example['depth_hist']] if isinstance( 83 | example['depth_hist'], list) else self._trans(example['depth_hist'], np.float32) 84 | if 'pyramid' in example: 85 | ret_dict['pyramid'] = [self._trans(p, np.float32) for p in example['pyramid']] 86 | if 'pyramid_ms' in example: 87 | ret_dict['pyramid_ms'] = [[self._trans(p, np.float32) for p in pyramids] for pyramids in 88 | example['pyramid_ms']] 89 | if 'mux_indices' in example: 90 | ret_dict['mux_indices'] = torch.stack([torch.from_numpy(midx.flatten()) for midx in example['mux_indices']]) 91 | if 'mux_masks' in example: 92 | ret_dict['mux_masks'] = [torch.from_numpy(np.uint8(mi)).unsqueeze(0) for mi in example['mux_masks']] 93 | if 'depth_bins' in example: 94 | ret_dict['depth_bins'] = torch.stack([torch.from_numpy(b) for b in example['depth_bins']]) 95 | if 'flow' in example: 96 | # ret_dict['flow'] = torch.from_numpy(example['flow']).permute(2, 0, 1).contiguous() 97 | ret_dict['flow'] = torch.from_numpy(np.ascontiguousarray(example['flow'])) 98 | # if 'flow_next' in example: 99 | # ret_dict['flow_next'] = torch.from_numpy(example['flow_next']).permute(2, 0, 1 ).contiguous() 100 | if 'flow_sub' in example: 101 | # ret_dict['flow_sub'] = torch.from_numpy(example['flow_sub']).permute(2, 0, 1).contiguous() 102 | ret_dict['flow_sub'] = torch.from_numpy(np.ascontiguousarray(example['flow_sub'])) 103 | if 'flipped' in example: 104 | del example['flipped'] 105 | return {**example, **ret_dict} 106 | 107 | 108 | class Numpy: 109 | def __call__(self, example): 110 | image = example['image'] 111 | axes = [0, 2, 3, 1] if len(image.shape) == 4 else [1, 2, 0] 112 | ret_dict = { 113 | 'image': image.numpy().transpose(axes) 114 | } 115 | for k in ['labels', 'original_labels']: 116 | if k in example and isinstance(example[k], torch.Tensor): 117 | ret_dict[k] = example[k].numpy() 118 | return {**example, **ret_dict} 119 | 120 | 121 | def detection_collate(batch): 122 | """Custom collate fn for dealing with batches of images that have a different 123 | number of associated object annotations (bounding boxes). 124 | 125 | Arguments: 126 | batch: (tuple) A tuple of tensor images and lists of annotations 127 | 128 | Return: 129 | A tuple containing: 130 | 1) (tensor) batch of images stacked on their 0 dim 131 | 2) (list of tensors) annotations for a given image are stacked on 0 dim 132 | """ 133 | custom = defaultdict(list) 134 | custom_keys = ['target_size', ] 135 | for sample in batch: 136 | for k in custom_keys: 137 | custom[k] += [sample[k]] 138 | other = {k: default_collate([b[k] for b in batch]) for k in 139 | filter(lambda x: x not in custom, batch[0].keys())} 140 | return {**other, **custom} 141 | 142 | 143 | def custom_collate(batch, del_orig_labels=False): 144 | keys = ['target_size', 'target_size_feats', 'alphas', 'target_level'] 145 | values = {} 146 | for k in keys: 147 | if k in batch[0]: 148 | values[k] = batch[0][k] 149 | for b in batch: 150 | if del_orig_labels: del b['original_labels'] 151 | for k in values.keys(): 152 | del b[k] 153 | if 'mux_indices' in b: 154 | b['mux_indices'] = b['mux_indices'].view(-1) 155 | batch = default_collate(batch) 156 | # if 'image_next' in batch: 157 | # batch['image'] = torch.cat([batch['image'], batch['image_next']], dim=0).contiguous() 158 | # del batch['image_next'] 159 | for k, v in values.items(): 160 | batch[k] = v 161 | return batch 162 | -------------------------------------------------------------------------------- /data/transform/border.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | __all__ = ['LabelDistanceTransform', 'NeighborhoodLabels', 'InstanceBorders'] 5 | 6 | 7 | class LabelDistanceTransform: 8 | def __init__(self, num_classes, bins=(4, 16, 64, 128), alphas=(8., 6., 4., 2., 1.), reduce=False, 9 | ignore_id=19): 10 | self.num_classes = num_classes 11 | self.reduce = reduce 12 | self.bins = bins 13 | self.alphas = alphas 14 | self.ignore_id = ignore_id 15 | 16 | def __call__(self, example): 17 | labels = np.array(example['labels']) 18 | present_classes = np.unique(labels) 19 | distances = np.zeros([self.num_classes] + list(labels.shape), dtype=np.float32) - 1. 20 | for i in range(self.num_classes): 21 | if i not in present_classes: 22 | continue 23 | class_mask = labels == i 24 | distances[i][class_mask] = cv2.distanceTransform(np.uint8(class_mask), cv2.DIST_L2, maskSize=5)[class_mask] 25 | if self.reduce: 26 | ignore_mask = labels == self.ignore_id 27 | distances[distances < 0] = 0 28 | distances = distances.sum(axis=0) 29 | label_distance_bins = np.digitize(distances, self.bins) 30 | label_distance_alphas = np.zeros(label_distance_bins.shape, dtype=np.float32) 31 | for idx, alpha in enumerate(self.alphas): 32 | label_distance_alphas[label_distance_bins == idx] = alpha 33 | label_distance_alphas[ignore_mask] = 0 34 | example['label_distance_alphas'] = label_distance_alphas 35 | else: 36 | example['label_distance_transform'] = distances 37 | return example 38 | 39 | 40 | class InstanceBorders: 41 | def __init__(self, instance_classes=8, thresh=.3): 42 | self.instance_classes = instance_classes 43 | self.thresh = thresh 44 | 45 | def __call__(self, example): 46 | shape = [self.instance_classes] + list(example['labels'].size)[::-1] 47 | instance_borders = np.zeros(shape, dtype=np.float32) 48 | instances = example['instances'] 49 | for k in instances: 50 | for instance in instances[k]: 51 | dist_trans = cv2.distanceTransform(instance.astype(np.uint8), cv2.DIST_L2, maskSize=5) 52 | dist_trans[instance] = 1. / dist_trans[instance] 53 | dist_trans[dist_trans < self.thresh] = .0 54 | instance_borders[k] += dist_trans 55 | example['instance_borders'] = instance_borders 56 | return example 57 | 58 | 59 | class NeighborhoodLabels: 60 | def __init__(self, num_classes, k=3, stride=1, discrete=False): 61 | self.num_classes = num_classes 62 | self.k = k 63 | self.pad = k // 2 64 | self.stride = stride 65 | self.discrete = discrete 66 | 67 | def __call__(self, example): 68 | labels = np.array(example['labels']) 69 | p = self.pad 70 | labels_padded = self.num_classes * np.ones([1, 1] + [sh + 2 * p for sh in labels.shape], dtype=labels.dtype) 71 | labels_padded[..., p:-p, p:-p] = labels.copy() 72 | label_col = im2col_cython.im2col_cython(labels_padded, self.k, self.k, padding=0, stride=self.stride) 73 | label_col_hist = im2col_cython.hist_from_cols(label_col, self.num_classes).reshape( 74 | [self.num_classes + 1] + list(labels.shape)) 75 | label_neighborhood_hist = label_col_hist / np.float32(self.k ** 2) 76 | if self.discrete: 77 | example['label_neighborhood_hist'] = (label_neighborhood_hist[:self.num_classes] > 0.).astype(np.float32) 78 | else: 79 | example['label_neighborhood_hist'] = label_neighborhood_hist 80 | return example 81 | -------------------------------------------------------------------------------- /data/transform/class_uniform.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | from tqdm import tqdm 4 | import random 5 | from PIL import Image as pimg 6 | from collections import defaultdict 7 | import warnings 8 | 9 | from data.transform import RESAMPLE, RESAMPLE_D 10 | from data.util import bb_intersection_over_union, crop_and_scale_img 11 | from data.transform.flow_utils import crop_and_scale_flow 12 | 13 | __all__ = ['create_class_uniform_strategy', 'ClassUniformSquareCropAndScale'] 14 | 15 | 16 | def create_class_uniform_strategy(instances, incidences, epochs=1): 17 | incidences = incidences[:-1] # remove ignore id 18 | num_images = len(instances) 19 | num_classes = incidences.shape[0] 20 | present_in_image = np.zeros((num_images, num_classes), dtype=np.uint32) 21 | image_names = np.array(list(instances.keys())) 22 | 23 | for i, (k, v) in enumerate(tqdm(instances.items(), total=len(instances))): 24 | for idx in v.keys(): 25 | if idx >= num_classes: 26 | continue 27 | present_in_image[i, idx] += len(v[idx]) 28 | 29 | class_incidence_histogram = incidences / incidences.sum() 30 | indices_by_occurence = np.argsort(class_incidence_histogram) 31 | p_r = class_incidence_histogram.sum() / class_incidence_histogram 32 | p_r[np.logical_or(np.isnan(p_r), np.isinf(p_r))] = 0. 33 | p_r /= p_r.sum() 34 | images_to_sample = np.round(num_images * p_r).astype(np.uint32) 35 | 36 | # weights = ((present_in_image > 0) * p_r.reshape(1, -1)).sum(-1) 37 | weights = (present_in_image * p_r.reshape(1, -1)).sum(-1) 38 | 39 | strategy = [] 40 | for e in range(epochs): 41 | chosen_classes = {} 42 | chosen_class = num_classes * np.ones(num_images, dtype=np.uint32) 43 | is_image_chosen = np.zeros(num_images, dtype=np.bool) 44 | for idx in indices_by_occurence: 45 | possibilities = np.where(present_in_image[:, idx] > 0 & ~is_image_chosen)[0] 46 | to_sample = min(images_to_sample[idx], len(possibilities)) 47 | chosen = np.random.choice(possibilities, to_sample) 48 | is_image_chosen[chosen] = 1 49 | chosen_class[chosen] = idx 50 | for n, c in zip(image_names, chosen_class): 51 | chosen_classes[n] = c 52 | strategy += [chosen_classes] 53 | statistics = defaultdict(int) 54 | for v in chosen_classes.values(): 55 | statistics[v] += 1 56 | return strategy, weights 57 | 58 | 59 | class ClassUniformSquareCropAndScale: 60 | def __init__(self, wh, mean, ignore_id, strategy, class_instances, min=.5, max=2., 61 | scale_method=lambda scale, wh, size: int(scale * wh), p_true_random_crop=.5): 62 | self.wh = wh 63 | self.min = min 64 | self.max = max 65 | self.mean = mean 66 | self.ignore_id = ignore_id 67 | self.random_gens = [self._rand_location, self._gen_instance_box] 68 | self.scale_method = scale_method 69 | self.strategy = strategy 70 | self.class_instances = class_instances 71 | self.p_true_random_crop = p_true_random_crop 72 | 73 | def _random_instance(self, name, epoch): 74 | instances = self.class_instances[name] 75 | chosen_class = self.strategy[epoch][name] 76 | if chosen_class == self.ignore_id: 77 | return None 78 | try: 79 | return random.choice(instances[chosen_class]) 80 | except IndexError: 81 | return None 82 | 83 | def _gen_instance_box(self, W, H, target_wh, name, flipped, epoch): 84 | # warnings.warn(f'ClassUniformSquareCropAndScale, epoch {epoch}') 85 | bbox = self._random_instance(name, epoch) 86 | if bbox is not None: 87 | if not (random.uniform(0, 1) < self.p_true_random_crop): 88 | wmin, wmax, hmin, hmax = bbox 89 | if flipped: 90 | wmin, wmax = W - 1 - wmax, W - 1 - wmin 91 | inst_box = [wmin, hmin, wmax, hmax] 92 | for _ in range(50): 93 | box = self._rand_location(W, H, target_wh) 94 | if bb_intersection_over_union(box, inst_box) > 0.: 95 | break 96 | return box 97 | return self._rand_location(W, H, target_wh) 98 | 99 | def _rand_location(self, W, H, target_wh, *args, **kwargs): 100 | try: 101 | w = np.random.randint(0, W - target_wh + 1) 102 | h = np.random.randint(0, H - target_wh + 1) 103 | except ValueError: 104 | print(f'Exception in RandomSquareCropAndScale: {target_wh}') 105 | w = h = 0 106 | # left, upper, right, lower) 107 | return w, h, w + target_wh, h + target_wh 108 | 109 | def _trans(self, img: pimg, crop_box, target_size, pad_size, resample, blank_value): 110 | return crop_and_scale_img(img, crop_box, target_size, pad_size, resample, blank_value) 111 | 112 | def __call__(self, example): 113 | image = example['image'] 114 | scale = np.random.uniform(self.min, self.max) 115 | W, H = image.size 116 | box_size = self.scale_method(scale, self.wh, image.size) 117 | pad_size = (max(box_size, W), max(box_size, H)) 118 | target_size = (self.wh, self.wh) 119 | flipped = example['flipped'] if 'flipped' in example else False 120 | crop_box = self._gen_instance_box(pad_size[0], pad_size[1], box_size, example.get('name'), flipped, 121 | example.get('epoch', 0)) 122 | ret_dict = { 123 | 'image': self._trans(image, crop_box, target_size, pad_size, RESAMPLE, self.mean), 124 | } 125 | if 'labels' in example: 126 | ret_dict['labels'] = self._trans(example['labels'], crop_box, target_size, pad_size, pimg.NEAREST, 127 | self.ignore_id) 128 | for k in ['image_prev', 'image_next']: 129 | if k in example: 130 | ret_dict[k] = self._trans(example[k], crop_box, target_size, pad_size, RESAMPLE, 131 | self.mean) 132 | if 'depth' in example: 133 | ret_dict['depth'] = self._trans(example['depth'], crop_box, target_size, pad_size, RESAMPLE_D, 0) 134 | if 'flow' in example: 135 | ret_dict['flow'] = crop_and_scale_flow(example['flow'], crop_box, target_size, pad_size, scale) 136 | return {**example, **ret_dict} 137 | -------------------------------------------------------------------------------- /data/transform/flow.py: -------------------------------------------------------------------------------- 1 | from .flow_utils import subsample_flow 2 | 3 | __all__ = ['SubsampleFlow'] 4 | 5 | 6 | class SubsampleFlow: 7 | def __init__(self, subsampling=4): 8 | self.subsampling = subsampling 9 | 10 | def __call__(self, example): 11 | example['flow_sub'] = subsample_flow(example['flow'], self.subsampling) 12 | return example 13 | -------------------------------------------------------------------------------- /data/transform/flow_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 4 | import numpy as np 5 | from PIL import Image as pimg 6 | 7 | from data.util import crop_and_scale_img 8 | 9 | ''' 10 | Adapted from https://github.com/NVIDIA/flownet2-pytorch 11 | ''' 12 | 13 | 14 | def readFlow(fn): 15 | """ Read .flo file in Middlebury format""" 16 | # Code adapted from: 17 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 18 | 19 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 20 | # print 'fn = %s'%(fn) 21 | with open(fn, 'rb') as f: 22 | magic = np.fromfile(f, np.float32, count=1) 23 | if 202021.25 != magic: 24 | print('Magic number incorrect. Invalid .flo file') 25 | return None 26 | else: 27 | w = np.fromfile(f, np.int32, count=1) 28 | h = np.fromfile(f, np.int32, count=1) 29 | # print 'Reading %d x %d flo file\n' % (w, h) 30 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 31 | # Reshape data into 3D array (columns, rows, bands) 32 | # The reshape here is for visualization, the original code is (w,h,2) 33 | return np.resize(data, (int(h), int(w), 2)) 34 | 35 | 36 | def flow2rgb(flow): 37 | hsv = np.zeros(list(flow.shape[:-1]) + [3], dtype=np.uint8) 38 | hsv[..., 1] = 255 39 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) 40 | hsv[..., 0] = ang * 180 / np.pi / 2 41 | hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 42 | return cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB) 43 | 44 | 45 | def offset_flow(img, flow): 46 | ''' 47 | :param img: torch.FloatTensor of shape NxCxHxW 48 | :param flow: torch.FloatTensor of shape NxHxWx2 49 | :return: torch.FloatTensor of shape NxCxHxW 50 | ''' 51 | N, C, H, W = img.shape 52 | # generate identity sampling grid 53 | gx, gy = torch.meshgrid(torch.arange(H), torch.arange(W)) 54 | gx = gx.float().div(gx.max() - 1).view(1, H, W, 1) 55 | gy = gy.float().div(gy.max() - 1).view(1, H, W, 1) 56 | grid = torch.cat([gy, gx], dim=-1).mul(2.).sub(1) 57 | # generate normalized flow field 58 | flown = flow.clone() 59 | flown[..., 0] /= W 60 | flown[..., 1] /= H 61 | # calculate offset field 62 | grid += flown 63 | return F.grid_sample(img, grid), grid 64 | 65 | 66 | def backward_warp(x, flo): 67 | """ 68 | warp an image/tensor (im2) back to im1, according to the optical flow 69 | x: [B, C, H, W] (im2) 70 | flo: [B, 2, H, W] flow 71 | """ 72 | B, C, H, W = x.size() 73 | # mesh grid 74 | xx = torch.arange(0, W).to(x.device).view(1, -1).repeat(H, 1) 75 | yy = torch.arange(0, H).to(x.device).view(-1, 1).repeat(1, W) 76 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 77 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 78 | grid = torch.cat((xx, yy), 1).float() 79 | 80 | vgrid = grid + flo 81 | 82 | # scale grid to [-1,1] 83 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 84 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 85 | 86 | vgrid = vgrid.permute(0, 2, 3, 1) 87 | output = F.grid_sample(x, vgrid) 88 | 89 | mask = torch.ones_like(x) 90 | mask = F.grid_sample(mask, vgrid) 91 | 92 | mask[mask < 0.9999] = 0 93 | mask[mask > 0] = 1 94 | 95 | return output * mask, mask > 0. 96 | 97 | 98 | def pad_flow(flow, size): 99 | h, w, _ = flow.shape 100 | shape = list(size) + [2] 101 | new_flow = np.zeros(shape, dtype=flow.dtype) 102 | new_flow[:h, :w] = flow 103 | 104 | 105 | def flip_flow_horizontal(flow): 106 | flow = np.flip(flow, axis=1) 107 | flow[..., 0] *= -1 108 | return flow 109 | 110 | 111 | def crop_and_scale_flow(flow, crop_box, target_size, pad_size, scale): 112 | def _trans(uv): 113 | return crop_and_scale_img(uv, crop_box, target_size, pad_size, resample=pimg.NEAREST, blank_value=0) 114 | 115 | u, v = [pimg.fromarray(uv.squeeze()) for uv in np.split(flow * scale, 2, axis=-1)] 116 | dtype = flow.dtype 117 | return np.stack([np.array(_trans(u), dtype=dtype), np.array(_trans(v), dtype=dtype)], axis=-1) 118 | 119 | 120 | def subsample_flow(flow, subsampling): 121 | dtype = flow.dtype 122 | u, v = [pimg.fromarray(uv.squeeze()) for uv in np.split(flow / subsampling, 2, axis=-1)] 123 | size = tuple([int(round(wh / subsampling)) for wh in u.size]) 124 | u, v = u.resize(size), v.resize(size) 125 | return np.stack([np.array(u, dtype=dtype), np.array(v, dtype=dtype)], axis=-1) 126 | -------------------------------------------------------------------------------- /data/transform/jitter.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | from math import ceil 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image as pimg 8 | 9 | from data.transform import RESAMPLE, RESAMPLE_D 10 | from data.transform.flow_utils import pad_flow, crop_and_scale_flow, flip_flow_horizontal 11 | from data.util import bb_intersection_over_union, crop_and_scale_img 12 | 13 | __all__ = ['Pad', 'PadToFactor', 'Normalize', 'Denormalize', 'DenormalizeTh', 'Resize', 'RandomFlip', 14 | 'RandomSquareCropAndScale', 'ResizeLongerSide', 'Downsample'] 15 | 16 | 17 | class Pad: 18 | def __init__(self, size, ignore_id, mean): 19 | self.size = size 20 | self.ignore_id = ignore_id 21 | self.mean = mean 22 | 23 | def _do(self, data, color): 24 | blank = pimg.new(mode=data.mode, size=self.size, color=color) 25 | blank.paste(data) 26 | return blank 27 | 28 | def __call__(self, example): 29 | ret_dict = {} 30 | for k, c in zip(['image', 'labels', 'original_labels', 'image_next', 'image_prev'], 31 | [self.mean, self.ignore_id, self.ignore_id, self.mean, self.mean]): 32 | if k in example: 33 | ret_dict[k] = self._do(example[k], c) 34 | if 'flow' in example: 35 | ret_dict['flow'] = pad_flow(example['flow'], self.size) 36 | return {**example, **ret_dict} 37 | 38 | 39 | class PadToFactor: 40 | def __init__(self, factor, ignore_id, mean): 41 | self.factor = factor 42 | self.ignore_id = ignore_id 43 | self.mean = mean 44 | 45 | def _do(self, data, color, size): 46 | blank = pimg.new(mode=data.mode, size=size, color=color) 47 | blank.paste(data) 48 | return blank 49 | 50 | def __call__(self, example): 51 | ret_dict = {} 52 | size = tuple(map(lambda x: ceil(x / self.factor) * self.factor, example['image'].size)) 53 | for k, c in zip(['image', 'labels', 'original_labels', 'image_next', 'image_prev'], 54 | [self.mean, self.ignore_id, self.ignore_id, self.mean]): 55 | if k in example: 56 | ret_dict[k] = self._do(example[k], c, size) 57 | if 'flow' in example: 58 | ret_dict['flow'] = pad_flow(example['flow'], size) 59 | return {**example, **ret_dict} 60 | 61 | 62 | class Norm: 63 | def __init__(self, scale, mean, std): 64 | self.scale = scale 65 | self.mean = mean 66 | self.std = std 67 | 68 | def _trans(self, img): 69 | raise NotImplementedError 70 | 71 | def __call__(self, example): 72 | ret_dict = { 73 | 'image': self._trans(example['image']) 74 | } 75 | for k in ['image_prev', 'image_next']: 76 | if k in example: 77 | ret_dict[k] = self._trans(example[k]) 78 | if 'pyramid' in example: 79 | ret_dict['pyramid'] = [self._trans(p) for p in example['pyramid']] 80 | if 'pyramid_ms' in example: 81 | ret_dict['pyramid_ms'] = [[self._trans(p) for p in pyramid] for pyramid in example['pyramid_ms']] 82 | return {**example, **ret_dict} 83 | 84 | 85 | class Normalize(Norm): 86 | def _trans(self, img): 87 | img = np.array(img).astype(np.float32) 88 | if self.scale != 1: 89 | img /= self.scale 90 | img -= self.mean 91 | img /= self.std 92 | return img 93 | 94 | 95 | class Denormalize(Norm): 96 | def _trans(self, img): 97 | img = np.array(img) 98 | img *= self.std 99 | img += self.mean 100 | if self.scale != 1: 101 | img *= self.scale 102 | return img 103 | 104 | 105 | class DenormalizeTh(Norm): 106 | def __init__(self, scale, mean, std): 107 | super(DenormalizeTh, self).__init__(scale, mean, std) 108 | self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1) 109 | self.std = torch.FloatTensor(std).view(1, 3, 1, 1) 110 | 111 | def _trans(self, img): 112 | img *= self.std 113 | img += self.mean 114 | if self.scale != 1: 115 | img *= self.scale 116 | return img 117 | 118 | 119 | class Downsample: 120 | def __init__(self, factor=2): 121 | self.factor = factor 122 | 123 | def __call__(self, example): 124 | if self.factor <= 1: 125 | return example 126 | W, H = example['image'].size 127 | w, h = W // self.factor, H // self.factor 128 | size = (w, h) 129 | ret_dict = { 130 | 'image': example['image'].resize(size, resample=RESAMPLE), 131 | 'labels': example['labels'].resize(size, resample=pimg.NEAREST), 132 | } 133 | if 'depth' in example: 134 | ret_dict['depth'] = example['depth'].resize(size, resample=RESAMPLE) 135 | return {**example, **ret_dict} 136 | 137 | 138 | class RandomSquareCropAndScale: 139 | def __init__(self, wh, mean, ignore_id, min=.5, max=2., class_incidence=None, class_instances=None, 140 | inst_classes=(3, 12, 14, 15, 16, 17, 18), scale_method=lambda scale, wh, size: int(scale * wh)): 141 | self.wh = wh 142 | self.min = min 143 | self.max = max 144 | self.mean = mean 145 | self.ignore_id = ignore_id 146 | self.random_gens = [self._rand_location] 147 | self.scale_method = scale_method 148 | 149 | if class_incidence is not None and class_instances is not None: 150 | self.true_random = False 151 | class_incidence_obj = np.load(class_incidence) 152 | with open(class_instances, 'rb') as f: 153 | self.class_instances = pickle.load(f) 154 | inst_classes = np.array(inst_classes) 155 | class_freq = class_incidence_obj[inst_classes].astype(np.float32) 156 | class_prob = 1. / (class_freq / class_freq.sum()) 157 | class_prob /= class_prob.sum() 158 | self.p_class = {k.item(): v.item() for k, v in zip(inst_classes, class_prob)} 159 | self.random_gens += [self._gen_instance_box] 160 | print(f'Instance based random cropping:\n\t{self.p_class}') 161 | 162 | def _random_instance(self, name, W, H): 163 | def weighted_random_choice(choices): 164 | max = sum(choices) 165 | pick = random.uniform(0, max) 166 | key, current = 0, 0. 167 | for key, value in enumerate(choices): 168 | current += value 169 | if current > pick: 170 | return key 171 | key += 1 172 | return key 173 | 174 | instances = self.class_instances[name] 175 | possible_classes = list(set(self.p_class.keys()).intersection(instances.keys())) 176 | roulette = [] 177 | flat_instances = [] 178 | for c in possible_classes: 179 | flat_instances += instances[c] 180 | roulette += [self.p_class[c]] * len(instances[c]) 181 | if len(flat_instances) == 0: 182 | return [0, W - 1, 0, H - 1] 183 | index = weighted_random_choice(roulette) 184 | return flat_instances[index] 185 | 186 | def _gen_instance_box(self, W, H, target_wh, name, flipped): 187 | wmin, wmax, hmin, hmax = self._random_instance(name, W, H) 188 | if flipped: 189 | wmin, wmax = W - 1 - wmax, W - 1 - wmin 190 | inst_box = [wmin, hmin, wmax, hmax] 191 | for _ in range(50): 192 | box = self._rand_location(W, H, target_wh) 193 | if bb_intersection_over_union(box, inst_box) > 0.: 194 | break 195 | return box 196 | 197 | def _rand_location(self, W, H, target_wh, *args, **kwargs): 198 | try: 199 | w = np.random.randint(0, W - target_wh + 1) 200 | h = np.random.randint(0, H - target_wh + 1) 201 | except ValueError: 202 | print(f'Exception in RandomSquareCropAndScale: {target_wh}') 203 | w = h = 0 204 | # left, upper, right, lower) 205 | return w, h, w + target_wh, h + target_wh 206 | 207 | def _trans(self, img: pimg, crop_box, target_size, pad_size, resample, blank_value): 208 | return crop_and_scale_img(img, crop_box, target_size, pad_size, resample, blank_value) 209 | 210 | def __call__(self, example): 211 | image = example['image'] 212 | scale = np.random.uniform(self.min, self.max) 213 | W, H = image.size 214 | box_size = self.scale_method(scale, self.wh, image.size) 215 | pad_size = (max(box_size, W), max(box_size, H)) 216 | target_size = (self.wh, self.wh) 217 | crop_fn = random.choice(self.random_gens) 218 | flipped = example['flipped'] if 'flipped' in example else False 219 | crop_box = crop_fn(pad_size[0], pad_size[1], box_size, example.get('name'), flipped) 220 | ret_dict = { 221 | 'image': self._trans(image, crop_box, target_size, pad_size, RESAMPLE, self.mean), 222 | } 223 | if 'labels' in example: 224 | ret_dict['labels'] = self._trans(example['labels'], crop_box, target_size, pad_size, pimg.NEAREST, self.ignore_id) 225 | for k in ['image_prev', 'image_next']: 226 | if k in example: 227 | ret_dict[k] = self._trans(example[k], crop_box, target_size, pad_size, RESAMPLE, 228 | self.mean) 229 | if 'depth' in example: 230 | ret_dict['depth'] = self._trans(example['depth'], crop_box, target_size, pad_size, RESAMPLE_D, 0) 231 | if 'flow' in example: 232 | ret_dict['flow'] = crop_and_scale_flow(example['flow'], crop_box, target_size, pad_size, scale) 233 | return {**example, **ret_dict} 234 | 235 | 236 | class RandomFlip: 237 | def _trans(self, img: pimg, flip: bool): 238 | return img.transpose(pimg.FLIP_LEFT_RIGHT) if flip else img 239 | 240 | def __call__(self, example): 241 | flip = np.random.choice([False, True]) 242 | ret_dict = {} 243 | for k in ['image', 'image_next', 'image_prev', 'labels', 'depth']: 244 | if k in example: 245 | ret_dict[k] = self._trans(example[k], flip) 246 | if ('flow' in example) and flip: 247 | ret_dict['flow'] = flip_flow_horizontal(example['flow']) 248 | return {**example, **ret_dict} 249 | 250 | 251 | class Resize: 252 | def __init__(self, size): 253 | self.size = size 254 | 255 | def __call__(self, example): 256 | # raise NotImplementedError() 257 | ret_dict = {'image': example['image'].resize(self.size, resample=RESAMPLE)} 258 | if 'labels' in example: 259 | ret_dict['labels'] = example['labels'].resize(self.size, resample=pimg.NEAREST) 260 | if 'depth' in example: 261 | ret_dict['depth'] = example['depth'].resize(self.size, resample=RESAMPLE_D) 262 | return {**example, **ret_dict} 263 | 264 | 265 | class ResizeLongerSide: 266 | def __init__(self, size): 267 | self.size = size 268 | 269 | def __call__(self, example): 270 | ret_dict = {} 271 | k = 'image' if 'image' in example else 'labels' 272 | scale = self.size / max(example[k].size) 273 | size = tuple([int(wh * scale) for wh in example[k].size]) 274 | if 'image' in example: 275 | ret_dict['image'] = example['image'].resize(size, resample=RESAMPLE) 276 | if 'labels' in example: 277 | ret_dict['labels'] = example['labels'].resize(size, resample=pimg.NEAREST) 278 | # if 'original_labels' in example: 279 | # ret_dict['original_labels'] = example['original_labels'].resize(size, resample=pimg.NEAREST) 280 | if 'depth' in example: 281 | ret_dict['depth'] = example['depth'].resize(size, resample=RESAMPLE_D) 282 | return {**example, **ret_dict} 283 | -------------------------------------------------------------------------------- /data/transform/labels.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from PIL import Image as pimg 5 | 6 | __all__ = ['ExtractInstances', 'RemapLabels', 'ColorizeLabels'] 7 | 8 | 9 | class ExtractInstances: 10 | def __init__(self, inst_map_to_id=None): 11 | self.inst_map_to_id = inst_map_to_id 12 | 13 | def __call__(self, example: dict): 14 | labels = np.int32(example['labels']) 15 | unique_ids = np.unique(labels) 16 | instances = defaultdict(list) 17 | for id in filter(lambda x: x > 1000, unique_ids): 18 | cls = self.inst_map_to_id.get(id // 1000, None) 19 | if cls is not None: 20 | instances[cls] += [labels == id] 21 | example['instances'] = instances 22 | return example 23 | 24 | 25 | class RemapLabels: 26 | def __init__(self, mapping: dict, ignore_id, total=35): 27 | self.mapping = np.ones((max(total, max(mapping.keys())) + 1,), dtype=np.uint8) * ignore_id 28 | self.ignore_id = ignore_id 29 | for i in range(len(self.mapping)): 30 | self.mapping[i] = mapping[i] if i in mapping else ignore_id 31 | 32 | def _trans(self, labels): 33 | max_k = self.mapping.shape[0] - 1 34 | labels[labels > max_k] //= 1000 35 | labels = self.mapping[labels].astype(labels.dtype) 36 | return labels 37 | 38 | def __call__(self, example): 39 | if not isinstance(example, dict): 40 | return self._trans(example) 41 | if 'labels' not in example: 42 | return example 43 | ret_dict = {'labels': pimg.fromarray(self._trans(np.array(example['labels'])))} 44 | if 'original_labels' in example: 45 | ret_dict['original_labels'] = pimg.fromarray(self._trans(np.array(example['original_labels']))) 46 | return {**example, **ret_dict} 47 | 48 | 49 | class ColorizeLabels: 50 | def __init__(self, color_info): 51 | self.color_info = np.array(color_info) 52 | 53 | def _trans(self, lab): 54 | R, G, B = [np.zeros_like(lab) for _ in range(3)] 55 | for l in np.unique(lab): 56 | mask = lab == l 57 | R[mask] = self.color_info[l][0] 58 | G[mask] = self.color_info[l][1] 59 | B[mask] = self.color_info[l][2] 60 | return np.stack((R, G, B), axis=-1).astype(np.uint8) 61 | 62 | def __call__(self, example): 63 | if not isinstance(example, dict): 64 | return self._trans(example) 65 | assert 'labels' in example 66 | return {**example, **{'labels': self._trans(example['labels']), 67 | 'original_labels': self._trans(example['original_labels'])}} 68 | -------------------------------------------------------------------------------- /data/transform/photometric.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | from PIL import Image as pimg 5 | 6 | __all__ = ['PhotometricDistort'] 7 | 8 | 9 | class Compose(object): 10 | 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, img): 15 | for t in self.transforms: 16 | img = t(img) 17 | return img 18 | 19 | 20 | class RandomSaturation(object): 21 | def __init__(self, lower=0.5, upper=1.5): 22 | self.lower = lower 23 | self.upper = upper 24 | assert self.upper >= self.lower, "contrast upper must be >= lower." 25 | assert self.lower >= 0, "contrast lower must be non-negative." 26 | 27 | def __call__(self, image): 28 | if random.randint(0, 2): 29 | image[:, :, 1] *= random.uniform(self.lower, self.upper) 30 | 31 | return image 32 | 33 | 34 | class RandomHue(object): 35 | def __init__(self, delta=18.0): 36 | assert 0.0 <= delta <= 360.0 37 | self.delta = delta 38 | 39 | def __call__(self, image): 40 | if random.randint(0, 2): 41 | image[:, :, 0] += random.uniform(-self.delta, self.delta) 42 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 43 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 44 | return image 45 | 46 | 47 | class SwapChannels(object): 48 | """Transforms a tensorized image by swapping the channels in the order 49 | specified in the swap tuple. 50 | Args: 51 | swaps (int triple): final order of channels 52 | eg: (2, 1, 0) 53 | """ 54 | 55 | def __init__(self, swaps): 56 | self.swaps = swaps 57 | 58 | def __call__(self, image): 59 | """ 60 | Args: 61 | image (Tensor): image tensor to be transformed 62 | Return: 63 | a tensor with channels swapped according to swap 64 | """ 65 | image = image[:, :, self.swaps] 66 | return image 67 | 68 | 69 | class RandomLightingNoise(object): 70 | def __init__(self): 71 | self.perms = ((0, 1, 2), (0, 2, 1), 72 | (1, 0, 2), (1, 2, 0), 73 | (2, 0, 1), (2, 1, 0)) 74 | 75 | def __call__(self, image): 76 | if random.randint(0, 2): 77 | swap = self.perms[random.randint(0, len(self.perms) - 1)] 78 | shuffle = SwapChannels(swap) # shuffle channels 79 | image = shuffle(image) 80 | return image 81 | 82 | 83 | class ConvertColor(object): 84 | def __init__(self, current='BGR', transform='HSV'): 85 | self.transform = transform 86 | self.current = current 87 | 88 | def __call__(self, image): 89 | if self.current == 'BGR' and self.transform == 'HSV': 90 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 91 | elif self.current == 'HSV' and self.transform == 'BGR': 92 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 93 | else: 94 | raise NotImplementedError 95 | return image 96 | 97 | 98 | class RandomContrast(object): 99 | def __init__(self, lower=0.5, upper=1.5): 100 | self.lower = lower 101 | self.upper = upper 102 | assert self.upper >= self.lower, "contrast upper must be >= lower." 103 | assert self.lower >= 0, "contrast lower must be non-negative." 104 | 105 | # expects float image 106 | def __call__(self, image): 107 | if random.randint(0, 2): 108 | alpha = random.uniform(self.lower, self.upper) 109 | image *= alpha 110 | return image 111 | 112 | 113 | class RandomBrightness(object): 114 | def __init__(self, delta=32): 115 | assert delta >= 0.0 116 | assert delta <= 255.0 117 | self.delta = delta 118 | 119 | def __call__(self, image): 120 | if random.randint(0, 2): 121 | delta = random.uniform(-self.delta, self.delta) 122 | image += delta 123 | return image 124 | 125 | 126 | class PhotometricDistort(object): 127 | def __init__(self): 128 | self.pd = [ 129 | RandomContrast(), 130 | ConvertColor(transform='HSV'), 131 | RandomSaturation(), 132 | RandomHue(), 133 | ConvertColor(current='HSV', transform='BGR'), 134 | RandomContrast() 135 | ] 136 | self.rand_brightness = RandomBrightness() 137 | self.rand_light_noise = RandomLightingNoise() 138 | 139 | def __call__(self, example): 140 | image = np.float32(example['image']) 141 | im = image.copy() 142 | im = self.rand_brightness(im) 143 | if random.randint(0, 2): 144 | distort = Compose(self.pd[:-1]) 145 | else: 146 | distort = Compose(self.pd[1:]) 147 | im = distort(im) 148 | im = self.rand_light_noise(im) 149 | ret = { 150 | 'image': pimg.fromarray(np.uint8(im)), 151 | } 152 | return {**example, **ret} 153 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import Dataset 3 | import torch 4 | import numpy as np 5 | import pickle 6 | from collections import defaultdict 7 | from PIL import Image as pimg 8 | 9 | 10 | def disparity_distribution_uniform(max_disp, num_bins): 11 | return np.linspace(0, max_disp, num_bins - 1) 12 | 13 | 14 | def disparity_distribution_log(num_bins): 15 | return np.power(np.sqrt(2), np.arange(num_bins - 1)) 16 | 17 | 18 | def downsample_distribution(labels, factor, num_classes): 19 | h, w = labels.shape 20 | assert h % factor == 0 and w % factor == 0 21 | new_h = h // factor 22 | new_w = w // factor 23 | labels_4d = np.ascontiguousarray(labels.reshape(new_h, factor, new_w, factor), labels.dtype) 24 | labels_oh = np.eye(num_classes, dtype=np.float32)[labels_4d] 25 | target_dist = labels_oh.sum((1, 3)) / factor ** 2 26 | return target_dist 27 | 28 | 29 | def downsample_distribution_th(labels, factor, num_classes, ignore_id=None): 30 | n, h, w = labels.shape 31 | assert h % factor == 0 and w % factor == 0 32 | new_h = h // factor 33 | new_w = w // factor 34 | labels_4d = labels.view(n, new_h, factor, new_w, factor) 35 | labels_oh = torch.eye(num_classes).to(labels_4d.device)[labels_4d] 36 | target_dist = labels_oh.sum(2).sum(3) / factor ** 2 37 | return target_dist 38 | 39 | 40 | def downsample_labels_th(labels, factor, num_classes): 41 | ''' 42 | :param labels: Tensor(N, H, W) 43 | :param factor: int 44 | :param num_classes: int 45 | :return: FloatTensor(-1, num_classes), ByteTensor(-1, 1) 46 | ''' 47 | n, h, w = labels.shape 48 | assert h % factor == 0 and w % factor == 0 49 | new_h = h // factor 50 | new_w = w // factor 51 | labels_4d = labels.view(n, new_h, factor, new_w, factor) 52 | # +1 class here because ignore id = num_classes 53 | labels_oh = torch.eye(num_classes + 1).to(labels_4d.device)[labels_4d] 54 | target_dist = labels_oh.sum(2).sum(3) / factor ** 2 55 | C = target_dist.shape[-1] 56 | target_dist = target_dist.view(-1, C) 57 | # keep only boxes which have p(ignore) < 0.5 58 | valid_mask = target_dist[:, -1] < 0.5 59 | target_dist = target_dist[:, :-1].contiguous() 60 | dist_sum = target_dist.sum(1, keepdim=True) 61 | # avoid division by zero 62 | dist_sum[dist_sum == 0] = 1 63 | # renormalize distribution after removing p(ignore) 64 | target_dist /= dist_sum 65 | return target_dist, valid_mask 66 | 67 | 68 | def equalize_hist_disparity_distribution(d, L): 69 | cd = np.cumsum(d / d.sum()) 70 | Y = np.round((L - 1) * cd).astype(np.uint8) 71 | return np.array([np.argmax(Y == i) for i in range(L - 1)]) 72 | 73 | 74 | def bb_intersection_over_union(boxA, boxB): 75 | # determine the (x, y)-coordinates of the intersection rectangle 76 | xA = max(boxA[0], boxB[0]) 77 | yA = max(boxA[1], boxB[1]) 78 | xB = min(boxA[2], boxB[2]) 79 | yB = min(boxA[3], boxB[3]) 80 | 81 | # compute the area of intersection rectangle 82 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 83 | 84 | # compute the area of both the prediction and ground-truth 85 | # rectangles 86 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 87 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 88 | 89 | # compute the intersection over union by taking the intersection 90 | # area and dividing it by the sum of prediction + ground-truth 91 | # areas - the interesection area 92 | iou = interArea / float(boxAArea + boxBArea - interArea) 93 | 94 | # return the intersection over union value 95 | return iou 96 | 97 | 98 | def one_hot_encoding(labels, C): 99 | ''' 100 | Converts an integer label torch.autograd.Variable to a one-hot Variable. 101 | 102 | Parameters 103 | ---------- 104 | labels : torch.autograd.Variable of torch.cuda.LongTensor 105 | N x 1 x H x W, where N is batch size. 106 | Each value is an integer representing correct classification. 107 | C : integer. 108 | number of classes in labels. 109 | 110 | Returns 111 | ------- 112 | target : torch.autograd.Variable of torch.cuda.FloatTensor 113 | N x C x H x W, where C is class number. One-hot encoded. 114 | ''' 115 | one_hot = torch.FloatTensor(labels.size(0), C, labels.size(2), labels.size(3)).to(labels.device).zero_() 116 | target = one_hot.scatter_(1, labels.data, 1) 117 | 118 | return target 119 | 120 | 121 | def crop_and_scale_img(img: pimg, crop_box, target_size, pad_size, resample, blank_value): 122 | target = pimg.new(img.mode, pad_size, color=blank_value) 123 | target.paste(img) 124 | res = target.crop(crop_box).resize(target_size, resample=resample) 125 | return res 126 | -------------------------------------------------------------------------------- /data/vistas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/data/vistas/__init__.py -------------------------------------------------------------------------------- /data/vistas/vistas.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from pathlib import Path 3 | 4 | class_info = ['animal--bird', 'animal--ground-animal', 'construction--barrier--curb', 'construction--barrier--fence', 5 | 'construction--barrier--guard-rail', 'construction--barrier--other-barrier', 6 | 'construction--barrier--wall', 'construction--flat--bike-lane', 'construction--flat--crosswalk-plain', 7 | 'construction--flat--curb-cut', 'construction--flat--parking', 'construction--flat--pedestrian-area', 8 | 'construction--flat--rail-track', 'construction--flat--road', 'construction--flat--service-lane', 9 | 'construction--flat--sidewalk', 'construction--structure--bridge', 'construction--structure--building', 10 | 'construction--structure--tunnel', 'human--person', 'human--rider--bicyclist', 11 | 'human--rider--motorcyclist', 'human--rider--other-rider', 'marking--crosswalk-zebra', 'marking--general', 12 | 'nature--mountain', 'nature--sand', 'nature--sky', 'nature--snow', 'nature--terrain', 13 | 'nature--vegetation', 'nature--water', 'object--banner', 'object--bench', 'object--bike-rack', 14 | 'object--billboard', 'object--catch-basin', 'object--cctv-camera', 'object--fire-hydrant', 15 | 'object--junction-box', 'object--mailbox', 'object--manhole', 'object--phone-booth', 'object--pothole', 16 | 'object--street-light', 'object--support--pole', 'object--support--traffic-sign-frame', 17 | 'object--support--utility-pole', 'object--traffic-light', 'object--traffic-sign--back', 18 | 'object--traffic-sign--front', 'object--trash-can', 'object--vehicle--bicycle', 'object--vehicle--boat', 19 | 'object--vehicle--bus', 'object--vehicle--car', 'object--vehicle--caravan', 'object--vehicle--motorcycle', 20 | 'object--vehicle--on-rails', 'object--vehicle--other-vehicle', 'object--vehicle--trailer', 21 | 'object--vehicle--truck', 'object--vehicle--wheeled-slow', 'void--car-mount', 'void--ego-vehicle', 22 | 'void--unlabeled'] 23 | color_info = [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153], [180, 165, 180], [102, 102, 156], 24 | [102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96], 25 | [230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232], [150, 100, 100], [70, 70, 70], 26 | [150, 120, 90], [220, 20, 60], [255, 0, 0], [255, 0, 0], [255, 0, 0], [200, 128, 128], [255, 255, 255], 27 | [64, 170, 64], [128, 64, 64], [70, 130, 180], [255, 255, 255], [152, 251, 152], [107, 142, 35], 28 | [0, 170, 30], [255, 255, 128], [250, 0, 30], [0, 0, 0], [220, 220, 220], [170, 170, 170], [222, 40, 40], 29 | [100, 170, 30], [40, 40, 40], [33, 33, 33], [170, 170, 170], [0, 0, 142], [170, 170, 170], 30 | [210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 142], [250, 170, 30], [192, 192, 192], 31 | [220, 220, 0], [180, 165, 180], [119, 11, 32], [0, 0, 142], [0, 60, 100], [0, 0, 142], [0, 0, 90], 32 | [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 192], [32, 32, 32], [0, 0, 0], 33 | [0, 0, 0]] 34 | 35 | 36 | class Vistas(Dataset): 37 | class_info = class_info 38 | color_info = color_info 39 | num_classes = 63 40 | 41 | def __init__(self, root: Path, transforms: lambda x: x, subset='training', open_images=True, epoch=None): 42 | self.root = root 43 | self.open_images = open_images 44 | self.images_dir = root / subset / 'images' 45 | self.labels_dir = root / subset / 'labels' 46 | 47 | self.images = list(sorted(self.images_dir.glob('*.jpg'))) 48 | self.labels = list(sorted(self.labels_dir.glob('*.png'))) 49 | 50 | self.transforms = transforms 51 | self.subset = subset 52 | self.epoch = epoch 53 | 54 | print(f'Num images: {len(self)}') 55 | 56 | def __len__(self): 57 | return len(self.images) 58 | 59 | def __getitem__(self, item): 60 | ret_dict = { 61 | 'name': self.images[item].stem, 62 | 'subset': self.subset, 63 | 'labels': self.labels[item] 64 | } 65 | if self.open_images: 66 | ret_dict['image'] = self.images[item] 67 | if self.epoch is not None: 68 | ret_dict['epoch'] = int(self.epoch.value) 69 | return self.transforms(ret_dict) 70 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/datasets/.gitkeep -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import importlib.util 4 | from evaluation import evaluate_semseg 5 | 6 | 7 | def import_module(path): 8 | spec = importlib.util.spec_from_file_location("module", path) 9 | module = importlib.util.module_from_spec(spec) 10 | spec.loader.exec_module(module) 11 | return module 12 | 13 | 14 | parser = argparse.ArgumentParser(description='Detector train') 15 | parser.add_argument('config', type=str, help='Path to configuration .py file') 16 | parser.add_argument('--profile', dest='profile', action='store_true', help='Profile one forward pass') 17 | 18 | if __name__ == '__main__': 19 | args = parser.parse_args() 20 | conf_path = Path(args.config) 21 | conf = import_module(args.config) 22 | 23 | class_info = conf.dataset_val.class_info 24 | 25 | model = conf.model.cuda() 26 | 27 | for loader, name in conf.eval_loaders: 28 | iou, per_class_iou = evaluate_semseg(model, loader, class_info, observers=conf.eval_observers) 29 | print(f'{name}: {iou:.2f}') 30 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluate import * 2 | from .prediction import * 3 | -------------------------------------------------------------------------------- /evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | from time import perf_counter 7 | 8 | import lib.cylib as cylib 9 | 10 | __all__ = ['compute_errors', 'get_pred', 'evaluate_semseg'] 11 | 12 | 13 | def compute_errors(conf_mat, class_info, verbose=True): 14 | num_correct = conf_mat.trace() 15 | num_classes = conf_mat.shape[0] 16 | total_size = conf_mat.sum() 17 | avg_pixel_acc = num_correct / total_size * 100.0 18 | TPFP = conf_mat.sum(1) 19 | TPFN = conf_mat.sum(0) 20 | FN = TPFN - conf_mat.diagonal() 21 | FP = TPFP - conf_mat.diagonal() 22 | class_iou = np.zeros(num_classes) 23 | class_recall = np.zeros(num_classes) 24 | class_precision = np.zeros(num_classes) 25 | per_class_iou = [] 26 | if verbose: 27 | print('Errors:') 28 | for i in range(num_classes): 29 | TP = conf_mat[i, i] 30 | class_iou[i] = (TP / (TP + FP[i] + FN[i])) * 100.0 31 | if TPFN[i] > 0: 32 | class_recall[i] = (TP / TPFN[i]) * 100.0 33 | else: 34 | class_recall[i] = 0 35 | if TPFP[i] > 0: 36 | class_precision[i] = (TP / TPFP[i]) * 100.0 37 | else: 38 | class_precision[i] = 0 39 | 40 | class_name = class_info[i] 41 | per_class_iou += [(class_name, class_iou[i])] 42 | if verbose: 43 | print('\t%s IoU accuracy = %.2f %%' % (class_name, class_iou[i])) 44 | avg_class_iou = class_iou.mean() 45 | avg_class_recall = class_recall.mean() 46 | avg_class_precision = class_precision.mean() 47 | if verbose: 48 | print('IoU mean class accuracy -> TP / (TP+FN+FP) = %.2f %%' % avg_class_iou) 49 | print('mean class recall -> TP / (TP+FN) = %.2f %%' % avg_class_recall) 50 | print('mean class precision -> TP / (TP+FP) = %.2f %%' % avg_class_precision) 51 | print('pixel accuracy = %.2f %%' % avg_pixel_acc) 52 | return avg_pixel_acc, avg_class_iou, avg_class_recall, avg_class_precision, total_size, per_class_iou 53 | 54 | 55 | def get_pred(logits, labels, conf_mat): 56 | _, pred = torch.max(logits.data, dim=1) 57 | pred = pred.byte().cpu() 58 | pred = pred.numpy().astype(np.int32) 59 | true = labels.numpy().astype(np.int32) 60 | cylib.collect_confusion_matrix(pred.reshape(-1), true.reshape(-1), conf_mat) 61 | 62 | 63 | def mt(sync=False): 64 | if sync: 65 | torch.cuda.synchronize() 66 | return 1000 * perf_counter() 67 | 68 | 69 | def evaluate_semseg(model, data_loader, class_info, observers=()): 70 | model.eval() 71 | managers = [torch.no_grad()] + list(observers) 72 | with contextlib.ExitStack() as stack: 73 | for ctx_mgr in managers: 74 | stack.enter_context(ctx_mgr) 75 | conf_mat = np.zeros((model.num_classes, model.num_classes), dtype=np.uint64) 76 | for step, batch in tqdm(enumerate(data_loader), total=len(data_loader)): 77 | batch['original_labels'] = batch['original_labels'].numpy().astype(np.uint32) 78 | logits, additional = model.do_forward(batch, batch['original_labels'].shape[1:3]) 79 | pred = torch.argmax(logits.data, dim=1).byte().cpu().numpy().astype(np.uint32) 80 | for o in observers: 81 | o(pred, batch, additional) 82 | cylib.collect_confusion_matrix(pred.flatten(), batch['original_labels'].flatten(), conf_mat) 83 | print('') 84 | pixel_acc, iou_acc, recall, precision, _, per_class_iou = compute_errors(conf_mat, class_info, verbose=True) 85 | model.train() 86 | return iou_acc, per_class_iou 87 | -------------------------------------------------------------------------------- /evaluation/prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image as pimg 3 | 4 | __all__ = ['StorePreds', 'StoreSubmissionPreds'] 5 | 6 | 7 | class StorePreds: 8 | def __init__(self, store_dir, to_img, to_color): 9 | self.store_dir = store_dir 10 | self.to_img = to_img 11 | self.to_color = to_color 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | pass 18 | 19 | def __str__(self): 20 | return '' 21 | 22 | def __call__(self, pred, batch, additional): 23 | b = self.to_img(batch) 24 | for p, im, gt, name, subset in zip(pred, b['image'], b['original_labels'], b['name'], b['subset']): 25 | store_img = np.concatenate([i.astype(np.uint8) for i in [im, self.to_color(p), gt]], axis=0) 26 | store_img = pimg.fromarray(store_img) 27 | store_img.thumbnail((960, 1344)) 28 | store_img.save(f'{self.store_dir}/{subset}/{name}.jpg') 29 | 30 | class StoreSubmissionPreds: 31 | def __init__(self, store_dir, remap, to_color=None, store_dir_color=None): 32 | self.store_dir = store_dir 33 | self.store_dir_color = store_dir_color 34 | self.to_color = to_color 35 | self.remap = remap 36 | 37 | def __enter__(self): 38 | return self 39 | 40 | def __exit__(self, exc_type, exc_val, exc_tb): 41 | pass 42 | 43 | def __str__(self): 44 | return '' 45 | 46 | def __call__(self, pred, batch, additional): 47 | for p, name in zip(pred.astype(np.uint8), batch['name']): 48 | pimg.fromarray(self.remap(p)).save(f'{self.store_dir}/{name}.png') 49 | pimg.fromarray(self.to_color(p)).save(f'{self.store_dir_color}/{name}.png') -------------------------------------------------------------------------------- /lib/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | rm cylib.so 3 | 4 | cython -a cylib.pyx -o cylib.cc 5 | 6 | #g++ -shared -pthread -fPIC -fwrapv -O3 -Wall -fno-strict-aliasing \ 7 | #-I/usr/lib/python3.7/site-packages/numpy/core/include -I/usr/include/python3.7m -o cylib.so cylib.cc 8 | g++ -shared -pthread -fPIC -fwrapv -O3 -Wall -fno-strict-aliasing \ 9 | -I/usr/lib/python3.8/site-packages/numpy/core/include -I/usr/include/python3.8 -o cylib.so cylib.cc 10 | -------------------------------------------------------------------------------- /lib/cylib.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | 8 | void add_confusion_matrix(uint32_t* y, uint32_t* yt, uint32_t size, uint64_t* matrix, uint32_t num_classes) { 9 | uint32_t target, i; 10 | for (i = 0; i < size; i++) { 11 | target = yt[i]; 12 | if (target >= 0 && target < num_classes) { 13 | matrix[y[i]*num_classes + target] += 1; 14 | } 15 | } 16 | } 17 | 18 | void add_disp_class_error_number_matrix(uint32_t *y, uint32_t *yt, uint32_t *dt, uint32_t size, uint64_t* error_matrix, 19 | uint64_t* num_matrix, uint32_t max_disp, uint32_t num_classes){ 20 | uint32_t disp, label, pred, index, i; 21 | for(i =0; i < size; i++){ 22 | disp = dt[i]; label = yt[i]; pred = y[i]; 23 | if ((disp >= max_disp) | (label >= num_classes) | (pred >= num_classes)){continue;} 24 | index = label * max_disp + disp; 25 | if (label != pred) 26 | {error_matrix[index] += 1;} 27 | num_matrix[index] += 1; 28 | } 29 | } 30 | 31 | 32 | void impl_convert_colors_to_ids(int num_classes, int* color_data, int width, int height, 33 | uint8_t* rgb_labels, uint8_t* id_labels, uint64_t* class_hist, 34 | float max_wgt, float* class_weights, float* weights) { 35 | std::unordered_map color_map; 36 | for (std::size_t i = 0; i < num_classes; i++) { 37 | int s = i * 4; 38 | std::ostringstream skey; 39 | for (int i = 0; i < 3; i++) 40 | skey << std::setw(3) << std::setfill('0') << color_data[s+i]; 41 | 42 | //skey << std::setw(3) << std::setfill('0') << r; 43 | //std::cout << skey.str() << '\n'; 44 | auto key = skey.str(); 45 | color_map[key] = color_data[s+3]; 46 | 47 | } 48 | //#pragma omp parallel for 49 | for (int r = 0; r < height; r++) { 50 | int stride = r * width * 3; 51 | for (int c = 0; c < width; c++) { 52 | std::ostringstream skey; 53 | for (int i = 0; i < 3; i++) 54 | skey << std::setw(3) << std::setfill('0') << int(rgb_labels[stride + c*3 + i]); 55 | auto key = skey.str(); 56 | //std::cout << key << " - " << int(color_map[key]) << '\n'; 57 | uint8_t class_id = color_map[key]; 58 | id_labels[r*width + c] = class_id; 59 | if (class_id < 255) { 60 | class_hist[class_id]++; 61 | } 62 | } 63 | } 64 | 65 | uint64_t num_labels = 0; 66 | for (int i = 0; i < num_classes; i++) 67 | num_labels += class_hist[i]; 68 | for (int i = 0; i < num_classes; i++) { 69 | if (class_hist[i] > 0) 70 | class_weights[i] = std::min(double(max_wgt), 1.0 / (double(class_hist[i]) / num_labels)); 71 | else 72 | class_weights[i] = 0.0; 73 | //std::cout << class_hist[i] << '\n'; 74 | //std::cout << class_weights[i] << '\n'; 75 | } 76 | //#pragma omp parallel for 77 | for (int r = 0; r < height; r++) { 78 | for (int c = 0; c < width; c++) { 79 | int pos = r*width + c; 80 | uint8_t cidx = id_labels[pos]; 81 | if (cidx < 255) 82 | weights[pos] = class_weights[cidx]; 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /lib/cylib.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | cimport numpy as np 3 | from libc cimport stdint 4 | 5 | #np.import_array() 6 | 7 | ctypedef stdint.int64_t int64_t 8 | ctypedef stdint.uint64_t uint64_t 9 | ctypedef stdint.uint32_t uint32_t 10 | ctypedef stdint.uint8_t uint8_t 11 | 12 | cdef extern from "cylib.h": 13 | void add_confusion_matrix(uint32_t*y, uint32_t*yt, uint32_t size, uint64_t*matrix, uint32_t num_classes) 14 | void impl_convert_colors_to_ids(int num_classes, int*color_data, int width, int height, 15 | uint8_t*rgb_labels, uint8_t*id_labels, uint64_t*class_hist, 16 | float max_wgt, float*class_weights, float*weights) 17 | void add_disp_class_error_number_matrix(uint32_t *y, uint32_t *yt, uint32_t *dt, uint32_t size, 18 | uint64_t*error_matrix, 19 | uint64_t*num_matrix, uint32_t max_disp, uint32_t num_classes) 20 | 21 | 22 | def collect_confusion_matrix(y, yt, confusion_mat): 23 | cdef uint32_t size = y.size 24 | cdef uint32_t num_classes = confusion_mat.shape[0] 25 | cdef np.ndarray[uint32_t, mode="c", ndim=1] y_c = np.ascontiguousarray(y) 26 | cdef np.ndarray[uint32_t, mode="c", ndim=1] yt_c = np.ascontiguousarray(yt) 27 | cdef np.ndarray[uint64_t, mode="c", ndim=2] confusion_mat_c = confusion_mat 28 | add_confusion_matrix(&y_c[0], &yt_c[0], size, &confusion_mat_c[0, 0], num_classes) 29 | 30 | def collect_disp_class_matrices(y, yt, dt, error_mat, num_mat): 31 | cdef uint32_t size = y.size 32 | cdef uint32_t num_classes = error_mat.shape[0] 33 | cdef uint32_t max_disp = error_mat.shape[1] 34 | cdef np.ndarray[uint32_t, mode="c", ndim=1] y_c = np.ascontiguousarray(y) 35 | cdef np.ndarray[uint32_t, mode="c", ndim=1] yt_c = np.ascontiguousarray(yt) 36 | cdef np.ndarray[uint32_t, mode="c", ndim=1] dt_c = np.ascontiguousarray(dt) 37 | cdef np.ndarray[uint64_t, mode="c", ndim=2] error_mat_c = error_mat 38 | cdef np.ndarray[uint64_t, mode="c", ndim=2] num_mat_c = num_mat 39 | add_disp_class_error_number_matrix(&y_c[0], &yt_c[0], &dt_c[0], size, &error_mat_c[0, 0], &num_mat_c[0, 0], 40 | max_disp, num_classes) 41 | 42 | def convert_colors_to_ids(color_data, rgb_labels, id_labels, class_histogram, max_wgt, 43 | class_weights, weights): 44 | cdef int num_classes = color_data.shape[0] 45 | cdef int width = rgb_labels.shape[1] 46 | cdef int height = rgb_labels.shape[0] 47 | cdef float max_wgt_c = max_wgt 48 | cdef np.ndarray[int, mode="c", ndim=2] color_data_c = color_data 49 | cdef np.ndarray[uint8_t, mode="c", ndim=3] rgb_labels_c = rgb_labels 50 | cdef np.ndarray[uint8_t, mode="c", ndim=2] id_labels_c = id_labels 51 | cdef np.ndarray[float, mode="c", ndim=2] weights_c = weights 52 | cdef np.ndarray[float, mode="c", ndim=1] class_weights_c = class_weights 53 | cdef np.ndarray[uint64_t, mode="c", ndim=1] class_hist_c = class_histogram 54 | impl_convert_colors_to_ids(num_classes, &color_data_c[0, 0], width, height, &rgb_labels_c[0, 0, 0], 55 | &id_labels_c[0, 0], &class_hist_c[0], max_wgt_c, &class_weights_c[0], &weights_c[0, 0]) 56 | 57 | #add_confusion_matrix(&y_c[0], &yt_c[0], size, &confusion_mat_c[0,0], num_classes) 58 | 59 | #def collect_confusion_matrix(y, yt, conf_mat): 60 | # print(y.size) 61 | # for i in range(y.size): 62 | # l = y[i] 63 | # lt = yt[i] 64 | # if lt >= 0: 65 | # conf_mat[l,lt] += 1 66 | # 67 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/models/__init__.py -------------------------------------------------------------------------------- /models/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .semseg_loss import SemsegCrossEntropy 2 | from .boundary_loss import BoundaryAwareFocalLoss 3 | from .util import * 4 | -------------------------------------------------------------------------------- /models/loss/boundary_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from models.util import upsample 6 | 7 | 8 | class BoundaryAwareFocalLoss(nn.Module): 9 | def __init__(self, gamma=0, num_classes=19, ignore_id=19, print_each=20): 10 | super(BoundaryAwareFocalLoss, self).__init__() 11 | self.num_classes = num_classes 12 | self.ignore_id = ignore_id 13 | self.print_each = print_each 14 | self.step_counter = 0 15 | self.gamma = gamma 16 | 17 | def forward(self, input, target, batch, **kwargs): 18 | if input.shape[-2:] != target.shape[-2:]: 19 | input = upsample(input, target.shape[-2:]) 20 | target[target == self.ignore_id] = 0 # we can do this because alphas are zero in ignore_id places 21 | label_distance_alphas = batch['label_distance_alphas'].to(input.device) 22 | N = (label_distance_alphas.data > 0.).sum() 23 | if N.le(0): 24 | return torch.zeros(size=(0,), device=label_distance_alphas.device, requires_grad=True).sum() 25 | if input.dim() > 2: 26 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 27 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 28 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 29 | target = target.view(-1, 1) 30 | alphas = label_distance_alphas.view(-1) 31 | 32 | logpt = F.log_softmax(input, dim=-1) 33 | logpt = logpt.gather(1, target) 34 | logpt = logpt.view(-1) 35 | pt = logpt.detach().exp() 36 | 37 | loss = -1 * alphas * torch.exp(self.gamma * (1 - pt)) * logpt 38 | loss = loss.sum() / N 39 | 40 | if (self.step_counter % self.print_each) == 0: 41 | print(f'Step: {self.step_counter} Loss: {loss.data.cpu().item():.4f}') 42 | self.step_counter += 1 43 | 44 | return loss 45 | -------------------------------------------------------------------------------- /models/loss/semseg_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from models.util import upsample 5 | 6 | 7 | class SemsegCrossEntropy(nn.Module): 8 | def __init__(self, num_classes=19, ignore_id=19, print_each=20): 9 | super(SemsegCrossEntropy, self).__init__() 10 | self.num_classes = num_classes 11 | self.ignore_id = ignore_id 12 | self.step_counter = 0 13 | self.print_each = print_each 14 | 15 | def loss(self, y, t): 16 | if y.shape[2:4] != t.shape[1:3]: 17 | y = upsample(y, t.shape[1:3]) 18 | return F.cross_entropy(y, target=t, ignore_index=self.ignore_id) 19 | 20 | def forward(self, logits, labels, **kwargs): 21 | loss = self.loss(logits, labels) 22 | if (self.step_counter % self.print_each) == 0: 23 | print(f'Step: {self.step_counter} Loss: {loss.data.cpu().item():.4f}') 24 | self.step_counter += 1 25 | return loss 26 | -------------------------------------------------------------------------------- /models/loss/util.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | __all__ = ['cross_entropy_with_logits', 'cross_entropy_with_logits_and_hist', 'mean_squared_error'] 4 | 5 | 6 | def cross_entropy_with_logits(y, t): 7 | ''' 8 | :param y: Tensor of logits 9 | :param t: Tensor of logits 10 | :return: 11 | ''' 12 | assert y.shape == t.shape 13 | return -(y.log_softmax(dim=1) * t.softmax(dim=1)).sum(dim=1).mean() 14 | 15 | 16 | def cross_entropy_with_logits_and_hist(y, t, reduce=True): 17 | ''' 18 | :param y: Tensor of logits 19 | :param t: Tensor of histograms 20 | :return: 21 | ''' 22 | assert y.shape == t.shape 23 | ce = -(y.log_softmax(dim=1) * t).sum(dim=1) 24 | if reduce: 25 | ce = ce.mean() 26 | return ce 27 | 28 | 29 | def mean_squared_error(y, t): 30 | ''' 31 | :param y: Tensor of logits 32 | :param t: Tensor of logits 33 | :return: 34 | ''' 35 | assert y.shape == t.shape 36 | return F.mse_loss(y, t, reduction='mean') 37 | -------------------------------------------------------------------------------- /models/resnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/models/resnet/__init__.py -------------------------------------------------------------------------------- /models/resnet/resnet_pyramid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.model_zoo as model_zoo 5 | from itertools import chain 6 | import torch.utils.checkpoint as cp 7 | from collections import defaultdict 8 | from math import log2 9 | 10 | from ..util import _UpsampleBlend 11 | 12 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'BasicBlock'] 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def convkxk(in_planes, out_planes, stride=1, k=3): 24 | """kxk convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=k, stride=stride, padding=k // 2, bias=False) 26 | 27 | 28 | def _bn_function_factory(conv, norm, relu=None): 29 | def bn_function(x): 30 | x = norm(conv(x)) 31 | if relu is not None: 32 | x = relu(x) 33 | return x 34 | 35 | return bn_function 36 | 37 | 38 | def do_efficient_fwd(block, x, efficient): 39 | # return block(x) 40 | if efficient and x.requires_grad: 41 | return cp.checkpoint(block, x) 42 | else: 43 | return block(x) 44 | 45 | 46 | class Identity(nn.Module): 47 | def __init__(self, *args, **kwargs): 48 | super(Identity, self).__init__() 49 | 50 | def forward(self, input): 51 | return input 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | expansion = 1 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, bn_class=nn.BatchNorm2d, levels=3): 58 | super(BasicBlock, self).__init__() 59 | self.conv1 = convkxk(inplanes, planes, stride) 60 | self.bn1 = nn.ModuleList([bn_class(planes) for _ in range(levels)]) 61 | self.relu_inp = nn.ReLU(inplace=True) 62 | self.relu = nn.ReLU(inplace=False) 63 | self.conv2 = convkxk(planes, planes) 64 | self.bn2 = nn.ModuleList([bn_class(planes) for _ in range(levels)]) 65 | self.downsample = downsample 66 | self.stride = stride 67 | self.efficient = efficient 68 | self.num_levels = levels 69 | 70 | def forward(self, x, level): 71 | residual = x 72 | 73 | bn_1 = _bn_function_factory(self.conv1, self.bn1[level], self.relu_inp) 74 | bn_2 = _bn_function_factory(self.conv2, self.bn2[level]) 75 | 76 | out = do_efficient_fwd(bn_1, x, self.efficient) 77 | out = do_efficient_fwd(bn_2, out, self.efficient) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | relu = self.relu(out) 84 | 85 | return relu, out 86 | 87 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 88 | missing_keys, unexpected_keys, error_msgs): 89 | super(BasicBlock, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, 90 | unexpected_keys, error_msgs) 91 | missing_keys = [] 92 | unexpected_keys = [] 93 | for bn in self.bn1: 94 | bn._load_from_state_dict(state_dict, prefix + 'bn1.', local_metadata, strict, missing_keys, unexpected_keys, 95 | error_msgs) 96 | for bn in self.bn2: 97 | bn._load_from_state_dict(state_dict, prefix + 'bn2.', local_metadata, strict, missing_keys, unexpected_keys, 98 | error_msgs) 99 | 100 | 101 | class ResNet(nn.Module): 102 | def _make_layer(self, block, planes, blocks, stride=1, bn_class=nn.BatchNorm2d): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | nn.Conv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | bn_class(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | layers.append(block(self.inplanes, planes, stride, downsample, self.efficient, bn_class=bn_class, 113 | levels=self.pyramid_levels)) 114 | self.inplanes = planes * block.expansion 115 | for i in range(1, blocks): 116 | layers.append(block(self.inplanes, planes, bn_class=bn_class, levels=self.pyramid_levels, efficient=self.efficient)) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def __init__(self, block, layers, *, num_features=128, pyramid_levels=3, use_bn=True, k_bneck=1, k_upsample=3, 121 | efficient=False, upsample_skip=True, mean=(73.1584, 82.9090, 72.3924), 122 | std=(44.9149, 46.1529, 45.3192), scale=1, detach_upsample_skips=(), detach_upsample_in=False, 123 | align_corners=None, pyramid_subsample='bicubic', target_size=None, 124 | output_stride=4, **kwargs): 125 | self.inplanes = 64 126 | self.efficient = efficient 127 | super(ResNet, self).__init__() 128 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 129 | bn_class = nn.BatchNorm2d if use_bn else Identity 130 | self.register_buffer('img_mean', torch.tensor(mean).view(1, -1, 1, 1)) 131 | self.register_buffer('img_std', torch.tensor(std).view(1, -1, 1, 1)) 132 | if scale != 1: 133 | self.register_buffer('img_scale', torch.tensor(scale).view(1, -1, 1, 1).float()) 134 | 135 | self.pyramid_levels = pyramid_levels 136 | self.num_features = num_features 137 | self.replicated = False 138 | 139 | self.align_corners = align_corners 140 | self.pyramid_subsample = pyramid_subsample 141 | 142 | self.bn1 = nn.ModuleList([bn_class(64) for _ in range(pyramid_levels)]) 143 | self.relu = nn.ReLU(inplace=True) 144 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 145 | bottlenecks = [] 146 | self.layer1 = self._make_layer(block, 64, layers[0], bn_class=bn_class) 147 | bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)] 148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, bn_class=bn_class) 149 | bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)] 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, bn_class=bn_class) 151 | bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)] 152 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, bn_class=bn_class) 153 | bottlenecks += [convkxk(self.inplanes, num_features, k=k_bneck)] 154 | 155 | num_bn_remove = max(0, int(log2(output_stride) - 2)) 156 | self.num_skip_levels = self.pyramid_levels + 3 - num_bn_remove 157 | bottlenecks = bottlenecks[num_bn_remove:] 158 | 159 | self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4, self.bn1] 160 | 161 | self.upsample_bottlenecks = nn.ModuleList(bottlenecks[::-1]) 162 | num_pyr_modules = 2 + pyramid_levels - num_bn_remove 163 | self.target_size = target_size 164 | if self.target_size is not None: 165 | h, w = target_size 166 | target_sizes = [(h // 2 ** i, w // 2 ** i) for i in range(2, 2 + num_pyr_modules)][::-1] 167 | else: 168 | target_sizes = [None] * num_pyr_modules 169 | self.upsample_blends = nn.ModuleList( 170 | [_UpsampleBlend(num_features, 171 | use_bn=use_bn, 172 | use_skip=upsample_skip, 173 | detach_skip=i in detach_upsample_skips, 174 | fixed_size=ts, 175 | k=k_upsample) 176 | for i, ts in enumerate(target_sizes)]) 177 | self.detach_upsample_in = detach_upsample_in 178 | 179 | self.random_init = [self.upsample_bottlenecks, self.upsample_blends] 180 | 181 | self.features = num_features 182 | 183 | for m in self.modules(): 184 | if isinstance(m, nn.Conv2d): 185 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 186 | elif isinstance(m, nn.BatchNorm2d): 187 | nn.init.constant_(m.weight, 1) 188 | nn.init.constant_(m.bias, 0) 189 | 190 | def random_init_params(self): 191 | return chain(*[f.parameters() for f in self.random_init]) 192 | 193 | def fine_tune_params(self): 194 | return chain(*[f.parameters() for f in self.fine_tune]) 195 | 196 | def forward_resblock(self, x, layers, idx): 197 | skip = None 198 | for l in layers: 199 | x = l(x) if not isinstance(l, BasicBlock) else l(x, idx) 200 | if isinstance(x, tuple): 201 | x, skip = x 202 | return x, skip 203 | 204 | def forward_down(self, image, skips, idx=-1): 205 | x = self.conv1(image) 206 | x = self.bn1[idx](x) 207 | x = self.relu(x) 208 | x = self.maxpool(x) 209 | 210 | features = [] 211 | x, skip = self.forward_resblock(x, self.layer1, idx) 212 | features += [skip] 213 | x, skip = self.forward_resblock(x, self.layer2, idx) 214 | features += [skip] 215 | x, skip = self.forward_resblock(x, self.layer3, idx) 216 | features += [skip] 217 | x, skip = self.forward_resblock(x, self.layer4, idx) 218 | features += [skip] 219 | 220 | skip_feats = [b(f) for b, f in zip(self.upsample_bottlenecks, reversed(features))] 221 | 222 | for i, s in enumerate(reversed(skip_feats)): 223 | skips[idx + i] += [s] 224 | 225 | return skips 226 | 227 | def forward(self, image): 228 | if isinstance(self.bn1[0], nn.BatchNorm2d): 229 | if hasattr(self, 'img_scale'): 230 | image /= self.img_scale 231 | image -= self.img_mean 232 | image /= self.img_std 233 | pyramid = [image] 234 | for l in range(1, self.pyramid_levels): 235 | if self.target_size is not None: 236 | ts = list([si // 2 ** l for si in self.target_size]) 237 | pyramid += [ 238 | F.interpolate(image, size=ts, mode=self.pyramid_subsample, align_corners=self.align_corners)] 239 | else: 240 | pyramid += [F.interpolate(image, scale_factor=1 / 2 ** l, mode=self.pyramid_subsample, 241 | align_corners=self.align_corners)] 242 | skips = [[] for _ in range(self.num_skip_levels)] 243 | additional = {'pyramid': pyramid} 244 | for idx, p in enumerate(pyramid): 245 | skips = self.forward_down(p, skips, idx=idx) 246 | skips = skips[::-1] 247 | x = skips[0][0] 248 | if self.detach_upsample_in: 249 | x = x.detach() 250 | for i, (sk, blend) in enumerate(zip(skips[1:], self.upsample_blends)): 251 | x = blend(x, sum(sk)) 252 | return x, additional 253 | 254 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 255 | missing_keys, unexpected_keys, error_msgs): 256 | super(ResNet, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, 257 | unexpected_keys, error_msgs) 258 | for bn in self.bn1: 259 | bn._load_from_state_dict(state_dict, prefix + 'bn1.', local_metadata, strict, missing_keys, unexpected_keys, 260 | error_msgs) 261 | 262 | 263 | def resnet18(pretrained=True, **kwargs): 264 | """Constructs a ResNet-18 model. 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 269 | if pretrained: 270 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 271 | return model 272 | 273 | 274 | def resnet34(pretrained=True, **kwargs): 275 | """Constructs a ResNet-34 model. 276 | Args: 277 | pretrained (bool): If True, returns a model pre-trained on ImageNet 278 | """ 279 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 280 | if pretrained: 281 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False) 282 | return model 283 | -------------------------------------------------------------------------------- /models/resnet/resnet_single_scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from itertools import chain 5 | import torch.utils.checkpoint as cp 6 | from math import log2 7 | 8 | from ..util import _Upsample, SpatialPyramidPooling, SeparableConv2d 9 | from evaluation.evaluate import mt 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet18dws', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'BasicBlock'] 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet18dws': '/home/morsic/saves/imagenet/resnet18dws/model_best.pth.tar', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, separable=False): 24 | """3x3 convolution with padding""" 25 | conv_class = SeparableConv2d if separable else nn.Conv2d 26 | return conv_class(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 27 | 28 | 29 | def _bn_function_factory(conv, norm, relu=None): 30 | def bn_function(x): 31 | x = conv(x) 32 | if norm is not None: 33 | x = norm(x) 34 | if relu is not None: 35 | x = relu(x) 36 | return x 37 | 38 | return bn_function 39 | 40 | def do_efficient_fwd(block, x, efficient): 41 | if efficient and x.requires_grad: 42 | return cp.checkpoint(block, x) 43 | else: 44 | return block(x) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True, deleting=False, 51 | separable=False): 52 | super(BasicBlock, self).__init__() 53 | self.use_bn = use_bn 54 | self.conv1 = conv3x3(inplanes, planes, stride, separable=separable) 55 | self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None 56 | self.relu = nn.ReLU(inplace=True) 57 | self.conv2 = conv3x3(planes, planes, separable=separable) 58 | self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None 59 | self.downsample = downsample 60 | self.stride = stride 61 | self.efficient = efficient 62 | self.deleting = deleting 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | if self.deleting is False: 71 | bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) 72 | bn_2 = _bn_function_factory(self.conv2, self.bn2) 73 | 74 | out = do_efficient_fwd(bn_1, x, self.efficient) 75 | out = do_efficient_fwd(bn_2, out, self.efficient) 76 | else: 77 | out = torch.zeros_like(residual) 78 | 79 | out = out + residual 80 | relu = self.relu(out) 81 | # print(f'Basic Block memory: {torch.cuda.memory_allocated() // 2**20}') 82 | 83 | return relu, out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, efficient=True, use_bn=True, separable=False): 90 | super(Bottleneck, self).__init__() 91 | self.use_bn = use_bn 92 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 93 | self.bn1 = nn.BatchNorm2d(planes) if self.use_bn else None 94 | conv_class = SeparableConv2d if separable else nn.Conv2d 95 | self.conv2 = conv_class(planes, planes, kernel_size=3, stride=stride, 96 | padding=1, bias=False) 97 | self.bn2 = nn.BatchNorm2d(planes) if self.use_bn else None 98 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) if self.use_bn else None 100 | self.relu = nn.ReLU(inplace=False) 101 | self.downsample = downsample 102 | self.stride = stride 103 | self.efficient = efficient 104 | 105 | def forward(self, x): 106 | residual = x 107 | 108 | bn_1 = _bn_function_factory(self.conv1, self.bn1, self.relu) 109 | bn_2 = _bn_function_factory(self.conv2, self.bn2, self.relu) 110 | bn_3 = _bn_function_factory(self.conv3, self.bn3, self.relu) 111 | 112 | out = do_efficient_fwd(bn_1, x, self.efficient) 113 | out = do_efficient_fwd(bn_2, out, self.efficient) 114 | out = do_efficient_fwd(bn_3, out, self.efficient) 115 | 116 | if self.downsample is not None: 117 | residual = self.downsample(x) 118 | 119 | out = out + residual 120 | relu = self.relu(out) 121 | 122 | return relu, out 123 | 124 | 125 | class ResNet(nn.Module): 126 | def __init__(self, block, layers, *, num_features=128, k_up=3, efficient=False, use_bn=True, 127 | spp_grids=(8, 4, 2, 1), spp_square_grid=False, spp_drop_rate=0.0, 128 | upsample_skip=True, upsample_only_skip=False, 129 | detach_upsample_skips=(), detach_upsample_in=False, 130 | target_size=None, output_stride=4, mean=(73.1584, 82.9090, 72.3924), 131 | std=(44.9149, 46.1529, 45.3192), scale=1, separable=False, 132 | upsample_separable=False, **kwargs): 133 | super(ResNet, self).__init__() 134 | self.inplanes = 64 135 | self.efficient = efficient 136 | self.use_bn = use_bn 137 | self.separable = separable 138 | self.register_buffer('img_mean', torch.tensor(mean).view(1, -1, 1, 1)) 139 | self.register_buffer('img_std', torch.tensor(std).view(1, -1, 1, 1)) 140 | if scale != 1: 141 | self.register_buffer('img_scale', torch.tensor(scale).view(1, -1, 1, 1).float()) 142 | 143 | self.detach_upsample_in = detach_upsample_in 144 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 145 | bias=False) 146 | self.bn1 = nn.BatchNorm2d(64) if self.use_bn else lambda x: x 147 | self.relu = nn.ReLU(inplace=True) 148 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 149 | self.target_size = target_size 150 | if self.target_size is not None: 151 | h, w = target_size 152 | target_sizes = [(h // 2 ** i, w // 2 ** i) for i in range(2, 6)] 153 | else: 154 | target_sizes = [None] * 4 155 | upsamples = [] 156 | self.layer1 = self._make_layer(block, 64, layers[0]) 157 | upsamples += [ 158 | _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip, 159 | only_skip=upsample_only_skip, detach_skip=2 in detach_upsample_skips, fixed_size=target_sizes[0], 160 | separable=upsample_separable)] 161 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 162 | upsamples += [ 163 | _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip, 164 | only_skip=upsample_only_skip, detach_skip=1 in detach_upsample_skips, fixed_size=target_sizes[1], 165 | separable=upsample_separable)] 166 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 167 | upsamples += [ 168 | _Upsample(num_features, self.inplanes, num_features, use_bn=self.use_bn, k=k_up, use_skip=upsample_skip, 169 | only_skip=upsample_only_skip, detach_skip=0 in detach_upsample_skips, fixed_size=target_sizes[2], 170 | separable=upsample_separable)] 171 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 172 | 173 | self.fine_tune = [self.conv1, self.maxpool, self.layer1, self.layer2, self.layer3, self.layer4] 174 | if self.use_bn: 175 | self.fine_tune += [self.bn1] 176 | 177 | num_levels = 3 178 | self.spp_size = kwargs.get('spp_size', num_features) 179 | bt_size = self.spp_size 180 | 181 | level_size = self.spp_size // num_levels 182 | 183 | self.spp = SpatialPyramidPooling(self.inplanes, num_levels, bt_size=bt_size, level_size=level_size, 184 | out_size=num_features, grids=spp_grids, square_grid=spp_square_grid, 185 | bn_momentum=0.01 / 2, use_bn=self.use_bn, drop_rate=spp_drop_rate 186 | , fixed_size=target_sizes[3]) 187 | num_up_remove = max(0, int(log2(output_stride) - 2)) 188 | self.upsample = nn.ModuleList(list(reversed(upsamples[num_up_remove:]))) 189 | 190 | self.random_init = [self.spp, self.upsample] 191 | 192 | self.num_features = num_features 193 | 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 197 | elif isinstance(m, nn.BatchNorm2d): 198 | nn.init.constant_(m.weight, 1) 199 | nn.init.constant_(m.bias, 0) 200 | 201 | def _make_layer(self, block, planes, blocks, stride=1): 202 | downsample = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | layers = [nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False)] 205 | if self.use_bn: 206 | layers += [nn.BatchNorm2d(planes * block.expansion)] 207 | downsample = nn.Sequential(*layers) 208 | layers = [block(self.inplanes, planes, stride, downsample, efficient=self.efficient, use_bn=self.use_bn, 209 | separable=self.separable)] 210 | self.inplanes = planes * block.expansion 211 | for i in range(1, blocks): 212 | layers += [block(self.inplanes, planes, efficient=self.efficient, use_bn=self.use_bn, 213 | separable=self.separable)] 214 | 215 | return nn.Sequential(*layers) 216 | 217 | def random_init_params(self): 218 | return chain(*[f.parameters() for f in self.random_init]) 219 | 220 | def fine_tune_params(self): 221 | return chain(*[f.parameters() for f in self.fine_tune]) 222 | 223 | def forward_resblock(self, x, layers): 224 | skip = None 225 | for l in layers: 226 | x = l(x) 227 | if isinstance(x, tuple): 228 | x, skip = x 229 | return x, skip 230 | 231 | def forward_down(self, image): 232 | if hasattr(self, 'img_scale'): 233 | image /= self.img_scale 234 | image -= self.img_mean 235 | image /= self.img_std 236 | 237 | x = self.conv1(image) 238 | x = self.bn1(x) 239 | x = self.relu(x) 240 | x = self.maxpool(x) 241 | 242 | features = [] 243 | x, skip = self.forward_resblock(x, self.layer1) 244 | features += [skip] 245 | x, skip = self.forward_resblock(x, self.layer2) 246 | features += [skip] 247 | x, skip = self.forward_resblock(x, self.layer3) 248 | features += [skip] 249 | x, skip = self.forward_resblock(x, self.layer4) 250 | features += [self.spp.forward(skip)] 251 | return features 252 | 253 | def forward_up(self, features): 254 | features = features[::-1] 255 | 256 | x = features[0] 257 | if self.detach_upsample_in: 258 | x = x.detach() 259 | 260 | upsamples = [] 261 | for skip, up in zip(features[1:], self.upsample): 262 | x = up(x, skip) 263 | upsamples += [x] 264 | return x, {'features': features, 'upsamples': upsamples} 265 | 266 | def forward(self, image): 267 | return self.forward_up(self.forward_down(image)) 268 | 269 | 270 | def resnet18(pretrained=True, **kwargs): 271 | """Constructs a ResNet-18 model. 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | """ 275 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 276 | if pretrained: 277 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) 278 | return model 279 | 280 | 281 | def resnet18dws(pretrained=True, **kwargs): 282 | """Constructs a ResNet-18 model. 283 | Args: 284 | pretrained (bool): If True, returns a model pre-trained on ImageNet 285 | """ 286 | model = ResNet(BasicBlock, [2, 2, 2, 2], separable=True, **kwargs) 287 | if pretrained: 288 | try: 289 | model.load_state_dict(torch.load(model_urls['resnet18dws'])['state_dict'], strict=True) 290 | except Exception as e: 291 | print(e) 292 | return model 293 | 294 | 295 | def resnet34(pretrained=True, **kwargs): 296 | """Constructs a ResNet-34 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | """ 300 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 301 | if pretrained: 302 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False) 303 | return model 304 | 305 | 306 | def resnet50(pretrained=True, **kwargs): 307 | """Constructs a ResNet-50 model. 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 312 | if pretrained: 313 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False) 314 | return model 315 | 316 | 317 | def resnet101(pretrained=True, **kwargs): 318 | """Constructs a ResNet-101 model. 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | """ 322 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 323 | if pretrained: 324 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False) 325 | return model 326 | 327 | 328 | def resnet152(pretrained=True, **kwargs): 329 | """Constructs a ResNet-152 model. 330 | Args: 331 | pretrained (bool): If True, returns a model pre-trained on ImageNet 332 | """ 333 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 334 | if pretrained: 335 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False) 336 | return model 337 | -------------------------------------------------------------------------------- /models/semseg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from itertools import chain 5 | import warnings 6 | 7 | from .util import _BNReluConv, upsample 8 | 9 | 10 | class SemsegModel(nn.Module): 11 | def __init__(self, backbone, num_classes, num_inst_classes=None, use_bn=True, k=1, bias=True, 12 | loss_ret_additional=False, upsample_logits=True, logit_class=_BNReluConv, 13 | multiscale_factors=(.5, .75, 1.5, 2.)): 14 | super(SemsegModel, self).__init__() 15 | self.backbone = backbone 16 | self.num_classes = num_classes 17 | self.logits = logit_class(self.backbone.num_features, self.num_classes, batch_norm=use_bn, k=k, bias=bias) 18 | if num_inst_classes is not None: 19 | self.border_logits = _BNReluConv(self.backbone.num_features, num_inst_classes, batch_norm=use_bn, 20 | k=k, bias=bias) 21 | self.criterion = None 22 | self.loss_ret_additional = loss_ret_additional 23 | self.img_req_grad = loss_ret_additional 24 | self.upsample_logits = upsample_logits 25 | self.multiscale_factors = multiscale_factors 26 | 27 | def forward(self, image, target_size, image_size): 28 | features, additional = self.backbone(image) 29 | logits = self.logits.forward(features) 30 | if (not self.training) or self.upsample_logits: 31 | logits = upsample(logits, image_size) 32 | if hasattr(self, 'border_logits'): 33 | additional['border_logits'] = self.border_logits(features).sigmoid() 34 | additional['logits'] = logits 35 | return logits, additional 36 | 37 | def forward_down(self, image, target_size, image_size): 38 | return self.backbone.forward_down(image), target_size, image_size 39 | 40 | def forward_up(self, feats, target_size, image_size): 41 | feats, additional = self.backbone.forward_up(feats) 42 | features = upsample(feats, target_size) 43 | logits = self.logits.forward(features) 44 | logits = upsample(logits, image_size) 45 | return logits, additional 46 | 47 | def prepare_data(self, batch, image_size, device=torch.device('cuda'), img_key='image'): 48 | if image_size is None: 49 | image_size = batch['target_size'] 50 | warnings.warn(f'Image requires grad: {self.img_req_grad}', UserWarning) 51 | image = batch[img_key].detach().requires_grad_(self.img_req_grad).to(device) 52 | return { 53 | 'image': image, 54 | 'image_size': image_size, 55 | 'target_size': batch.get('target_size_feats') 56 | } 57 | 58 | def do_forward(self, batch, image_size=None): 59 | data = self.prepare_data(batch, image_size) 60 | logits, additional = self.forward(**data) 61 | additional['model'] = self 62 | additional = {**additional, **data} 63 | return logits, additional 64 | 65 | def loss(self, batch): 66 | assert self.criterion is not None 67 | labels = batch['labels'].cuda() 68 | logits, additional = self.do_forward(batch, image_size=labels.shape[-2:]) 69 | if self.loss_ret_additional: 70 | return self.criterion(logits, labels, batch=batch, additional=additional), additional 71 | return self.criterion(logits, labels, batch=batch, additional=additional) 72 | 73 | def random_init_params(self): 74 | params = [self.logits.parameters(), self.backbone.random_init_params()] 75 | if hasattr(self, 'border_logits'): 76 | params += [self.border_logits.parameters()] 77 | return chain(*(params)) 78 | 79 | def fine_tune_params(self): 80 | return self.backbone.fine_tune_params() 81 | 82 | def ms_forward(self, batch, image_size=None): 83 | image_size = batch.get('target_size', image_size if image_size is not None else batch['image'].shape[-2:]) 84 | ms_logits = None 85 | pyramid = [batch['image'].cuda()] 86 | pyramid += [ 87 | F.interpolate(pyramid[0], scale_factor=sf, mode=self.backbone.pyramid_subsample, 88 | align_corners=self.backbone.align_corners) for sf in self.multiscale_factors 89 | ] 90 | for image in pyramid: 91 | batch['image'] = image 92 | logits, additional = self.do_forward(batch, image_size=image_size) 93 | if ms_logits is None: 94 | ms_logits = torch.zeros(logits.size()).to(logits.device) 95 | ms_logits += F.softmax(logits, dim=1) 96 | batch['image'] = pyramid[0].cpu() 97 | return ms_logits / len(pyramid), {} -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import warnings 4 | 5 | from torch import nn as nn 6 | 7 | upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False) 8 | batchnorm_momentum = 0.01 / 2 9 | 10 | 11 | def get_n_params(parameters): 12 | pp = 0 13 | for p in parameters: 14 | nn = 1 15 | for s in list(p.size()): 16 | nn = nn * s 17 | pp += nn 18 | return pp 19 | 20 | 21 | class SeparableConv2d(nn.Module): 22 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): 23 | super(SeparableConv2d, self).__init__() 24 | 25 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, 26 | bias=bias) 27 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.pointwise(x) 32 | return x 33 | 34 | 35 | class _BNReluConv(nn.Sequential): 36 | def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1, 37 | drop_rate=.0, separable=False): 38 | super(_BNReluConv, self).__init__() 39 | if batch_norm: 40 | self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum)) 41 | self.add_module('relu', nn.ReLU(inplace=batch_norm is True)) 42 | padding = k // 2 43 | conv_class = SeparableConv2d if separable else nn.Conv2d 44 | warnings.warn(f'Using conv type {k}x{k}: {conv_class}') 45 | self.add_module('conv', conv_class(num_maps_in, num_maps_out, kernel_size=k, padding=padding, bias=bias, 46 | dilation=dilation)) 47 | if drop_rate > 0: 48 | warnings.warn(f'Using dropout with p: {drop_rate}') 49 | self.add_module('dropout', nn.Dropout2d(drop_rate, inplace=True)) 50 | 51 | 52 | class _Upsample(nn.Module): 53 | def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3, use_skip=True, only_skip=False, 54 | detach_skip=False, fixed_size=None, separable=False, bneck_starts_with_bn=True): 55 | super(_Upsample, self).__init__() 56 | print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}') 57 | self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn and bneck_starts_with_bn) 58 | self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn, separable=separable) 59 | self.use_skip = use_skip 60 | self.only_skip = only_skip 61 | self.detach_skip = detach_skip 62 | warnings.warn(f'\tUsing skips: {self.use_skip} (only skips: {self.only_skip})', UserWarning) 63 | self.upsampling_method = upsample 64 | if fixed_size is not None: 65 | self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) 66 | warnings.warn(f'Fixed upsample size', UserWarning) 67 | 68 | def forward(self, x, skip): 69 | skip = self.bottleneck.forward(skip) 70 | if self.detach_skip: 71 | skip = skip.detach() 72 | skip_size = skip.size()[2:4] 73 | x = self.upsampling_method(x, skip_size) 74 | if self.use_skip: 75 | x = x + skip 76 | x = self.blend_conv.forward(x) 77 | return x 78 | 79 | 80 | class _UpsampleBlend(nn.Module): 81 | def __init__(self, num_features, use_bn=True, use_skip=True, detach_skip=False, fixed_size=None, k=3, 82 | separable=False): 83 | super(_UpsampleBlend, self).__init__() 84 | self.blend_conv = _BNReluConv(num_features, num_features, k=k, batch_norm=use_bn, separable=separable) 85 | self.use_skip = use_skip 86 | self.detach_skip = detach_skip 87 | warnings.warn(f'Using skip connections: {self.use_skip}', UserWarning) 88 | self.upsampling_method = upsample 89 | if fixed_size is not None: 90 | self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) 91 | warnings.warn(f'Fixed upsample size', UserWarning) 92 | 93 | def forward(self, x, skip): 94 | if self.detach_skip: 95 | warnings.warn(f'Detaching skip connection {skip.shape[2:4]}', UserWarning) 96 | skip = skip.detach() 97 | skip_size = skip.size()[-2:] 98 | x = self.upsampling_method(x, skip_size) 99 | if self.use_skip: 100 | x = x + skip 101 | x = self.blend_conv.forward(x) 102 | return x 103 | 104 | 105 | class SpatialPyramidPooling(nn.Module): 106 | def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128, 107 | grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True, drop_rate=.0, 108 | fixed_size=None, starts_with_bn=True): 109 | super(SpatialPyramidPooling, self).__init__() 110 | self.fixed_size = fixed_size 111 | self.grids = grids 112 | if self.fixed_size: 113 | ref = min(self.fixed_size) 114 | self.grids = list(filter(lambda x: x <= ref, self.grids)) 115 | self.square_grid = square_grid 116 | self.upsampling_method = upsample 117 | if self.fixed_size is not None: 118 | self.upsampling_method = lambda x, size: F.interpolate(x, mode='nearest', size=fixed_size) 119 | warnings.warn(f'Fixed upsample size', UserWarning) 120 | self.spp = nn.Sequential() 121 | self.spp.add_module('spp_bn', _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, 122 | batch_norm=use_bn and starts_with_bn)) 123 | num_features = bt_size 124 | final_size = num_features 125 | for i in range(num_levels): 126 | final_size += level_size 127 | self.spp.add_module('spp' + str(i), 128 | _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn, 129 | drop_rate=drop_rate)) 130 | self.spp.add_module('spp_fuse', 131 | _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) 132 | 133 | def forward(self, x): 134 | levels = [] 135 | target_size = self.fixed_size if self.fixed_size is not None else x.size()[2:4] 136 | 137 | ar = target_size[1] / target_size[0] 138 | 139 | x = self.spp[0].forward(x) 140 | levels.append(x) 141 | num = len(self.spp) - 1 142 | 143 | for i in range(1, num): 144 | if not self.square_grid: 145 | grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1]))) 146 | x_pooled = F.adaptive_avg_pool2d(x, grid_size) 147 | else: 148 | x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1]) 149 | level = self.spp[i].forward(x_pooled) 150 | 151 | level = self.upsampling_method(level, target_size) 152 | levels.append(level) 153 | 154 | x = torch.cat(levels, 1) 155 | x = self.spp[-1].forward(x) 156 | return x 157 | 158 | 159 | class Identity(nn.Module): 160 | def __init__(self, *args, **kwargs): 161 | super(Identity, self).__init__() 162 | 163 | def forward(self, input): 164 | return input 165 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==6.2.0 2 | torch==1.3.1 3 | torchvision==0.4.2 4 | numpy==1.17.4 5 | tqdm==4.28.1 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | import torch 5 | import importlib.util 6 | import datetime 7 | import sys 8 | from shutil import copy 9 | import pickle 10 | from time import perf_counter 11 | 12 | from evaluation import evaluate_semseg 13 | 14 | 15 | def import_module(path): 16 | spec = importlib.util.spec_from_file_location("module", path) 17 | module = importlib.util.module_from_spec(spec) 18 | spec.loader.exec_module(module) 19 | return module 20 | 21 | 22 | def store(model, store_path, name): 23 | with open(store_path.format(name), 'wb') as f: 24 | torch.save(model.state_dict(), f) 25 | 26 | 27 | class Logger(object): 28 | def __init__(self, *files): 29 | self.files = files 30 | 31 | def write(self, obj): 32 | for f in self.files: 33 | f.write(obj) 34 | f.flush() # If you want the output to be visible immediately 35 | 36 | def flush(self): 37 | for f in self.files: 38 | f.flush() 39 | 40 | 41 | class Trainer: 42 | def __init__(self, conf, args, name): 43 | self.conf = conf 44 | using_hparams = hasattr(conf, 'hyperparams') 45 | print(f'Using hparams: {using_hparams}') 46 | self.hyperparams = self.conf 47 | self.args = args 48 | self.name = name 49 | self.model = self.conf.model 50 | self.optimizer = self.conf.optimizer 51 | 52 | self.dataset_train = self.conf.dataset_train 53 | self.dataset_val = self.conf.dataset_val 54 | self.loader_train = self.conf.loader_train 55 | self.loader_val = self.conf.loader_val 56 | 57 | def __enter__(self): 58 | self.best_iou = -1 59 | self.best_iou_epoch = -1 60 | self.validation_ious = [] 61 | self.experiment_start = datetime.datetime.now() 62 | 63 | if self.args.resume: 64 | self.experiment_dir = Path(self.args.resume) 65 | print(f'Resuming experiment from {args.resume}') 66 | else: 67 | self.experiment_dir = Path(self.args.store_dir) / ( 68 | self.experiment_start.strftime('%Y_%m_%d_%H_%M_%S_') + self.name) 69 | 70 | self.checkpoint_dir = self.experiment_dir / 'stored' 71 | self.store_path = str(self.checkpoint_dir / '{}.pt') 72 | 73 | if not self.args.dry and not self.args.resume: 74 | os.makedirs(str(self.experiment_dir), exist_ok=True) 75 | os.makedirs(str(self.checkpoint_dir), exist_ok=True) 76 | copy(self.args.config, str(self.experiment_dir / 'config.py')) 77 | 78 | if self.args.log and not self.args.dry: 79 | f = (self.experiment_dir / 'log.txt').open(mode='a') 80 | sys.stdout = Logger(sys.stdout, f) 81 | 82 | self.model.cuda() 83 | 84 | return self 85 | 86 | def __exit__(self, exc_type, exc_val, exc_tb): 87 | if not self.args.dry: 88 | store(self.model, self.store_path, 'model') 89 | if not self.args.dry: 90 | with open(f'{self.experiment_dir}/val_ious.pkl', 'wb') as f: 91 | pickle.dump(self.validation_ious, f) 92 | dir_iou = Path(self.args.store_dir) / (f'{self.best_iou:.2f}_'.replace('.', '-') + self.name) 93 | os.rename(self.experiment_dir, dir_iou) 94 | 95 | def train(self): 96 | num_epochs = self.hyperparams.epochs 97 | start_epoch = self.hyperparams.start_epoch if hasattr(self.hyperparams, 'start_epoch') else 0 98 | for epoch in range(start_epoch, num_epochs): 99 | if hasattr(self.conf, 'epoch'): 100 | self.conf.epoch.value = epoch 101 | print(self.conf.epoch) 102 | self.model.train() 103 | try: 104 | self.conf.lr_scheduler.step() 105 | print(f'Elapsed time: {datetime.datetime.now() - self.experiment_start}') 106 | for group in self.optimizer.param_groups: 107 | print('LR: {:.4e}'.format(group['lr'])) 108 | eval_epoch = ((epoch % self.conf.eval_each == 0) or (epoch == num_epochs - 1)) # and (epoch > 0) 109 | self.model.criterion.step_counter = 0 110 | print(f'Epoch: {epoch} / {num_epochs - 1}') 111 | if eval_epoch and not self.args.dry: 112 | print("Experiment dir: %s" % self.experiment_dir) 113 | batch_iterator = iter(enumerate(self.loader_train)) 114 | start_t = perf_counter() 115 | for step, batch in batch_iterator: 116 | self.optimizer.zero_grad() 117 | loss = self.model.loss(batch) 118 | loss.backward() 119 | self.optimizer.step() 120 | if step % 80 == 0 and step > 0: 121 | curr_t = perf_counter() 122 | print(f'{(step * self.conf.batch_size) / (curr_t - start_t):.2f}fps') 123 | if not self.args.dry: 124 | store(self.model, self.store_path, 'model') 125 | store(self.optimizer, self.store_path, 'optimizer') 126 | if eval_epoch and self.args.eval: 127 | print('Evaluating model') 128 | iou, per_class_iou = evaluate_semseg(self.model, self.loader_val, self.dataset_val.class_info) 129 | self.validation_ious += [iou] 130 | if self.args.eval_train: 131 | print('Evaluating train') 132 | evaluate_semseg(self.model, self.loader_train, self.dataset_train.class_info) 133 | if iou > self.best_iou: 134 | self.best_iou = iou 135 | self.best_iou_epoch = epoch 136 | if not self.args.dry: 137 | copy(self.store_path.format('model'), self.store_path.format('model_best')) 138 | print(f'Best mIoU: {self.best_iou:.2f}% (epoch {self.best_iou_epoch})') 139 | 140 | except KeyboardInterrupt: 141 | break 142 | 143 | 144 | parser = argparse.ArgumentParser(description='Detector train') 145 | parser.add_argument('config', type=str, help='Path to configuration .py file') 146 | parser.add_argument('--store_dir', default='saves/', type=str, help='Path to experiments directory') 147 | parser.add_argument('--resume', default=None, type=str, help='Path to existing experiment dir') 148 | parser.add_argument('--no-log', dest='log', action='store_false', help='Turn off logging') 149 | parser.add_argument('--log', dest='log', action='store_true', help='Turn on train evaluation') 150 | parser.add_argument('--no-eval-train', dest='eval_train', action='store_false', help='Turn off train evaluation') 151 | parser.add_argument('--eval-train', dest='eval_train', action='store_true', help='Turn on train evaluation') 152 | parser.add_argument('--no-eval', dest='eval', action='store_false', help='Turn off evaluation') 153 | parser.add_argument('--eval', dest='eval', action='store_true', help='Turn on evaluation') 154 | parser.add_argument('--dry-run', dest='dry', action='store_true', help='Don\'t store') 155 | parser.set_defaults(log=True) 156 | parser.set_defaults(eval_train=False) 157 | parser.set_defaults(eval=True) 158 | 159 | if __name__ == '__main__': 160 | args = parser.parse_args() 161 | conf_path = Path(args.config) 162 | conf = import_module(args.config) 163 | 164 | with Trainer(conf, args, conf_path.stem) as trainer: 165 | trainer.train() 166 | -------------------------------------------------------------------------------- /weights/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/orsic/swiftnet/2b88990e1ab674e8ef7cb533a1d8d49ef34ac93d/weights/.gitkeep --------------------------------------------------------------------------------