├── LICENSE ├── README.md ├── demo.py ├── main.py ├── pipeline ├── csra.py ├── dataset.py ├── resnet_csra.py ├── timm_utils │ ├── __init__.py │ ├── drop.py │ ├── tuple.py │ └── weight_init.py └── vit_csra.py ├── utils ├── demo_images │ ├── 000001.jpg │ ├── 000002.jpg │ ├── 000004.jpg │ ├── 000006.jpg │ ├── 000007.jpg │ └── 000009.jpg ├── evaluation │ ├── cal_PR.py │ ├── cal_mAP.py │ ├── eval.py │ └── warmUpLR.py ├── pipeline.PNG ├── prepare │ ├── prepare_coco.py │ ├── prepare_voc.py │ └── prepare_wider.py └── visualize.py └── val.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CSRA 2 | This is the official code of ICCV 2021 paper:
3 | [Residual Attention: A Simple But Effective Method for Multi-Label Recoginition](https://arxiv.org/abs/2108.02456)
4 | 5 | ![attention](https://github.com/Kevinz-code/CSRA/blob/master/utils/pipeline.PNG) 6 | 7 | ### Demo, Train and Validation code have been released! (including VIT on Wider-Attribute) 8 | This package is developed by Mr. Ke Zhu (http://www.lamda.nju.edu.cn/zhuk/) and we have just finished the implementation code of ViT models. If you have any question about the code, please feel free to contact Mr. Ke Zhu (zhuk@lamda.nju.edu.cn). The package is free for academic usage. You can run it at your own risk. For other purposes, please contact Prof. Jianxin Wu (mail to 9 | wujx2001@gmail.com). 10 | 11 | ## Requirements 12 | - Python 3.7 13 | - pytorch 1.6 14 | - torchvision 0.7.0 15 | - pycocotools 2.0 16 | - tqdm 4.49.0, pillow 7.2.0 17 | 18 | ## Dataset 19 | We expect VOC2007, COCO2014 and Wider-Attribute dataset to have the following structure: 20 | ``` 21 | Dataset/ 22 | |-- VOCdevkit/ 23 | |---- VOC2007/ 24 | |------ JPEGImages/ 25 | |------ Annotations/ 26 | |------ ImageSets/ 27 | ...... 28 | |-- COCO2014/ 29 | |---- annotations/ 30 | |---- images/ 31 | |------ train2014/ 32 | |------ val2014/ 33 | ...... 34 | |-- WIDER/ 35 | |---- Annotations/ 36 | |------ wider_attribute_test.json 37 | |------ wider_attribute_trainval.json 38 | |---- Image/ 39 | |------ train/ 40 | |------ val/ 41 | |------ test/ 42 | ... 43 | ``` 44 | Then directly run the following command to generate json file (for implementation) of these datasets. 45 | ```shell 46 | python utils/prepare/prepare_voc.py --data_path Dataset/VOCdevkit 47 | python utils/prepare/prepare_coco.py --data_path Dataset/COCO2014 48 | python utils/prepare/prepare_wider.py --data_path Dataset/WIDER 49 | ``` 50 | which will automatically result in annotation json files in *./data/voc07*, *./data/coco* and *./data/wider* 51 | 52 | ## Demo 53 | We provide prediction demos of our models. The demo images (picked from VCO2007) have already been put into *./utils/demo_images/*, you can simply run demo.py by using our CSRA models pretrained on VOC2007: 54 | ```shell 55 | CUDA_VISIBLE_DEVICES=0 python demo.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from OUR_VOC_PRETRAINED.pth --img_dir utils/demo_images 56 | ``` 57 | which will output like this: 58 | ```shell 59 | utils/demo_images/000001.jpg prediction: dog,person, 60 | utils/demo_images/000004.jpg prediction: car, 61 | utils/demo_images/000002.jpg prediction: train, 62 | ... 63 | ``` 64 | 65 | 66 | ## Validation 67 | We provide pretrained models on [Google Drive](https://www.google.com/drive/) for validation. ResNet101 trained on ImageNet with **CutMix** augmentation can be downloaded 68 | [here](https://drive.google.com/u/0/uc?export=download&confirm=kYfp&id=1T4AxsAO2tszvhn62KFN5kaknBtBZIpDV). 69 | |Dataset | Backbone | Head nums | mAP(%) | Resolution | Download | 70 | | ---------- | ------- | :--------: | ------ | :---: | -------- | 71 | | VOC2007 |ResNet-101 | 1 | 94.7 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=bXcv&id=1cQSRI_DWyKpLa0tvxltoH9rM4IZMIEWJ) | 72 | | VOC2007 |ResNet-cut | 1 | 95.2 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=otx_&id=1bzSsWhGG-zUNQRMB7rQCuPMqLZjnrzFh) | 73 | | VOC2007 (extra) |ResNet-cut | 1 | 96.8 | 448x448 |[download](https://drive.google.com/u/0/uc?id=1XgVE3Q3vmE8hjdDjqow_2GyjPx_5bDjU&export=download) | 74 | | COCO |ResNet-101 | 4 | 83.3 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=EWtH&id=1e_WzdVgF_sQc--ubN-DRnGVbbJGSJEZa) | 75 | | COCO |ResNet-cut | 6 | 85.6 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=uEcu&id=17FgLUe_vr5sJX6_TT-MPdP5TYYAcVEPF) | 76 | | COCO |VIT_L16_224 | 8 | 86.5 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=1Rmm&id=1TTzCpRadhYDwZSEow3OVdrh1TKezWHF_)| 77 | | COCO |VIT_L16_224* | 8 | 86.9 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=xpbJ&id=1zYE88pmWcZfcrdQsP8-9JMo4n_g5pO4l)| 78 | | Wider |VIT_B16_224| 1 | 89.0 | 224x224 |[download](https://drive.google.com/u/0/uc?id=1qkJgWQ2EOYri8ITLth_wgnR4kEsv0bfj&export=download) | 79 | | Wider |VIT_L16_224| 1 | 90.2 | 224x224 |[download](https://drive.google.com/u/0/uc?id=1da8D7UP9cMCgKO0bb1gyRvVqYoZ3Wh7O&export=download) | 80 | 81 | For voc2007, run the following validation example: 82 | ```shell 83 | CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --load_from MODEL.pth 84 | ``` 85 | For coco2014, run the following validation example: 86 | ```shell 87 | CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 4 --lam 0.5 --dataset coco --num_cls 80 --load_from MODEL.pth 88 | ``` 89 | For wider attribute with ViT models, run the following 90 | ```shell 91 | CUDA_VISIBLE_DEVICES=0 python val.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 --load_from ViT_B16_MODEL.pth 92 | CUDA_VISIBLE_DEVICES=0 python val.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 --load_from ViT_L16_MODEL.pth 93 | ``` 94 | To provide pretrained VIT models on Wider-Attribute dataset, we retrain them recently, which has a slightly different performance (~0.1%mAP) from what has been presented in our paper. The structure of the VIT models is the initial VIT version (**An image is worth 16x16 words: Transformers for image recognition at scale**, [link](https://arxiv.org/pdf/2010.11929.pdf)) and the implementation code of the VIT models is derived from [http://github.com/rwightman/pytorch-image-models/](http://github.com/rwightman/pytorch-image-models/). 95 | ## Training 96 | #### VOC2007 97 | You can run either of these two lines below 98 | ```shell 99 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 100 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --cutmix CutMix_ResNet101.pth 101 | ``` 102 | Note that the first command uses the Official ResNet-101 backbone while the second command uses the ResNet-101 pretrained on ImageNet with CutMix augmentation 103 | [link](https://drive.google.com/u/0/uc?export=download&confirm=kYfp&id=1T4AxsAO2tszvhn62KFN5kaknBtBZIpDV) (which is supposed to gain better performance). 104 | 105 | #### MS-COCO 106 | run the ResNet-101 with 4 heads 107 | ```shell 108 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.5 --dataset coco --num_cls 80 109 | ``` 110 | run the ResNet-101 (pretrained with CutMix) with 6 heads 111 | ```shell 112 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.4 --dataset coco --num_cls 80 --cutmix CutMix_ResNet101.pth 113 | ``` 114 | You can feel free to adjust the hyper-parameters such as number of attention heads (--num_heads), or the Lambda (--lam). Still, the default values of them in the above command are supposed to be the best. 115 | 116 | #### Wider-Attribute 117 | run the VIT_B16_224 with 1 heads 118 | ```shell 119 | CUDA_VISIBLE_DEVICES=0 python main.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 120 | ``` 121 | run the VIT_L16_224 with 1 heads 122 | ```shell 123 | CUDA_VISIBLE_DEVICES=0,1 python main.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 124 | ``` 125 | Note that the VIT_L16_224 model consume larger GPU space, so we use 2 GPUs to train them. 126 | ## Notice 127 | To avoid confusion, please note the **4 lines of code** in Figure 1 (in paper) is only used in **test** stage (without training), which is our motivation. When our model is end-to-end training and testing, **multi-head-attention** (H=1, H=2, H=4, etc.) is used with different T values. Also, when H=1 and T=infty, the implementation code of **multi-head-attention** is exactly the same with Figure 1. 128 | 129 | We didn't use any new augmentation such as **Autoaugment, RandAugment** in our ResNet series models. 130 | 131 | ## Acknowledgement 132 | 133 | We thank Lin Sui (http://www.lamda.nju.edu.cn/suil/) for his initial contribution to this project. 134 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from tqdm import tqdm 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from pipeline.resnet_csra import ResNet_CSRA 11 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA 12 | from pipeline.dataset import DataSet 13 | from torchvision.transforms import transforms 14 | from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict 15 | 16 | 17 | # Usage: 18 | # This demo is used to predict the label of each image 19 | # if you want to use our models to predict some labels of the VOC2007 images 20 | # 1st: use the models pretrained on VOC2007 21 | # 2nd: put the images in the utils/demo_images 22 | # 3rd: run demo.py 23 | 24 | def Args(): 25 | parser = argparse.ArgumentParser(description="settings") 26 | # model default resnet101 27 | parser.add_argument("--model", default="resnet101", type=str) 28 | parser.add_argument("--num_heads", default=1, type=int) 29 | parser.add_argument("--lam",default=0.1, type=float) 30 | parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str) 31 | parser.add_argument("--img_dir", default="images/", type=str) 32 | 33 | # dataset 34 | parser.add_argument("--dataset", default="voc07", type=str) 35 | parser.add_argument("--num_cls", default=20, type=int) 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def demo(): 42 | args = Args() 43 | 44 | # model 45 | if args.model == "resnet101": 46 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls) 47 | normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) 48 | img_size = 448 49 | if args.model == "vit_B16_224": 50 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 51 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 52 | img_size = 224 53 | if args.model == "vit_L16_224": 54 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 55 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 56 | img_size = 224 57 | 58 | model.cuda() 59 | print("Loading weights from {}".format(args.load_from)) 60 | model.load_state_dict(torch.load(args.load_from)) 61 | 62 | # image pre-process 63 | transform = transforms.Compose([ 64 | transforms.Resize((img_size, img_size)), 65 | transforms.ToTensor(), 66 | normalize 67 | ]) 68 | 69 | # prediction of each image's label 70 | for img_file in os.listdir(args.img_dir): 71 | print(os.path.join(args.img_dir, img_file), end=" prediction: ") 72 | img = Image.open(os.path.join(args.img_dir, img_file)).convert("RGB") 73 | img = transform(img) 74 | img = img.cuda() 75 | img = img.unsqueeze(0) 76 | 77 | model.eval() 78 | logit = model(img).squeeze(0) 79 | logit = nn.Sigmoid()(logit) 80 | 81 | 82 | pos = torch.where(logit > 0.5)[0].cpu().numpy() 83 | for k in pos: 84 | print(class_dict[args.dataset][k], end=",") 85 | print() 86 | 87 | 88 | if __name__ == "__main__": 89 | demo() 90 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | from pipeline.resnet_csra import ResNet_CSRA 8 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA 9 | from pipeline.dataset import DataSet 10 | from utils.evaluation.eval import evaluation 11 | from utils.evaluation.warmUpLR import WarmUpLR 12 | from tqdm import tqdm 13 | 14 | 15 | # modify for wider dataset and vit models 16 | 17 | def Args(): 18 | parser = argparse.ArgumentParser(description="settings") 19 | # model 20 | parser.add_argument("--model", default="resnet101") 21 | parser.add_argument("--num_heads", default=1, type=int) 22 | parser.add_argument("--lam",default=0.1, type=float) 23 | parser.add_argument("--cutmix", default=None, type=str) # the path to load cutmix-pretrained backbone 24 | # dataset 25 | parser.add_argument("--dataset", default="voc07", type=str) 26 | parser.add_argument("--num_cls", default=20, type=int) 27 | parser.add_argument("--train_aug", default=["randomflip", "resizedcrop"], type=list) 28 | parser.add_argument("--test_aug", default=[], type=list) 29 | parser.add_argument("--img_size", default=448, type=int) 30 | parser.add_argument("--batch_size", default=16, type=int) 31 | # optimizer, default SGD 32 | parser.add_argument("--lr", default=0.01, type=float) 33 | parser.add_argument("--momentum", default=0.9, type=float) 34 | parser.add_argument("--w_d", default=0.0001, type=float, help="weight_decay") 35 | parser.add_argument("--warmup_epoch", default=2, type=int) 36 | parser.add_argument("--total_epoch", default=30, type=int) 37 | parser.add_argument("--print_freq", default=100, type=int) 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def train(i, args, model, train_loader, optimizer, warmup_scheduler): 43 | print() 44 | model.train() 45 | epoch_begin = time.time() 46 | for index, data in enumerate(train_loader): 47 | batch_begin = time.time() 48 | img = data['img'].cuda() 49 | target = data['target'].cuda() 50 | 51 | optimizer.zero_grad() 52 | logit, loss = model(img, target) 53 | loss = loss.mean() 54 | loss.backward() 55 | optimizer.step() 56 | t = time.time() - batch_begin 57 | 58 | if index % args.print_freq == 0: 59 | print("Epoch {}[{}/{}]: loss:{:.5f}, lr:{:.5f}, time:{:.4f}".format( 60 | i, 61 | args.batch_size * (index + 1), 62 | len(train_loader.dataset), 63 | loss, 64 | optimizer.param_groups[0]["lr"], 65 | float(t) 66 | )) 67 | 68 | if warmup_scheduler and i <= args.warmup_epoch: 69 | warmup_scheduler.step() 70 | 71 | 72 | t = time.time() - epoch_begin 73 | print("Epoch {} training ends, total {:.2f}s".format(i, t)) 74 | 75 | 76 | def val(i, args, model, test_loader, test_file): 77 | model.eval() 78 | print("Test on Epoch {}".format(i)) 79 | result_list = [] 80 | 81 | # calculate logit 82 | for index, data in enumerate(tqdm(test_loader)): 83 | img = data['img'].cuda() 84 | target = data['target'].cuda() 85 | img_path = data['img_path'] 86 | 87 | with torch.no_grad(): 88 | logit = model(img) 89 | 90 | result = nn.Sigmoid()(logit).cpu().detach().numpy().tolist() 91 | for k in range(len(img_path)): 92 | result_list.append( 93 | { 94 | "file_name": img_path[k].split("/")[-1].split(".")[0], 95 | "scores": result[k] 96 | } 97 | ) 98 | # cal_mAP OP OR 99 | evaluation(result=result_list, types=args.dataset, ann_path=test_file[0]) 100 | 101 | 102 | 103 | def main(): 104 | args = Args() 105 | 106 | # model 107 | if args.model == "resnet101": 108 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls, cutmix=args.cutmix) 109 | if args.model == "vit_B16_224": 110 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 111 | if args.model == "vit_L16_224": 112 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 113 | 114 | model.cuda() 115 | if torch.cuda.device_count() > 1: 116 | print("lets use {} GPUs.".format(torch.cuda.device_count())) 117 | model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) 118 | 119 | # data 120 | if args.dataset == "voc07": 121 | train_file = ["data/voc07/trainval_voc07.json"] 122 | test_file = ['data/voc07/test_voc07.json'] 123 | step_size = 4 124 | if args.dataset == "coco": 125 | train_file = ['data/coco/train_coco2014.json'] 126 | test_file = ['data/coco/val_coco2014.json'] 127 | step_size = 5 128 | if args.dataset == "wider": 129 | train_file = ['data/wider/trainval_wider.json'] 130 | test_file = ["data/wider/test_wider.json"] 131 | step_size = 5 132 | args.train_aug = ["randomflip"] 133 | 134 | train_dataset = DataSet(train_file, args.train_aug, args.img_size, args.dataset) 135 | test_dataset = DataSet(test_file, args.test_aug, args.img_size, args.dataset) 136 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) 137 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 138 | 139 | # optimizer and warmup 140 | backbone, classifier = [], [] 141 | for name, param in model.named_parameters(): 142 | if 'classifier' in name: 143 | classifier.append(param) 144 | else: 145 | backbone.append(param) 146 | optimizer = optim.SGD( 147 | [ 148 | {'params': backbone, 'lr': args.lr}, 149 | {'params': classifier, 'lr': args.lr * 10} 150 | ], 151 | momentum=args.momentum, weight_decay=args.w_d) 152 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1) 153 | 154 | iter_per_epoch = len(train_loader) 155 | if args.warmup_epoch > 0: 156 | warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warmup_epoch) 157 | else: 158 | warmup_scheduler = None 159 | 160 | # training and validation 161 | for i in range(1, args.total_epoch + 1): 162 | train(i, args, model, train_loader, optimizer, warmup_scheduler) 163 | torch.save(model.state_dict(), "checkpoint/{}/epoch_{}.pth".format(args.model, i)) 164 | val(i, args, model, test_loader, test_file) 165 | scheduler.step() 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /pipeline/csra.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class CSRA(nn.Module): # one basic block 7 | def __init__(self, input_dim, num_classes, T, lam): 8 | super(CSRA, self).__init__() 9 | self.T = T # temperature 10 | self.lam = lam # Lambda 11 | self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False) 12 | self.softmax = nn.Softmax(dim=2) 13 | 14 | def forward(self, x): 15 | # x (B d H W) 16 | # normalize classifier 17 | # score (B C HxW) 18 | score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1) 19 | score = score.flatten(2) 20 | base_logit = torch.mean(score, dim=2) 21 | 22 | if self.T == 99: # max-pooling 23 | att_logit = torch.max(score, dim=2)[0] 24 | else: 25 | score_soft = self.softmax(score * self.T) 26 | att_logit = torch.sum(score * score_soft, dim=2) 27 | 28 | return base_logit + self.lam * att_logit 29 | 30 | 31 | 32 | 33 | class MHA(nn.Module): # multi-head attention 34 | temp_settings = { # softmax temperature settings 35 | 1: [1], 36 | 2: [1, 99], 37 | 4: [1, 2, 4, 99], 38 | 6: [1, 2, 3, 4, 5, 99], 39 | 8: [1, 2, 3, 4, 5, 6, 7, 99] 40 | } 41 | 42 | def __init__(self, num_heads, lam, input_dim, num_classes): 43 | super(MHA, self).__init__() 44 | self.temp_list = self.temp_settings[num_heads] 45 | self.multi_head = nn.ModuleList([ 46 | CSRA(input_dim, num_classes, self.temp_list[i], lam) 47 | for i in range(num_heads) 48 | ]) 49 | 50 | def forward(self, x): 51 | logit = 0. 52 | for head in self.multi_head: 53 | logit += head(x) 54 | return logit 55 | -------------------------------------------------------------------------------- /pipeline/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from torchvision.transforms import transforms 5 | import torch 6 | import numpy as np 7 | 8 | # modify for transformation for vit 9 | # modfify wider crop-person images 10 | 11 | 12 | class DataSet(Dataset): 13 | def __init__(self, 14 | ann_files, 15 | augs, 16 | img_size, 17 | dataset, 18 | ): 19 | self.dataset = dataset 20 | self.ann_files = ann_files 21 | self.augment = self.augs_function(augs, img_size) 22 | self.transform = transforms.Compose( 23 | [ 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]) 26 | ] 27 | # In this paper, we normalize the image data to [0, 1] 28 | # You can also use the so called 'ImageNet' Normalization method 29 | ) 30 | self.anns = [] 31 | self.load_anns() 32 | print(self.augment) 33 | 34 | # in wider dataset we use vit models 35 | # so transformation has been changed 36 | if self.dataset == "wider": 37 | self.transform = transforms.Compose( 38 | [ 39 | transforms.ToTensor(), 40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 41 | ] 42 | ) 43 | 44 | def augs_function(self, augs, img_size): 45 | t = [] 46 | if 'randomflip' in augs: 47 | t.append(transforms.RandomHorizontalFlip()) 48 | if 'ColorJitter' in augs: 49 | t.append(transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0)) 50 | if 'resizedcrop' in augs: 51 | t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0))) 52 | if 'RandAugment' in augs: 53 | t.append(RandAugment()) 54 | 55 | t.append(transforms.Resize((img_size, img_size))) 56 | 57 | return transforms.Compose(t) 58 | 59 | def load_anns(self): 60 | self.anns = [] 61 | for ann_file in self.ann_files: 62 | json_data = json.load(open(ann_file, "r")) 63 | self.anns += json_data 64 | 65 | def __len__(self): 66 | return len(self.anns) 67 | 68 | def __getitem__(self, idx): 69 | idx = idx % len(self) 70 | ann = self.anns[idx] 71 | img = Image.open(ann["img_path"]).convert("RGB") 72 | 73 | if self.dataset == "wider": 74 | x, y, w, h = ann['bbox'] 75 | img_area = img.crop([x, y, x+w, y+h]) 76 | img_area = self.augment(img_area) 77 | img_area = self.transform(img_area) 78 | message = { 79 | "img_path": ann['img_path'], 80 | "target": torch.Tensor(ann['target']), 81 | "img": img_area 82 | } 83 | else: # voc and coco 84 | img = self.augment(img) 85 | img = self.transform(img) 86 | message = { 87 | "img_path": ann["img_path"], 88 | "target": torch.Tensor(ann["target"]), 89 | "img": img 90 | } 91 | 92 | return message 93 | # finally, if we use dataloader to get the data, we will get 94 | # { 95 | # "img_path": list, # length = batch_size 96 | # "target": Tensor, # shape: batch_size * num_classes 97 | # "img": Tensor, # shape: batch_size * 3 * 224 * 224 98 | # } 99 | -------------------------------------------------------------------------------- /pipeline/resnet_csra.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import ResNet 2 | from torchvision.models.resnet import Bottleneck, BasicBlock 3 | from .csra import CSRA, MHA 4 | import torch.utils.model_zoo as model_zoo 5 | import logging 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | } 18 | 19 | 20 | 21 | 22 | 23 | 24 | class ResNet_CSRA(ResNet): 25 | arch_settings = { 26 | 18: (BasicBlock, (2, 2, 2, 2)), 27 | 34: (BasicBlock, (3, 4, 6, 3)), 28 | 50: (Bottleneck, (3, 4, 6, 3)), 29 | 101: (Bottleneck, (3, 4, 23, 3)), 30 | 152: (Bottleneck, (3, 8, 36, 3)) 31 | } 32 | 33 | def __init__(self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None): 34 | self.block, self.layers = self.arch_settings[depth] 35 | self.depth = depth 36 | super(ResNet_CSRA, self).__init__(self.block, self.layers) 37 | self.init_weights(pretrained=True, cutmix=cutmix) 38 | 39 | self.classifier = MHA(num_heads, lam, input_dim, num_classes) 40 | self.loss_func = F.binary_cross_entropy_with_logits 41 | 42 | def backbone(self, x): 43 | x = self.conv1(x) 44 | x = self.bn1(x) 45 | x = self.relu(x) 46 | x = self.maxpool(x) 47 | 48 | x = self.layer1(x) 49 | x = self.layer2(x) 50 | x = self.layer3(x) 51 | x = self.layer4(x) 52 | 53 | return x 54 | 55 | def forward_train(self, x, target): 56 | x = self.backbone(x) 57 | logit = self.classifier(x) 58 | loss = self.loss_func(logit, target, reduction="mean") 59 | return logit, loss 60 | 61 | def forward_test(self, x): 62 | x = self.backbone(x) 63 | x = self.classifier(x) 64 | return x 65 | 66 | def forward(self, x, target=None): 67 | if target is not None: 68 | return self.forward_train(x, target) 69 | else: 70 | return self.forward_test(x) 71 | 72 | def init_weights(self, pretrained=True, cutmix=None): 73 | if cutmix is not None: 74 | print("backbone params inited by CutMix pretrained model") 75 | state_dict = torch.load(cutmix) 76 | elif pretrained: 77 | print("backbone params inited by Pytorch official model") 78 | model_url = model_urls["resnet{}".format(self.depth)] 79 | state_dict = model_zoo.load_url(model_url) 80 | 81 | model_dict = self.state_dict() 82 | try: 83 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict} 84 | self.load_state_dict(pretrained_dict) 85 | except: 86 | logger = logging.getLogger() 87 | logger.info( 88 | "the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix...") 89 | state_dict = self._keysFix(model_dict, state_dict) 90 | self.load_state_dict(state_dict) 91 | 92 | # remove the original 1000-class fc 93 | self.fc = nn.Sequential() -------------------------------------------------------------------------------- /pipeline/timm_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .tuple import to_ntuple, to_2tuple, to_3tuple, to_4tuple 2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 3 | from .weight_init import trunc_normal_ 4 | 5 | -------------------------------------------------------------------------------- /pipeline/timm_utils/drop.py: -------------------------------------------------------------------------------- 1 | """ DropBlock, DropPath 2 | 3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. 4 | 5 | Papers: 6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) 7 | 8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) 9 | 10 | Code: 11 | DropBlock impl inspired by two Tensorflow impl that I liked: 12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py 14 | 15 | Hacked together by / Copyright 2020 Ross Wightman 16 | """ 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | def drop_block_2d( 23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, 24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 26 | 27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training 28 | runs with success, but needs further validation and possibly optimization for lower runtime impact. 29 | """ 30 | B, C, H, W = x.shape 31 | total_size = W * H 32 | clipped_block_size = min(block_size, min(W, H)) 33 | # seed_drop_rate, the gamma parameter 34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 35 | (W - block_size + 1) * (H - block_size + 1)) 36 | 37 | # Forces the block to be inside the feature map. 38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) 39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ 40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) 41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype) 42 | 43 | if batchwise: 44 | # one mask for whole batch, quite a bit faster 45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) 46 | else: 47 | uniform_noise = torch.rand_like(x) 48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype) 49 | block_mask = -F.max_pool2d( 50 | -block_mask, 51 | kernel_size=clipped_block_size, # block_size, 52 | stride=1, 53 | padding=clipped_block_size // 2) 54 | 55 | if with_noise: 56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 57 | if inplace: 58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask)) 59 | else: 60 | x = x * block_mask + normal_noise * (1 - block_mask) 61 | else: 62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype) 63 | if inplace: 64 | x.mul_(block_mask * normalize_scale) 65 | else: 66 | x = x * block_mask * normalize_scale 67 | return x 68 | 69 | 70 | def drop_block_fast_2d( 71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7, 72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): 73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 74 | 75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid 76 | block mask at edges. 77 | """ 78 | B, C, H, W = x.shape 79 | total_size = W * H 80 | clipped_block_size = min(block_size, min(W, H)) 81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( 82 | (W - block_size + 1) * (H - block_size + 1)) 83 | 84 | if batchwise: 85 | # one mask for whole batch, quite a bit faster 86 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma 87 | else: 88 | # mask per batch element 89 | block_mask = torch.rand_like(x) < gamma 90 | block_mask = F.max_pool2d( 91 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2) 92 | 93 | if with_noise: 94 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x) 95 | if inplace: 96 | x.mul_(1. - block_mask).add_(normal_noise * block_mask) 97 | else: 98 | x = x * (1. - block_mask) + normal_noise * block_mask 99 | else: 100 | block_mask = 1 - block_mask 101 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype) 102 | if inplace: 103 | x.mul_(block_mask * normalize_scale) 104 | else: 105 | x = x * block_mask * normalize_scale 106 | return x 107 | 108 | 109 | class DropBlock2d(nn.Module): 110 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf 111 | """ 112 | def __init__(self, 113 | drop_prob=0.1, 114 | block_size=7, 115 | gamma_scale=1.0, 116 | with_noise=False, 117 | inplace=False, 118 | batchwise=False, 119 | fast=True): 120 | super(DropBlock2d, self).__init__() 121 | self.drop_prob = drop_prob 122 | self.gamma_scale = gamma_scale 123 | self.block_size = block_size 124 | self.with_noise = with_noise 125 | self.inplace = inplace 126 | self.batchwise = batchwise 127 | self.fast = fast # FIXME finish comparisons of fast vs not 128 | 129 | def forward(self, x): 130 | if not self.training or not self.drop_prob: 131 | return x 132 | if self.fast: 133 | return drop_block_fast_2d( 134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 135 | else: 136 | return drop_block_2d( 137 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise) 138 | 139 | 140 | def drop_path(x, drop_prob: float = 0., training: bool = False): 141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 142 | 143 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 144 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 145 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 146 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 147 | 'survival rate' as the argument. 148 | 149 | """ 150 | if drop_prob == 0. or not training: 151 | return x 152 | keep_prob = 1 - drop_prob 153 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 154 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 155 | random_tensor.floor_() # binarize 156 | output = x.div(keep_prob) * random_tensor 157 | return output 158 | 159 | 160 | class DropPath(nn.Module): 161 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 162 | """ 163 | def __init__(self, drop_prob=None): 164 | super(DropPath, self).__init__() 165 | self.drop_prob = drop_prob 166 | 167 | def forward(self, x): 168 | return drop_path(x, self.drop_prob, self.training) 169 | -------------------------------------------------------------------------------- /pipeline/timm_utils/tuple.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | from torch._six import container_abcs 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, container_abcs.Iterable): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /pipeline/timm_utils/weight_init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import warnings 4 | 5 | 6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 9 | def norm_cdf(x): 10 | # Computes standard normal cumulative distribution function 11 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 12 | 13 | if (mean < a - 2 * std) or (mean > b + 2 * std): 14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 15 | "The distribution of values may be incorrect.", 16 | stacklevel=2) 17 | 18 | with torch.no_grad(): 19 | # Values are generated by using a truncated uniform distribution and 20 | # then using the inverse CDF for the normal distribution. 21 | # Get upper and lower cdf values 22 | l = norm_cdf((a - mean) / std) 23 | u = norm_cdf((b - mean) / std) 24 | 25 | # Uniformly fill tensor with values from [l, u], then translate to 26 | # [2l-1, 2u-1]. 27 | tensor.uniform_(2 * l - 1, 2 * u - 1) 28 | 29 | # Use inverse cdf transform for normal distribution to get truncated 30 | # standard normal 31 | tensor.erfinv_() 32 | 33 | # Transform to proper mean, std 34 | tensor.mul_(std * math.sqrt(2.)) 35 | tensor.add_(mean) 36 | 37 | # Clamp to ensure it's in the proper range 38 | tensor.clamp_(min=a, max=b) 39 | return tensor 40 | 41 | 42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 43 | # type: (Tensor, float, float, float, float) -> Tensor 44 | r"""Fills the input Tensor with values drawn from a truncated 45 | normal distribution. The values are effectively drawn from the 46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 47 | with values outside :math:`[a, b]` redrawn until they are within 48 | the bounds. The method used for generating the random values works 49 | best when :math:`a \leq \text{mean} \leq b`. 50 | Args: 51 | tensor: an n-dimensional `torch.Tensor` 52 | mean: the mean of the normal distribution 53 | std: the standard deviation of the normal distribution 54 | a: the minimum cutoff value 55 | b: the maximum cutoff value 56 | Examples: 57 | >>> w = torch.empty(3, 5) 58 | >>> nn.init.trunc_normal_(w) 59 | """ 60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 61 | -------------------------------------------------------------------------------- /pipeline/vit_csra.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in 4 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 5 | 6 | The official jax code is released and available at https://github.com/google-research/vision_transformer 7 | 8 | Status/TODO: 9 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 10 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 11 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 12 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 13 | 14 | Acknowledgments: 15 | * The paper authors for releasing code and weights, thanks! 16 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 17 | for some einops/einsum fun 18 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 19 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 20 | 21 | Hacked together by / Copyright 2020 Ross Wightman 22 | """ 23 | import math 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.utils.model_zoo as model_zoo 28 | from functools import partial 29 | from .timm_utils import DropPath, to_2tuple, trunc_normal_ 30 | from .csra import MHA, CSRA 31 | 32 | 33 | default_cfgs = { 34 | 'vit_base_patch16_224': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 35 | 'vit_large_patch16_224':'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth' 36 | } 37 | 38 | 39 | 40 | class Mlp(nn.Module): 41 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 42 | super().__init__() 43 | out_features = out_features or in_features 44 | hidden_features = hidden_features or in_features 45 | self.fc1 = nn.Linear(in_features, hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.drop = nn.Dropout(drop) 49 | 50 | def forward(self, x): 51 | x = self.fc1(x) 52 | x = self.act(x) 53 | x = self.drop(x) 54 | x = self.fc2(x) 55 | x = self.drop(x) 56 | return x 57 | 58 | 59 | class Attention(nn.Module): 60 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 61 | super().__init__() 62 | self.num_heads = num_heads 63 | head_dim = dim // num_heads # 64 64 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 65 | self.scale = qk_scale or head_dim ** -0.5 66 | 67 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 68 | self.attn_drop = nn.Dropout(attn_drop) 69 | self.proj = nn.Linear(dim, dim) 70 | self.proj_drop = nn.Dropout(proj_drop) 71 | 72 | def forward(self, x): 73 | B, N, C = x.shape 74 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 75 | # qkv (3, B, 12, N, C/12) 76 | # q (B, 12, N, C/12) 77 | # k (B, 12, N, C/12) 78 | # v (B, 12, N, C/12) 79 | # attn (B, 12, N, N) 80 | # x (B, 12, N, C/12) 81 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 82 | 83 | attn = (q @ k.transpose(-2, -1)) * self.scale 84 | attn = attn.softmax(dim=-1) 85 | attn = self.attn_drop(attn) 86 | 87 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 88 | 89 | x = self.proj(x) 90 | x = self.proj_drop(x) 91 | 92 | return x 93 | 94 | 95 | class Block(nn.Module): 96 | 97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 98 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 99 | super().__init__() 100 | self.norm1 = norm_layer(dim) 101 | self.attn = Attention( 102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 103 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 104 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 105 | self.norm2 = norm_layer(dim) 106 | mlp_hidden_dim = int(dim * mlp_ratio) 107 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 108 | 109 | def forward(self, x): 110 | x = x + self.drop_path(self.attn(self.norm1(x))) 111 | x = x + self.drop_path(self.mlp(self.norm2(x))) 112 | return x 113 | 114 | 115 | class PatchEmbed(nn.Module): 116 | """ Image to Patch Embedding 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 119 | super().__init__() 120 | img_size = to_2tuple(img_size) 121 | patch_size = to_2tuple(patch_size) 122 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 123 | self.img_size = img_size 124 | self.patch_size = patch_size 125 | self.num_patches = num_patches 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 128 | 129 | def forward(self, x): 130 | B, C, H, W = x.shape 131 | # FIXME look at relaxing size constraints 132 | assert H == self.img_size[0] and W == self.img_size[1], \ 133 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 134 | x = self.proj(x).flatten(2).transpose(1, 2) 135 | return x 136 | 137 | 138 | class HybridEmbed(nn.Module): 139 | """ CNN Feature Map Embedding 140 | Extract feature map from CNN, flatten, project to embedding dim. 141 | """ 142 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 143 | super().__init__() 144 | assert isinstance(backbone, nn.Module) 145 | img_size = to_2tuple(img_size) 146 | self.img_size = img_size 147 | self.backbone = backbone 148 | if feature_size is None: 149 | with torch.no_grad(): 150 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 151 | # map for all networks, the feature metadata has reliable channel and stride info, but using 152 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 153 | training = backbone.training 154 | if training: 155 | backbone.eval() 156 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 157 | feature_size = o.shape[-2:] 158 | feature_dim = o.shape[1] 159 | backbone.train(training) 160 | else: 161 | feature_size = to_2tuple(feature_size) 162 | feature_dim = self.backbone.feature_info.channels()[-1] 163 | self.num_patches = feature_size[0] * feature_size[1] 164 | self.proj = nn.Linear(feature_dim, embed_dim) 165 | 166 | def forward(self, x): 167 | x = self.backbone(x)[-1] 168 | x = x.flatten(2).transpose(1, 2) 169 | x = self.proj(x) 170 | return x 171 | 172 | 173 | class VIT_CSRA(nn.Module): 174 | """ Vision Transformer with support for patch or hybrid CNN input stage 175 | """ 176 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 177 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 178 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cls_num_heads=1, cls_num_cls=80, lam=0.3): 179 | super().__init__() 180 | self.add_w = 0. 181 | self.normalize = False 182 | self.num_classes = num_classes 183 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 184 | 185 | if hybrid_backbone is not None: 186 | self.patch_embed = HybridEmbed( 187 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 188 | else: 189 | self.patch_embed = PatchEmbed( 190 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 191 | num_patches = self.patch_embed.num_patches 192 | self.HW = int(math.sqrt(num_patches)) 193 | 194 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 195 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 196 | self.pos_drop = nn.Dropout(p=drop_rate) 197 | 198 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 199 | self.blocks = nn.ModuleList([ 200 | Block( 201 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 202 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 203 | for i in range(depth)]) 204 | self.norm = norm_layer(embed_dim) 205 | 206 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 207 | #self.repr = nn.Linear(embed_dim, representation_size) 208 | #self.repr_act = nn.Tanh() 209 | 210 | trunc_normal_(self.pos_embed, std=.02) 211 | trunc_normal_(self.cls_token, std=.02) 212 | self.apply(self._init_weights) 213 | 214 | # We add our MHA (CSRA) beside the orginal VIT structure below 215 | self.head = nn.Sequential() # delete original classifier 216 | self.classifier = MHA(input_dim=embed_dim, num_heads=cls_num_heads, num_classes=cls_num_cls, lam=lam) 217 | 218 | self.loss_func = F.binary_cross_entropy_with_logits 219 | 220 | def _init_weights(self, m): 221 | if isinstance(m, nn.Linear): 222 | trunc_normal_(m.weight, std=.02) 223 | if isinstance(m, nn.Linear) and m.bias is not None: 224 | nn.init.constant_(m.bias, 0) 225 | elif isinstance(m, nn.LayerNorm): 226 | nn.init.constant_(m.bias, 0) 227 | nn.init.constant_(m.weight, 1.0) 228 | 229 | def backbone(self, x): 230 | B = x.shape[0] 231 | x = self.patch_embed(x) 232 | 233 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 234 | x = torch.cat((cls_tokens, x), dim=1) 235 | x = x + self.pos_embed 236 | x = self.pos_drop(x) 237 | 238 | for blk in self.blocks: 239 | x = blk(x) 240 | x = self.norm(x) 241 | 242 | # (B, 1+HW, C) 243 | # we use all the feature to form the tensor like B C H W 244 | x = x[:, 1:] 245 | b, hw, c = x.shape 246 | x = x.transpose(1, 2) 247 | x = x.reshape(b, c, self.HW, self.HW) 248 | 249 | return x 250 | 251 | def forward_train(self, x, target): 252 | x = self.backbone(x) 253 | logit = self.classifier(x) 254 | loss = self.loss_func(logit, target, reduction="mean") 255 | return logit, loss 256 | 257 | def forward_test(self, x): 258 | x = self.backbone(x) 259 | x = self.classifier(x) 260 | return x 261 | 262 | def forward(self, x, target=None): 263 | if target is not None: 264 | return self.forward_train(x, target) 265 | else: 266 | return self.forward_test(x) 267 | 268 | 269 | 270 | 271 | def _conv_filter(state_dict, patch_size=16): 272 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 273 | out_dict = {} 274 | for k, v in state_dict.items(): 275 | if 'patch_embed.proj.weight' in k: 276 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 277 | out_dict[k] = v 278 | return out_dict 279 | 280 | 281 | def VIT_B16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3): 282 | model = VIT_CSRA( 283 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 284 | norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam) 285 | 286 | model_url = default_cfgs['vit_base_patch16_224'] 287 | if pretrained: 288 | state_dict = model_zoo.load_url(model_url) 289 | model.load_state_dict(state_dict, strict=False) 290 | return model 291 | 292 | 293 | def VIT_L16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3): 294 | model = VIT_CSRA( 295 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 296 | norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam) 297 | 298 | model_url = default_cfgs['vit_large_patch16_224'] 299 | if pretrained: 300 | state_dict = model_zoo.load_url(model_url) 301 | model.load_state_dict(state_dict, strict=False) 302 | # load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 303 | return model -------------------------------------------------------------------------------- /utils/demo_images/000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000001.jpg -------------------------------------------------------------------------------- /utils/demo_images/000002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000002.jpg -------------------------------------------------------------------------------- /utils/demo_images/000004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000004.jpg -------------------------------------------------------------------------------- /utils/demo_images/000006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000006.jpg -------------------------------------------------------------------------------- /utils/demo_images/000007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000007.jpg -------------------------------------------------------------------------------- /utils/demo_images/000009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000009.jpg -------------------------------------------------------------------------------- /utils/evaluation/cal_PR.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | 5 | 6 | def json_metric(score_json, target_json, num_classes, types): 7 | assert len(score_json) == len(target_json) 8 | scores = np.zeros((len(score_json), num_classes)) 9 | targets = np.zeros((len(target_json), num_classes)) 10 | for index in range(len(score_json)): 11 | scores[index] = score_json[index]["scores"] 12 | targets[index] = target_json[index]["target"] 13 | 14 | 15 | return metric(scores, targets, types) 16 | 17 | def json_metric_top3(score_json, target_json, num_classes, types): 18 | assert len(score_json) == len(target_json) 19 | scores = np.zeros((len(score_json), num_classes)) 20 | targets = np.zeros((len(target_json), num_classes)) 21 | for index in range(len(score_json)): 22 | tmp = np.array(score_json[index]['scores']) 23 | idx = np.argsort(-tmp) 24 | idx_after_3 = idx[3:] 25 | tmp[idx_after_3] = 0. 26 | 27 | scores[index] = tmp 28 | # scores[index] = score_json[index]["scores"] 29 | targets[index] = target_json[index]["target"] 30 | 31 | return metric(scores, targets, types) 32 | 33 | 34 | def metric(scores, targets, types): 35 | """ 36 | :param scores: the output the model predict 37 | :param targets: the gt label 38 | :return: OP, OR, OF1, CP, CR, CF1 39 | calculate the Precision of every class by: TP/TP+FP i.e. TP/total predict 40 | calculate the Recall by: TP/total GT 41 | """ 42 | num, num_class = scores.shape 43 | gt_num = np.zeros(num_class) 44 | tp_num = np.zeros(num_class) 45 | predict_num = np.zeros(num_class) 46 | 47 | 48 | for index in range(num_class): 49 | score = scores[:, index] 50 | target = targets[:, index] 51 | if types == 'wider': 52 | tmp = np.where(target == 99)[0] 53 | # score[tmp] = 0 54 | target[tmp] = 0 55 | 56 | if types == 'voc07': 57 | tmp = np.where(target != 0)[0] 58 | score = score[tmp] 59 | target = target[tmp] 60 | neg_id = np.where(target == -1)[0] 61 | target[neg_id] = 0 62 | 63 | 64 | gt_num[index] = np.sum(target == 1) 65 | predict_num[index] = np.sum(score >= 0.5) 66 | tp_num[index] = np.sum(target * (score >= 0.5)) 67 | 68 | predict_num[predict_num == 0] = 1 # avoid dividing 0 69 | OP = np.sum(tp_num) / np.sum(predict_num) 70 | OR = np.sum(tp_num) / np.sum(gt_num) 71 | OF1 = (2 * OP * OR) / (OP + OR) 72 | 73 | #print(tp_num / predict_num) 74 | #print(tp_num / gt_num) 75 | CP = np.sum(tp_num / predict_num) / num_class 76 | CR = np.sum(tp_num / gt_num) / num_class 77 | CF1 = (2 * CP * CR) / (CP + CR) 78 | 79 | return OP, OR, OF1, CP, CR, CF1 80 | -------------------------------------------------------------------------------- /utils/evaluation/cal_mAP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import json 5 | 6 | 7 | def json_map(cls_id, pred_json, ann_json, types): 8 | assert len(ann_json) == len(pred_json) 9 | num = len(ann_json) 10 | predict = np.zeros((num), dtype=np.float64) 11 | target = np.zeros((num), dtype=np.float64) 12 | 13 | for i in range(num): 14 | predict[i] = pred_json[i]["scores"][cls_id] 15 | target[i] = ann_json[i]["target"][cls_id] 16 | 17 | if types == 'wider': 18 | tmp = np.where(target != 99)[0] 19 | predict = predict[tmp] 20 | target = target[tmp] 21 | num = len(tmp) 22 | 23 | if types == 'voc07': 24 | tmp = np.where(target != 0)[0] 25 | predict = predict[tmp] 26 | target = target[tmp] 27 | neg_id = np.where(target == -1)[0] 28 | target[neg_id] = 0 29 | num = len(tmp) 30 | 31 | 32 | tmp = np.argsort(-predict) 33 | target = target[tmp] 34 | predict = predict[tmp] 35 | 36 | 37 | pre, obj = 0, 0 38 | for i in range(num): 39 | if target[i] == 1: 40 | obj += 1.0 41 | pre += obj / (i+1) 42 | pre /= obj 43 | return pre 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /utils/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | import json 5 | from tqdm import tqdm 6 | from .cal_mAP import json_map 7 | from .cal_PR import json_metric, metric, json_metric_top3 8 | 9 | 10 | voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle", 11 | "bus", "car", "cat", "chair", "cow", "diningtable", 12 | "dog", "horse", "motorbike", "person", "pottedplant", 13 | "sheep", "sofa", "train", "tvmonitor") 14 | coco_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 15 | 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', 16 | 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', 17 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 18 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 19 | 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', 20 | 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', 21 | 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 22 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 23 | 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 24 | 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', 25 | 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', 26 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 27 | 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush') 28 | wider_classes = ( 29 | "Male","longHair","sunglass","Hat","Tshiirt","longSleeve","formal", 30 | "shorts","jeans","longPants","skirt","faceMask", "logo","stripe") 31 | 32 | class_dict = { 33 | "voc07": voc_classes, 34 | "coco": coco_classes, 35 | "wider": wider_classes, 36 | } 37 | 38 | 39 | 40 | def evaluation(result, types, ann_path): 41 | print("Evaluation") 42 | classes = class_dict[types] 43 | aps = np.zeros(len(classes), dtype=np.float64) 44 | 45 | ann_json = json.load(open(ann_path, "r")) 46 | pred_json = result 47 | 48 | for i, _ in enumerate(tqdm(classes)): 49 | ap = json_map(i, pred_json, ann_json, types) 50 | aps[i] = ap 51 | OP, OR, OF1, CP, CR, CF1 = json_metric(pred_json, ann_json, len(classes), types) 52 | print("mAP: {:4f}".format(np.mean(aps))) 53 | print("CP: {:4f}, CR: {:4f}, CF1 :{:4F}".format(CP, CR, CF1)) 54 | print("OP: {:4f}, OR: {:4f}, OF1 {:4F}".format(OP, OR, OF1)) 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /utils/evaluation/warmUpLR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class WarmUpLR(torch.optim.lr_scheduler._LRScheduler): 5 | def __init__(self, optimizer, total_iters, last_epoch=-1): 6 | self.total_iters = total_iters 7 | super().__init__(optimizer, last_epoch=last_epoch) 8 | 9 | def get_lr(self): 10 | return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs] 11 | 12 | -------------------------------------------------------------------------------- /utils/pipeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/pipeline.PNG -------------------------------------------------------------------------------- /utils/prepare/prepare_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | from pycocotools.coco import COCO 6 | 7 | 8 | 9 | def make_data(data_path=None, tag="train"): 10 | annFile = os.path.join(data_path, "annotations/instances_{}2014.json".format(tag)) 11 | coco = COCO(annFile) 12 | 13 | img_id = coco.getImgIds() 14 | cat_id = coco.getCatIds() 15 | img_id = list(sorted(img_id)) 16 | cat_trans = {} 17 | for i in range(len(cat_id)): 18 | cat_trans[cat_id[i]] = i 19 | 20 | message = [] 21 | 22 | 23 | for i in img_id: 24 | data = {} 25 | target = [0] * 80 26 | path = "" 27 | img_info = coco.loadImgs(i)[0] 28 | ann_ids = coco.getAnnIds(imgIds = i) 29 | anns = coco.loadAnns(ann_ids) 30 | if len(anns) == 0: 31 | continue 32 | else: 33 | for i in range(len(anns)): 34 | cls = anns[i]['category_id'] 35 | cls = cat_trans[cls] 36 | target[cls] = 1 37 | path = img_info['file_name'] 38 | data['target'] = target 39 | data['img_path'] = os.path.join(os.path.join(data_path, "images/{}2014/".format(tag)), path) 40 | message.append(data) 41 | 42 | with open('data/coco/{}_coco2014.json'.format(tag), 'w') as f: 43 | json.dump(message, f) 44 | 45 | 46 | 47 | # The final json file include: train_coco2014.json & val_coco2014.json 48 | # which is the following format: 49 | # [item1, item2, item3, ......,] 50 | # item1 = { 51 | # "target": 52 | # "img_path": 53 | # } 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | # Usage: --data_path /your/dataset/path/COCO2014 57 | parser.add_argument("--data_path", default="Dataset/COCO2014/", type=str, help="The absolute path of COCO2014") 58 | args = parser.parse_args() 59 | 60 | if not os.path.exists("data/coco"): 61 | os.makedirs("data/coco") 62 | 63 | make_data(data_path=args.data_path, tag="train") 64 | make_data(data_path=args.data_path, tag="val") 65 | 66 | print("COCO data ready!") 67 | print("data/coco/train_coco2014.json, data/coco/val_coco2014.json") 68 | -------------------------------------------------------------------------------- /utils/prepare/prepare_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | import xml.dom.minidom as XML 6 | 7 | 8 | 9 | voc_cls_id = {"aeroplane":0, "bicycle":1, "bird":2, "boat":3, "bottle":4, 10 | "bus":5, "car":6, "cat":7, "chair":8, "cow":9, 11 | "diningtable":10, "dog":11, "horse":12, "motorbike":13, "person":14, 12 | "pottedplant":15, "sheep":16, "sofa":17, "train":18, "tvmonitor":19} 13 | 14 | 15 | def get_label(data_path): 16 | print("generating labels for VOC07 dataset") 17 | xml_paths = os.path.join(data_path, "VOC2007/Annotations/") 18 | save_dir = "data/voc07/labels" 19 | 20 | if not os.path.exists(save_dir): 21 | os.makedirs(save_dir) 22 | 23 | for i in os.listdir(xml_paths): 24 | if not i.endswith(".xml"): 25 | continue 26 | s_name = i.split('.')[0] + ".txt" 27 | s_dir = os.path.join(save_dir, s_name) 28 | xml_path = os.path.join(xml_paths, i) 29 | DomTree = XML.parse(xml_path) 30 | Root = DomTree.documentElement 31 | 32 | obj_all = Root.getElementsByTagName("object") 33 | leng = len(obj_all) 34 | cls = [] 35 | difi_tag = [] 36 | for obj in obj_all: 37 | # get the classes 38 | obj_name = obj.getElementsByTagName('name')[0] 39 | one_class = obj_name.childNodes[0].data 40 | cls.append(voc_cls_id[one_class]) 41 | 42 | difficult = obj.getElementsByTagName('difficult')[0] 43 | difi_tag.append(difficult.childNodes[0].data) 44 | 45 | for i, c in enumerate(cls): 46 | with open(s_dir, "a") as f: 47 | f.writelines("%s,%s\n" % (c, difi_tag[i])) 48 | 49 | 50 | def transdifi(data_path): 51 | print("generating final json file for VOC07 dataset") 52 | label_dir = "data/voc07/labels/" 53 | img_dir = os.path.join(data_path, "VOC2007/JPEGImages/") 54 | 55 | # get trainval test id 56 | id_dirs = os.path.join(data_path, "VOC2007/ImageSets/Main/") 57 | f_train = open(os.path.join(id_dirs, "train.txt"), "r").readlines() 58 | f_val = open(os.path.join(id_dirs, "val.txt"), "r").readlines() 59 | f_trainval = f_train + f_val 60 | f_test = open(os.path.join(id_dirs, "test.txt"), "r") 61 | 62 | trainval_id = np.sort([int(line.strip()) for line in f_trainval]).tolist() 63 | test_id = [int(line.strip()) for line in f_test] 64 | trainval_data = [] 65 | test_data = [] 66 | 67 | # ternary label 68 | # -1 means negative 69 | # 0 means difficult 70 | # +1 means positive 71 | 72 | # binary label 73 | # 0 means negative 74 | # +1 means positive 75 | 76 | # we use binary labels in our implementation 77 | 78 | for item in sorted(os.listdir(label_dir)): 79 | with open(os.path.join(label_dir, item), "r") as f: 80 | 81 | target = np.array([-1] * 20) 82 | classes = [] 83 | diffi_tag = [] 84 | 85 | for line in f.readlines(): 86 | cls, tag = map(int, line.strip().split(',')) 87 | classes.append(cls) 88 | diffi_tag.append(tag) 89 | 90 | classes = np.array(classes) 91 | diffi_tag = np.array(diffi_tag) 92 | for i in range(20): 93 | if i in classes: 94 | i_index = np.where(classes == i)[0] 95 | if len(i_index) == 1: 96 | target[i] = 1 - diffi_tag[i_index] 97 | else: 98 | if len(i_index) == sum(diffi_tag[i_index]): 99 | target[i] = 0 100 | else: 101 | target[i] = 1 102 | else: 103 | continue 104 | img_path = os.path.join(img_dir, item.split('.')[0]+".jpg") 105 | 106 | if int(item.split('.')[0]) in trainval_id: 107 | target[target == -1] = 0 # from ternary to binary by treating difficult as negatives 108 | data = {"target": target.tolist(), "img_path": img_path} 109 | trainval_data.append(data) 110 | if int(item.split('.')[0]) in test_id: 111 | data = {"target": target.tolist(), "img_path": img_path} 112 | test_data.append(data) 113 | 114 | json.dump(trainval_data, open("data/voc07/trainval_voc07.json", "w")) 115 | json.dump(test_data, open("data/voc07/test_voc07.json", "w")) 116 | print("VOC07 data preparing finished!") 117 | print("data/voc07/trainval_voc07.json data/voc07/test_voc07.json") 118 | 119 | # remove label cash 120 | for item in os.listdir(label_dir): 121 | os.remove(os.path.join(label_dir, item)) 122 | os.rmdir(label_dir) 123 | 124 | 125 | # We treat difficult classes in trainval_data as negtive while ignore them in test_data 126 | # The ignoring operation can be automatically done during evaluation (testing). 127 | # The final json file include: trainval_voc07.json & test_voc07.json 128 | # which is the following format: 129 | # [item1, item2, item3, ......,] 130 | # item1 = { 131 | # "target": 132 | # "img_path": 133 | # } 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser() 137 | # Usage: --data_path /your/dataset/path/VOCdevkit 138 | parser.add_argument("--data_path", default="Dataset/VOCdevkit/", type=str, help="The absolute path of VOCdevkit") 139 | args = parser.parse_args() 140 | 141 | if not os.path.exists("data/voc07"): 142 | os.makedirs("data/voc07") 143 | 144 | if 'VOCdevkit' not in args.data_path: 145 | print("WARNING: please include \'VOCdevkit\' str in your args.data_path") 146 | # exit() 147 | 148 | get_label(args.data_path) 149 | transdifi(args.data_path) -------------------------------------------------------------------------------- /utils/prepare/prepare_wider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import argparse 5 | 6 | 7 | def make_wider(tag, value, data_path): 8 | img_path = os.path.join(data_path, "Image") 9 | ann_path = os.path.join(data_path, "Annotations") 10 | ann_file = os.path.join(ann_path, "wider_attribute_{}.json".format(tag)) 11 | 12 | data = json.load(open(ann_file, "r")) 13 | 14 | final = [] 15 | image_list = data['images'] 16 | for image in image_list: 17 | for person in image["targets"]: # iterate over each person 18 | tmp = {} 19 | tmp['img_path'] = os.path.join(img_path, image['file_name']) 20 | tmp['bbox'] = person['bbox'] 21 | attr = person["attribute"] 22 | for i, item in enumerate(attr): 23 | if item == -1: 24 | attr[i] = 0 25 | if item == 0: 26 | attr[i] = value # pad un-specified samples 27 | if item == 1: 28 | attr[i] = 1 29 | tmp["target"] = attr 30 | final.append(tmp) 31 | 32 | json.dump(final, open("data/wider/{}_wider.json".format(tag), "w")) 33 | print("data/wider/{}_wider.json".format(tag)) 34 | 35 | 36 | 37 | # which is the following format: 38 | # [item1, item2, item3, ......,] 39 | # item1 = { 40 | # "target": 41 | # "img_path": 42 | # } 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument("--data_path", default="Dataset/WIDER_ATTRIBUTE", type=str) 48 | args = parser.parse_args() 49 | 50 | if not os.path.exists("data/wider"): 51 | os.makedirs("data/wider") 52 | 53 | # 0 (zero) means negative, we treat un-specified attribute as negative in the trainval set 54 | make_wider(tag='trainval', value=0, data_path=args.data_path) 55 | 56 | # 99 means we ignore un-specified attribute in the test set, following previous work 57 | # the number 99 can be properly identified when evaluating mAP 58 | make_wider(tag='test', value=99, data_path=args.data_path) 59 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import json 3 | import torch 4 | from torchvision import transforms 5 | import cv2 6 | import numpy as np 7 | import os 8 | import torch.nn as nn 9 | 10 | def show_cam_on_img(img, mask, img_path_save): 11 | heat_map = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) 12 | heat_map = np.float32(heat_map) / 255 13 | 14 | cam = heat_map + np.float32(img) 15 | cam = cam / np.max(cam) 16 | cv2.imwrite(img_path_save, np.uint8(255 * cam)) 17 | 18 | 19 | img_path_read = "" 20 | img_path_save = "" 21 | 22 | 23 | 24 | 25 | def main(): 26 | img = cv2.imread(img_path_read, flags=1) 27 | 28 | img = np.float32(cv2.resize(img, (224, 224))) / 255 29 | 30 | # cam_all is the score tensor of shape (B, C, H, W), similar to y_raw in out Figure 1 31 | # cls_idx specifying the i-th class out of C class 32 | # visualize the 0's class heatmap 33 | cls_idx = 0 34 | cam = cam_all[cls_idx] 35 | 36 | 37 | # cam = nn.ReLU()(cam) 38 | cam = cam / torch.max(cam) 39 | 40 | cam = cv2.resize(np.array(cam), (224, 224)) 41 | show_cam_on_img(img, cam, img_path_save) 42 | 43 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | from pipeline.resnet_csra import ResNet_CSRA 8 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA 9 | from pipeline.dataset import DataSet 10 | from utils.evaluation.eval import evaluation 11 | from utils.evaluation.eval import WarmUpLR 12 | from tqdm import tqdm 13 | 14 | 15 | def Args(): 16 | parser = argparse.ArgumentParser(description="settings") 17 | # model default resnet101 18 | parser.add_argument("--model", default="resnet101", type=str) 19 | parser.add_argument("--num_heads", default=1, type=int) 20 | parser.add_argument("--lam",default=0.1, type=float) 21 | parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str) 22 | # dataset 23 | parser.add_argument("--dataset", default="voc07", type=str) 24 | parser.add_argument("--num_cls", default=20, type=int) 25 | parser.add_argument("--test_aug", default=[], type=list) 26 | parser.add_argument("--img_size", default=448, type=int) 27 | parser.add_argument("--batch_size", default=16, type=int) 28 | 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def val(args, model, test_loader, test_file): 34 | model.eval() 35 | print("Test on Pretrained Models") 36 | result_list = [] 37 | 38 | # calculate logit 39 | for index, data in enumerate(tqdm(test_loader)): 40 | img = data['img'].cuda() 41 | target = data['target'].cuda() 42 | img_path = data['img_path'] 43 | 44 | with torch.no_grad(): 45 | logit = model(img) 46 | 47 | result = nn.Sigmoid()(logit).cpu().detach().numpy().tolist() 48 | for k in range(len(img_path)): 49 | result_list.append( 50 | { 51 | "file_name": img_path[k].split("/")[-1].split(".")[0], 52 | "scores": result[k] 53 | } 54 | ) 55 | 56 | # cal_mAP OP OR 57 | evaluation(result=result_list, types=args.dataset, ann_path=test_file[0]) 58 | 59 | 60 | 61 | def main(): 62 | args = Args() 63 | 64 | # model 65 | if args.model == "resnet101": 66 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls, cutmix=args.cutmix) 67 | if args.model == "vit_B16_224": 68 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 69 | if args.model == "vit_L16_224": 70 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls) 71 | 72 | model.cuda() 73 | print("Loading weights from {}".format(args.load_from)) 74 | if torch.cuda.device_count() > 1: 75 | print("lets use {} GPUs.".format(torch.cuda.device_count())) 76 | model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) 77 | model.module.load_state_dict(torch.load(args.load_from)) 78 | else: 79 | model.load_state_dict(torch.load(args.load_from)) 80 | 81 | # data 82 | if args.dataset == "voc07": 83 | test_file = ['data/voc07/test_voc07.json'] 84 | if args.dataset == "coco": 85 | test_file = ['data/coco/val_coco2014.json'] 86 | if args.dataset == "wider": 87 | test_file = ['data/wider/test_wider.json'] 88 | 89 | 90 | test_dataset = DataSet(test_file, args.test_aug, args.img_size, args.dataset) 91 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 92 | 93 | val(args, model, test_loader, test_file) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | --------------------------------------------------------------------------------