├── LICENSE
├── README.md
├── demo.py
├── main.py
├── pipeline
├── csra.py
├── dataset.py
├── resnet_csra.py
├── timm_utils
│ ├── __init__.py
│ ├── drop.py
│ ├── tuple.py
│ └── weight_init.py
└── vit_csra.py
├── utils
├── demo_images
│ ├── 000001.jpg
│ ├── 000002.jpg
│ ├── 000004.jpg
│ ├── 000006.jpg
│ ├── 000007.jpg
│ └── 000009.jpg
├── evaluation
│ ├── cal_PR.py
│ ├── cal_mAP.py
│ ├── eval.py
│ └── warmUpLR.py
├── pipeline.PNG
├── prepare
│ ├── prepare_coco.py
│ ├── prepare_voc.py
│ └── prepare_wider.py
└── visualize.py
└── val.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU AFFERO GENERAL PUBLIC LICENSE
2 | Version 3, 19 November 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU Affero General Public License is a free, copyleft license for
11 | software and other kinds of works, specifically designed to ensure
12 | cooperation with the community in the case of network server software.
13 |
14 | The licenses for most software and other practical works are designed
15 | to take away your freedom to share and change the works. By contrast,
16 | our General Public Licenses are intended to guarantee your freedom to
17 | share and change all versions of a program--to make sure it remains free
18 | software for all its users.
19 |
20 | When we speak of free software, we are referring to freedom, not
21 | price. Our General Public Licenses are designed to make sure that you
22 | have the freedom to distribute copies of free software (and charge for
23 | them if you wish), that you receive source code or can get it if you
24 | want it, that you can change the software or use pieces of it in new
25 | free programs, and that you know you can do these things.
26 |
27 | Developers that use our General Public Licenses protect your rights
28 | with two steps: (1) assert copyright on the software, and (2) offer
29 | you this License which gives you legal permission to copy, distribute
30 | and/or modify the software.
31 |
32 | A secondary benefit of defending all users' freedom is that
33 | improvements made in alternate versions of the program, if they
34 | receive widespread use, become available for other developers to
35 | incorporate. Many developers of free software are heartened and
36 | encouraged by the resulting cooperation. However, in the case of
37 | software used on network servers, this result may fail to come about.
38 | The GNU General Public License permits making a modified version and
39 | letting the public access it on a server without ever releasing its
40 | source code to the public.
41 |
42 | The GNU Affero General Public License is designed specifically to
43 | ensure that, in such cases, the modified source code becomes available
44 | to the community. It requires the operator of a network server to
45 | provide the source code of the modified version running there to the
46 | users of that server. Therefore, public use of a modified version, on
47 | a publicly accessible server, gives the public access to the source
48 | code of the modified version.
49 |
50 | An older license, called the Affero General Public License and
51 | published by Affero, was designed to accomplish similar goals. This is
52 | a different license, not a version of the Affero GPL, but Affero has
53 | released a new version of the Affero GPL which permits relicensing under
54 | this license.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | TERMS AND CONDITIONS
60 |
61 | 0. Definitions.
62 |
63 | "This License" refers to version 3 of the GNU Affero General Public License.
64 |
65 | "Copyright" also means copyright-like laws that apply to other kinds of
66 | works, such as semiconductor masks.
67 |
68 | "The Program" refers to any copyrightable work licensed under this
69 | License. Each licensee is addressed as "you". "Licensees" and
70 | "recipients" may be individuals or organizations.
71 |
72 | To "modify" a work means to copy from or adapt all or part of the work
73 | in a fashion requiring copyright permission, other than the making of an
74 | exact copy. The resulting work is called a "modified version" of the
75 | earlier work or a work "based on" the earlier work.
76 |
77 | A "covered work" means either the unmodified Program or a work based
78 | on the Program.
79 |
80 | To "propagate" a work means to do anything with it that, without
81 | permission, would make you directly or secondarily liable for
82 | infringement under applicable copyright law, except executing it on a
83 | computer or modifying a private copy. Propagation includes copying,
84 | distribution (with or without modification), making available to the
85 | public, and in some countries other activities as well.
86 |
87 | To "convey" a work means any kind of propagation that enables other
88 | parties to make or receive copies. Mere interaction with a user through
89 | a computer network, with no transfer of a copy, is not conveying.
90 |
91 | An interactive user interface displays "Appropriate Legal Notices"
92 | to the extent that it includes a convenient and prominently visible
93 | feature that (1) displays an appropriate copyright notice, and (2)
94 | tells the user that there is no warranty for the work (except to the
95 | extent that warranties are provided), that licensees may convey the
96 | work under this License, and how to view a copy of this License. If
97 | the interface presents a list of user commands or options, such as a
98 | menu, a prominent item in the list meets this criterion.
99 |
100 | 1. Source Code.
101 |
102 | The "source code" for a work means the preferred form of the work
103 | for making modifications to it. "Object code" means any non-source
104 | form of a work.
105 |
106 | A "Standard Interface" means an interface that either is an official
107 | standard defined by a recognized standards body, or, in the case of
108 | interfaces specified for a particular programming language, one that
109 | is widely used among developers working in that language.
110 |
111 | The "System Libraries" of an executable work include anything, other
112 | than the work as a whole, that (a) is included in the normal form of
113 | packaging a Major Component, but which is not part of that Major
114 | Component, and (b) serves only to enable use of the work with that
115 | Major Component, or to implement a Standard Interface for which an
116 | implementation is available to the public in source code form. A
117 | "Major Component", in this context, means a major essential component
118 | (kernel, window system, and so on) of the specific operating system
119 | (if any) on which the executable work runs, or a compiler used to
120 | produce the work, or an object code interpreter used to run it.
121 |
122 | The "Corresponding Source" for a work in object code form means all
123 | the source code needed to generate, install, and (for an executable
124 | work) run the object code and to modify the work, including scripts to
125 | control those activities. However, it does not include the work's
126 | System Libraries, or general-purpose tools or generally available free
127 | programs which are used unmodified in performing those activities but
128 | which are not part of the work. For example, Corresponding Source
129 | includes interface definition files associated with source files for
130 | the work, and the source code for shared libraries and dynamically
131 | linked subprograms that the work is specifically designed to require,
132 | such as by intimate data communication or control flow between those
133 | subprograms and other parts of the work.
134 |
135 | The Corresponding Source need not include anything that users
136 | can regenerate automatically from other parts of the Corresponding
137 | Source.
138 |
139 | The Corresponding Source for a work in source code form is that
140 | same work.
141 |
142 | 2. Basic Permissions.
143 |
144 | All rights granted under this License are granted for the term of
145 | copyright on the Program, and are irrevocable provided the stated
146 | conditions are met. This License explicitly affirms your unlimited
147 | permission to run the unmodified Program. The output from running a
148 | covered work is covered by this License only if the output, given its
149 | content, constitutes a covered work. This License acknowledges your
150 | rights of fair use or other equivalent, as provided by copyright law.
151 |
152 | You may make, run and propagate covered works that you do not
153 | convey, without conditions so long as your license otherwise remains
154 | in force. You may convey covered works to others for the sole purpose
155 | of having them make modifications exclusively for you, or provide you
156 | with facilities for running those works, provided that you comply with
157 | the terms of this License in conveying all material for which you do
158 | not control copyright. Those thus making or running the covered works
159 | for you must do so exclusively on your behalf, under your direction
160 | and control, on terms that prohibit them from making any copies of
161 | your copyrighted material outside their relationship with you.
162 |
163 | Conveying under any other circumstances is permitted solely under
164 | the conditions stated below. Sublicensing is not allowed; section 10
165 | makes it unnecessary.
166 |
167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168 |
169 | No covered work shall be deemed part of an effective technological
170 | measure under any applicable law fulfilling obligations under article
171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172 | similar laws prohibiting or restricting circumvention of such
173 | measures.
174 |
175 | When you convey a covered work, you waive any legal power to forbid
176 | circumvention of technological measures to the extent such circumvention
177 | is effected by exercising rights under this License with respect to
178 | the covered work, and you disclaim any intention to limit operation or
179 | modification of the work as a means of enforcing, against the work's
180 | users, your or third parties' legal rights to forbid circumvention of
181 | technological measures.
182 |
183 | 4. Conveying Verbatim Copies.
184 |
185 | You may convey verbatim copies of the Program's source code as you
186 | receive it, in any medium, provided that you conspicuously and
187 | appropriately publish on each copy an appropriate copyright notice;
188 | keep intact all notices stating that this License and any
189 | non-permissive terms added in accord with section 7 apply to the code;
190 | keep intact all notices of the absence of any warranty; and give all
191 | recipients a copy of this License along with the Program.
192 |
193 | You may charge any price or no price for each copy that you convey,
194 | and you may offer support or warranty protection for a fee.
195 |
196 | 5. Conveying Modified Source Versions.
197 |
198 | You may convey a work based on the Program, or the modifications to
199 | produce it from the Program, in the form of source code under the
200 | terms of section 4, provided that you also meet all of these conditions:
201 |
202 | a) The work must carry prominent notices stating that you modified
203 | it, and giving a relevant date.
204 |
205 | b) The work must carry prominent notices stating that it is
206 | released under this License and any conditions added under section
207 | 7. This requirement modifies the requirement in section 4 to
208 | "keep intact all notices".
209 |
210 | c) You must license the entire work, as a whole, under this
211 | License to anyone who comes into possession of a copy. This
212 | License will therefore apply, along with any applicable section 7
213 | additional terms, to the whole of the work, and all its parts,
214 | regardless of how they are packaged. This License gives no
215 | permission to license the work in any other way, but it does not
216 | invalidate such permission if you have separately received it.
217 |
218 | d) If the work has interactive user interfaces, each must display
219 | Appropriate Legal Notices; however, if the Program has interactive
220 | interfaces that do not display Appropriate Legal Notices, your
221 | work need not make them do so.
222 |
223 | A compilation of a covered work with other separate and independent
224 | works, which are not by their nature extensions of the covered work,
225 | and which are not combined with it such as to form a larger program,
226 | in or on a volume of a storage or distribution medium, is called an
227 | "aggregate" if the compilation and its resulting copyright are not
228 | used to limit the access or legal rights of the compilation's users
229 | beyond what the individual works permit. Inclusion of a covered work
230 | in an aggregate does not cause this License to apply to the other
231 | parts of the aggregate.
232 |
233 | 6. Conveying Non-Source Forms.
234 |
235 | You may convey a covered work in object code form under the terms
236 | of sections 4 and 5, provided that you also convey the
237 | machine-readable Corresponding Source under the terms of this License,
238 | in one of these ways:
239 |
240 | a) Convey the object code in, or embodied in, a physical product
241 | (including a physical distribution medium), accompanied by the
242 | Corresponding Source fixed on a durable physical medium
243 | customarily used for software interchange.
244 |
245 | b) Convey the object code in, or embodied in, a physical product
246 | (including a physical distribution medium), accompanied by a
247 | written offer, valid for at least three years and valid for as
248 | long as you offer spare parts or customer support for that product
249 | model, to give anyone who possesses the object code either (1) a
250 | copy of the Corresponding Source for all the software in the
251 | product that is covered by this License, on a durable physical
252 | medium customarily used for software interchange, for a price no
253 | more than your reasonable cost of physically performing this
254 | conveying of source, or (2) access to copy the
255 | Corresponding Source from a network server at no charge.
256 |
257 | c) Convey individual copies of the object code with a copy of the
258 | written offer to provide the Corresponding Source. This
259 | alternative is allowed only occasionally and noncommercially, and
260 | only if you received the object code with such an offer, in accord
261 | with subsection 6b.
262 |
263 | d) Convey the object code by offering access from a designated
264 | place (gratis or for a charge), and offer equivalent access to the
265 | Corresponding Source in the same way through the same place at no
266 | further charge. You need not require recipients to copy the
267 | Corresponding Source along with the object code. If the place to
268 | copy the object code is a network server, the Corresponding Source
269 | may be on a different server (operated by you or a third party)
270 | that supports equivalent copying facilities, provided you maintain
271 | clear directions next to the object code saying where to find the
272 | Corresponding Source. Regardless of what server hosts the
273 | Corresponding Source, you remain obligated to ensure that it is
274 | available for as long as needed to satisfy these requirements.
275 |
276 | e) Convey the object code using peer-to-peer transmission, provided
277 | you inform other peers where the object code and Corresponding
278 | Source of the work are being offered to the general public at no
279 | charge under subsection 6d.
280 |
281 | A separable portion of the object code, whose source code is excluded
282 | from the Corresponding Source as a System Library, need not be
283 | included in conveying the object code work.
284 |
285 | A "User Product" is either (1) a "consumer product", which means any
286 | tangible personal property which is normally used for personal, family,
287 | or household purposes, or (2) anything designed or sold for incorporation
288 | into a dwelling. In determining whether a product is a consumer product,
289 | doubtful cases shall be resolved in favor of coverage. For a particular
290 | product received by a particular user, "normally used" refers to a
291 | typical or common use of that class of product, regardless of the status
292 | of the particular user or of the way in which the particular user
293 | actually uses, or expects or is expected to use, the product. A product
294 | is a consumer product regardless of whether the product has substantial
295 | commercial, industrial or non-consumer uses, unless such uses represent
296 | the only significant mode of use of the product.
297 |
298 | "Installation Information" for a User Product means any methods,
299 | procedures, authorization keys, or other information required to install
300 | and execute modified versions of a covered work in that User Product from
301 | a modified version of its Corresponding Source. The information must
302 | suffice to ensure that the continued functioning of the modified object
303 | code is in no case prevented or interfered with solely because
304 | modification has been made.
305 |
306 | If you convey an object code work under this section in, or with, or
307 | specifically for use in, a User Product, and the conveying occurs as
308 | part of a transaction in which the right of possession and use of the
309 | User Product is transferred to the recipient in perpetuity or for a
310 | fixed term (regardless of how the transaction is characterized), the
311 | Corresponding Source conveyed under this section must be accompanied
312 | by the Installation Information. But this requirement does not apply
313 | if neither you nor any third party retains the ability to install
314 | modified object code on the User Product (for example, the work has
315 | been installed in ROM).
316 |
317 | The requirement to provide Installation Information does not include a
318 | requirement to continue to provide support service, warranty, or updates
319 | for a work that has been modified or installed by the recipient, or for
320 | the User Product in which it has been modified or installed. Access to a
321 | network may be denied when the modification itself materially and
322 | adversely affects the operation of the network or violates the rules and
323 | protocols for communication across the network.
324 |
325 | Corresponding Source conveyed, and Installation Information provided,
326 | in accord with this section must be in a format that is publicly
327 | documented (and with an implementation available to the public in
328 | source code form), and must require no special password or key for
329 | unpacking, reading or copying.
330 |
331 | 7. Additional Terms.
332 |
333 | "Additional permissions" are terms that supplement the terms of this
334 | License by making exceptions from one or more of its conditions.
335 | Additional permissions that are applicable to the entire Program shall
336 | be treated as though they were included in this License, to the extent
337 | that they are valid under applicable law. If additional permissions
338 | apply only to part of the Program, that part may be used separately
339 | under those permissions, but the entire Program remains governed by
340 | this License without regard to the additional permissions.
341 |
342 | When you convey a copy of a covered work, you may at your option
343 | remove any additional permissions from that copy, or from any part of
344 | it. (Additional permissions may be written to require their own
345 | removal in certain cases when you modify the work.) You may place
346 | additional permissions on material, added by you to a covered work,
347 | for which you have or can give appropriate copyright permission.
348 |
349 | Notwithstanding any other provision of this License, for material you
350 | add to a covered work, you may (if authorized by the copyright holders of
351 | that material) supplement the terms of this License with terms:
352 |
353 | a) Disclaiming warranty or limiting liability differently from the
354 | terms of sections 15 and 16 of this License; or
355 |
356 | b) Requiring preservation of specified reasonable legal notices or
357 | author attributions in that material or in the Appropriate Legal
358 | Notices displayed by works containing it; or
359 |
360 | c) Prohibiting misrepresentation of the origin of that material, or
361 | requiring that modified versions of such material be marked in
362 | reasonable ways as different from the original version; or
363 |
364 | d) Limiting the use for publicity purposes of names of licensors or
365 | authors of the material; or
366 |
367 | e) Declining to grant rights under trademark law for use of some
368 | trade names, trademarks, or service marks; or
369 |
370 | f) Requiring indemnification of licensors and authors of that
371 | material by anyone who conveys the material (or modified versions of
372 | it) with contractual assumptions of liability to the recipient, for
373 | any liability that these contractual assumptions directly impose on
374 | those licensors and authors.
375 |
376 | All other non-permissive additional terms are considered "further
377 | restrictions" within the meaning of section 10. If the Program as you
378 | received it, or any part of it, contains a notice stating that it is
379 | governed by this License along with a term that is a further
380 | restriction, you may remove that term. If a license document contains
381 | a further restriction but permits relicensing or conveying under this
382 | License, you may add to a covered work material governed by the terms
383 | of that license document, provided that the further restriction does
384 | not survive such relicensing or conveying.
385 |
386 | If you add terms to a covered work in accord with this section, you
387 | must place, in the relevant source files, a statement of the
388 | additional terms that apply to those files, or a notice indicating
389 | where to find the applicable terms.
390 |
391 | Additional terms, permissive or non-permissive, may be stated in the
392 | form of a separately written license, or stated as exceptions;
393 | the above requirements apply either way.
394 |
395 | 8. Termination.
396 |
397 | You may not propagate or modify a covered work except as expressly
398 | provided under this License. Any attempt otherwise to propagate or
399 | modify it is void, and will automatically terminate your rights under
400 | this License (including any patent licenses granted under the third
401 | paragraph of section 11).
402 |
403 | However, if you cease all violation of this License, then your
404 | license from a particular copyright holder is reinstated (a)
405 | provisionally, unless and until the copyright holder explicitly and
406 | finally terminates your license, and (b) permanently, if the copyright
407 | holder fails to notify you of the violation by some reasonable means
408 | prior to 60 days after the cessation.
409 |
410 | Moreover, your license from a particular copyright holder is
411 | reinstated permanently if the copyright holder notifies you of the
412 | violation by some reasonable means, this is the first time you have
413 | received notice of violation of this License (for any work) from that
414 | copyright holder, and you cure the violation prior to 30 days after
415 | your receipt of the notice.
416 |
417 | Termination of your rights under this section does not terminate the
418 | licenses of parties who have received copies or rights from you under
419 | this License. If your rights have been terminated and not permanently
420 | reinstated, you do not qualify to receive new licenses for the same
421 | material under section 10.
422 |
423 | 9. Acceptance Not Required for Having Copies.
424 |
425 | You are not required to accept this License in order to receive or
426 | run a copy of the Program. Ancillary propagation of a covered work
427 | occurring solely as a consequence of using peer-to-peer transmission
428 | to receive a copy likewise does not require acceptance. However,
429 | nothing other than this License grants you permission to propagate or
430 | modify any covered work. These actions infringe copyright if you do
431 | not accept this License. Therefore, by modifying or propagating a
432 | covered work, you indicate your acceptance of this License to do so.
433 |
434 | 10. Automatic Licensing of Downstream Recipients.
435 |
436 | Each time you convey a covered work, the recipient automatically
437 | receives a license from the original licensors, to run, modify and
438 | propagate that work, subject to this License. You are not responsible
439 | for enforcing compliance by third parties with this License.
440 |
441 | An "entity transaction" is a transaction transferring control of an
442 | organization, or substantially all assets of one, or subdividing an
443 | organization, or merging organizations. If propagation of a covered
444 | work results from an entity transaction, each party to that
445 | transaction who receives a copy of the work also receives whatever
446 | licenses to the work the party's predecessor in interest had or could
447 | give under the previous paragraph, plus a right to possession of the
448 | Corresponding Source of the work from the predecessor in interest, if
449 | the predecessor has it or can get it with reasonable efforts.
450 |
451 | You may not impose any further restrictions on the exercise of the
452 | rights granted or affirmed under this License. For example, you may
453 | not impose a license fee, royalty, or other charge for exercise of
454 | rights granted under this License, and you may not initiate litigation
455 | (including a cross-claim or counterclaim in a lawsuit) alleging that
456 | any patent claim is infringed by making, using, selling, offering for
457 | sale, or importing the Program or any portion of it.
458 |
459 | 11. Patents.
460 |
461 | A "contributor" is a copyright holder who authorizes use under this
462 | License of the Program or a work on which the Program is based. The
463 | work thus licensed is called the contributor's "contributor version".
464 |
465 | A contributor's "essential patent claims" are all patent claims
466 | owned or controlled by the contributor, whether already acquired or
467 | hereafter acquired, that would be infringed by some manner, permitted
468 | by this License, of making, using, or selling its contributor version,
469 | but do not include claims that would be infringed only as a
470 | consequence of further modification of the contributor version. For
471 | purposes of this definition, "control" includes the right to grant
472 | patent sublicenses in a manner consistent with the requirements of
473 | this License.
474 |
475 | Each contributor grants you a non-exclusive, worldwide, royalty-free
476 | patent license under the contributor's essential patent claims, to
477 | make, use, sell, offer for sale, import and otherwise run, modify and
478 | propagate the contents of its contributor version.
479 |
480 | In the following three paragraphs, a "patent license" is any express
481 | agreement or commitment, however denominated, not to enforce a patent
482 | (such as an express permission to practice a patent or covenant not to
483 | sue for patent infringement). To "grant" such a patent license to a
484 | party means to make such an agreement or commitment not to enforce a
485 | patent against the party.
486 |
487 | If you convey a covered work, knowingly relying on a patent license,
488 | and the Corresponding Source of the work is not available for anyone
489 | to copy, free of charge and under the terms of this License, through a
490 | publicly available network server or other readily accessible means,
491 | then you must either (1) cause the Corresponding Source to be so
492 | available, or (2) arrange to deprive yourself of the benefit of the
493 | patent license for this particular work, or (3) arrange, in a manner
494 | consistent with the requirements of this License, to extend the patent
495 | license to downstream recipients. "Knowingly relying" means you have
496 | actual knowledge that, but for the patent license, your conveying the
497 | covered work in a country, or your recipient's use of the covered work
498 | in a country, would infringe one or more identifiable patents in that
499 | country that you have reason to believe are valid.
500 |
501 | If, pursuant to or in connection with a single transaction or
502 | arrangement, you convey, or propagate by procuring conveyance of, a
503 | covered work, and grant a patent license to some of the parties
504 | receiving the covered work authorizing them to use, propagate, modify
505 | or convey a specific copy of the covered work, then the patent license
506 | you grant is automatically extended to all recipients of the covered
507 | work and works based on it.
508 |
509 | A patent license is "discriminatory" if it does not include within
510 | the scope of its coverage, prohibits the exercise of, or is
511 | conditioned on the non-exercise of one or more of the rights that are
512 | specifically granted under this License. You may not convey a covered
513 | work if you are a party to an arrangement with a third party that is
514 | in the business of distributing software, under which you make payment
515 | to the third party based on the extent of your activity of conveying
516 | the work, and under which the third party grants, to any of the
517 | parties who would receive the covered work from you, a discriminatory
518 | patent license (a) in connection with copies of the covered work
519 | conveyed by you (or copies made from those copies), or (b) primarily
520 | for and in connection with specific products or compilations that
521 | contain the covered work, unless you entered into that arrangement,
522 | or that patent license was granted, prior to 28 March 2007.
523 |
524 | Nothing in this License shall be construed as excluding or limiting
525 | any implied license or other defenses to infringement that may
526 | otherwise be available to you under applicable patent law.
527 |
528 | 12. No Surrender of Others' Freedom.
529 |
530 | If conditions are imposed on you (whether by court order, agreement or
531 | otherwise) that contradict the conditions of this License, they do not
532 | excuse you from the conditions of this License. If you cannot convey a
533 | covered work so as to satisfy simultaneously your obligations under this
534 | License and any other pertinent obligations, then as a consequence you may
535 | not convey it at all. For example, if you agree to terms that obligate you
536 | to collect a royalty for further conveying from those to whom you convey
537 | the Program, the only way you could satisfy both those terms and this
538 | License would be to refrain entirely from conveying the Program.
539 |
540 | 13. Remote Network Interaction; Use with the GNU General Public License.
541 |
542 | Notwithstanding any other provision of this License, if you modify the
543 | Program, your modified version must prominently offer all users
544 | interacting with it remotely through a computer network (if your version
545 | supports such interaction) an opportunity to receive the Corresponding
546 | Source of your version by providing access to the Corresponding Source
547 | from a network server at no charge, through some standard or customary
548 | means of facilitating copying of software. This Corresponding Source
549 | shall include the Corresponding Source for any work covered by version 3
550 | of the GNU General Public License that is incorporated pursuant to the
551 | following paragraph.
552 |
553 | Notwithstanding any other provision of this License, you have
554 | permission to link or combine any covered work with a work licensed
555 | under version 3 of the GNU General Public License into a single
556 | combined work, and to convey the resulting work. The terms of this
557 | License will continue to apply to the part which is the covered work,
558 | but the work with which it is combined will remain governed by version
559 | 3 of the GNU General Public License.
560 |
561 | 14. Revised Versions of this License.
562 |
563 | The Free Software Foundation may publish revised and/or new versions of
564 | the GNU Affero General Public License from time to time. Such new versions
565 | will be similar in spirit to the present version, but may differ in detail to
566 | address new problems or concerns.
567 |
568 | Each version is given a distinguishing version number. If the
569 | Program specifies that a certain numbered version of the GNU Affero General
570 | Public License "or any later version" applies to it, you have the
571 | option of following the terms and conditions either of that numbered
572 | version or of any later version published by the Free Software
573 | Foundation. If the Program does not specify a version number of the
574 | GNU Affero General Public License, you may choose any version ever published
575 | by the Free Software Foundation.
576 |
577 | If the Program specifies that a proxy can decide which future
578 | versions of the GNU Affero General Public License can be used, that proxy's
579 | public statement of acceptance of a version permanently authorizes you
580 | to choose that version for the Program.
581 |
582 | Later license versions may give you additional or different
583 | permissions. However, no additional obligations are imposed on any
584 | author or copyright holder as a result of your choosing to follow a
585 | later version.
586 |
587 | 15. Disclaimer of Warranty.
588 |
589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597 |
598 | 16. Limitation of Liability.
599 |
600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608 | SUCH DAMAGES.
609 |
610 | 17. Interpretation of Sections 15 and 16.
611 |
612 | If the disclaimer of warranty and limitation of liability provided
613 | above cannot be given local legal effect according to their terms,
614 | reviewing courts shall apply local law that most closely approximates
615 | an absolute waiver of all civil liability in connection with the
616 | Program, unless a warranty or assumption of liability accompanies a
617 | copy of the Program in return for a fee.
618 |
619 | END OF TERMS AND CONDITIONS
620 |
621 | How to Apply These Terms to Your New Programs
622 |
623 | If you develop a new program, and you want it to be of the greatest
624 | possible use to the public, the best way to achieve this is to make it
625 | free software which everyone can redistribute and change under these terms.
626 |
627 | To do so, attach the following notices to the program. It is safest
628 | to attach them to the start of each source file to most effectively
629 | state the exclusion of warranty; and each file should have at least
630 | the "copyright" line and a pointer to where the full notice is found.
631 |
632 |
633 | Copyright (C)
634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CSRA
2 | This is the official code of ICCV 2021 paper:
3 | [Residual Attention: A Simple But Effective Method for Multi-Label Recoginition](https://arxiv.org/abs/2108.02456)
4 |
5 | 
6 |
7 | ### Demo, Train and Validation code have been released! (including VIT on Wider-Attribute)
8 | This package is developed by Mr. Ke Zhu (http://www.lamda.nju.edu.cn/zhuk/) and we have just finished the implementation code of ViT models. If you have any question about the code, please feel free to contact Mr. Ke Zhu (zhuk@lamda.nju.edu.cn). The package is free for academic usage. You can run it at your own risk. For other purposes, please contact Prof. Jianxin Wu (mail to
9 | wujx2001@gmail.com).
10 |
11 | ## Requirements
12 | - Python 3.7
13 | - pytorch 1.6
14 | - torchvision 0.7.0
15 | - pycocotools 2.0
16 | - tqdm 4.49.0, pillow 7.2.0
17 |
18 | ## Dataset
19 | We expect VOC2007, COCO2014 and Wider-Attribute dataset to have the following structure:
20 | ```
21 | Dataset/
22 | |-- VOCdevkit/
23 | |---- VOC2007/
24 | |------ JPEGImages/
25 | |------ Annotations/
26 | |------ ImageSets/
27 | ......
28 | |-- COCO2014/
29 | |---- annotations/
30 | |---- images/
31 | |------ train2014/
32 | |------ val2014/
33 | ......
34 | |-- WIDER/
35 | |---- Annotations/
36 | |------ wider_attribute_test.json
37 | |------ wider_attribute_trainval.json
38 | |---- Image/
39 | |------ train/
40 | |------ val/
41 | |------ test/
42 | ...
43 | ```
44 | Then directly run the following command to generate json file (for implementation) of these datasets.
45 | ```shell
46 | python utils/prepare/prepare_voc.py --data_path Dataset/VOCdevkit
47 | python utils/prepare/prepare_coco.py --data_path Dataset/COCO2014
48 | python utils/prepare/prepare_wider.py --data_path Dataset/WIDER
49 | ```
50 | which will automatically result in annotation json files in *./data/voc07*, *./data/coco* and *./data/wider*
51 |
52 | ## Demo
53 | We provide prediction demos of our models. The demo images (picked from VCO2007) have already been put into *./utils/demo_images/*, you can simply run demo.py by using our CSRA models pretrained on VOC2007:
54 | ```shell
55 | CUDA_VISIBLE_DEVICES=0 python demo.py --model resnet101 --num_heads 1 --lam 0.1 --dataset voc07 --load_from OUR_VOC_PRETRAINED.pth --img_dir utils/demo_images
56 | ```
57 | which will output like this:
58 | ```shell
59 | utils/demo_images/000001.jpg prediction: dog,person,
60 | utils/demo_images/000004.jpg prediction: car,
61 | utils/demo_images/000002.jpg prediction: train,
62 | ...
63 | ```
64 |
65 |
66 | ## Validation
67 | We provide pretrained models on [Google Drive](https://www.google.com/drive/) for validation. ResNet101 trained on ImageNet with **CutMix** augmentation can be downloaded
68 | [here](https://drive.google.com/u/0/uc?export=download&confirm=kYfp&id=1T4AxsAO2tszvhn62KFN5kaknBtBZIpDV).
69 | |Dataset | Backbone | Head nums | mAP(%) | Resolution | Download |
70 | | ---------- | ------- | :--------: | ------ | :---: | -------- |
71 | | VOC2007 |ResNet-101 | 1 | 94.7 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=bXcv&id=1cQSRI_DWyKpLa0tvxltoH9rM4IZMIEWJ) |
72 | | VOC2007 |ResNet-cut | 1 | 95.2 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=otx_&id=1bzSsWhGG-zUNQRMB7rQCuPMqLZjnrzFh) |
73 | | VOC2007 (extra) |ResNet-cut | 1 | 96.8 | 448x448 |[download](https://drive.google.com/u/0/uc?id=1XgVE3Q3vmE8hjdDjqow_2GyjPx_5bDjU&export=download) |
74 | | COCO |ResNet-101 | 4 | 83.3 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=EWtH&id=1e_WzdVgF_sQc--ubN-DRnGVbbJGSJEZa) |
75 | | COCO |ResNet-cut | 6 | 85.6 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=uEcu&id=17FgLUe_vr5sJX6_TT-MPdP5TYYAcVEPF) |
76 | | COCO |VIT_L16_224 | 8 | 86.5 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=1Rmm&id=1TTzCpRadhYDwZSEow3OVdrh1TKezWHF_)|
77 | | COCO |VIT_L16_224* | 8 | 86.9 | 448x448 |[download](https://drive.google.com/u/0/uc?export=download&confirm=xpbJ&id=1zYE88pmWcZfcrdQsP8-9JMo4n_g5pO4l)|
78 | | Wider |VIT_B16_224| 1 | 89.0 | 224x224 |[download](https://drive.google.com/u/0/uc?id=1qkJgWQ2EOYri8ITLth_wgnR4kEsv0bfj&export=download) |
79 | | Wider |VIT_L16_224| 1 | 90.2 | 224x224 |[download](https://drive.google.com/u/0/uc?id=1da8D7UP9cMCgKO0bb1gyRvVqYoZ3Wh7O&export=download) |
80 |
81 | For voc2007, run the following validation example:
82 | ```shell
83 | CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --load_from MODEL.pth
84 | ```
85 | For coco2014, run the following validation example:
86 | ```shell
87 | CUDA_VISIBLE_DEVICES=0 python val.py --num_heads 4 --lam 0.5 --dataset coco --num_cls 80 --load_from MODEL.pth
88 | ```
89 | For wider attribute with ViT models, run the following
90 | ```shell
91 | CUDA_VISIBLE_DEVICES=0 python val.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 --load_from ViT_B16_MODEL.pth
92 | CUDA_VISIBLE_DEVICES=0 python val.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14 --load_from ViT_L16_MODEL.pth
93 | ```
94 | To provide pretrained VIT models on Wider-Attribute dataset, we retrain them recently, which has a slightly different performance (~0.1%mAP) from what has been presented in our paper. The structure of the VIT models is the initial VIT version (**An image is worth 16x16 words: Transformers for image recognition at scale**, [link](https://arxiv.org/pdf/2010.11929.pdf)) and the implementation code of the VIT models is derived from [http://github.com/rwightman/pytorch-image-models/](http://github.com/rwightman/pytorch-image-models/).
95 | ## Training
96 | #### VOC2007
97 | You can run either of these two lines below
98 | ```shell
99 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20
100 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 1 --lam 0.1 --dataset voc07 --num_cls 20 --cutmix CutMix_ResNet101.pth
101 | ```
102 | Note that the first command uses the Official ResNet-101 backbone while the second command uses the ResNet-101 pretrained on ImageNet with CutMix augmentation
103 | [link](https://drive.google.com/u/0/uc?export=download&confirm=kYfp&id=1T4AxsAO2tszvhn62KFN5kaknBtBZIpDV) (which is supposed to gain better performance).
104 |
105 | #### MS-COCO
106 | run the ResNet-101 with 4 heads
107 | ```shell
108 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.5 --dataset coco --num_cls 80
109 | ```
110 | run the ResNet-101 (pretrained with CutMix) with 6 heads
111 | ```shell
112 | CUDA_VISIBLE_DEVICES=0 python main.py --num_heads 6 --lam 0.4 --dataset coco --num_cls 80 --cutmix CutMix_ResNet101.pth
113 | ```
114 | You can feel free to adjust the hyper-parameters such as number of attention heads (--num_heads), or the Lambda (--lam). Still, the default values of them in the above command are supposed to be the best.
115 |
116 | #### Wider-Attribute
117 | run the VIT_B16_224 with 1 heads
118 | ```shell
119 | CUDA_VISIBLE_DEVICES=0 python main.py --model vit_B16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14
120 | ```
121 | run the VIT_L16_224 with 1 heads
122 | ```shell
123 | CUDA_VISIBLE_DEVICES=0,1 python main.py --model vit_L16_224 --img_size 224 --num_heads 1 --lam 0.3 --dataset wider --num_cls 14
124 | ```
125 | Note that the VIT_L16_224 model consume larger GPU space, so we use 2 GPUs to train them.
126 | ## Notice
127 | To avoid confusion, please note the **4 lines of code** in Figure 1 (in paper) is only used in **test** stage (without training), which is our motivation. When our model is end-to-end training and testing, **multi-head-attention** (H=1, H=2, H=4, etc.) is used with different T values. Also, when H=1 and T=infty, the implementation code of **multi-head-attention** is exactly the same with Figure 1.
128 |
129 | We didn't use any new augmentation such as **Autoaugment, RandAugment** in our ResNet series models.
130 |
131 | ## Acknowledgement
132 |
133 | We thank Lin Sui (http://www.lamda.nju.edu.cn/suil/) for his initial contribution to this project.
134 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import time
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | from tqdm import tqdm
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from pipeline.resnet_csra import ResNet_CSRA
11 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
12 | from pipeline.dataset import DataSet
13 | from torchvision.transforms import transforms
14 | from utils.evaluation.eval import voc_classes, wider_classes, coco_classes, class_dict
15 |
16 |
17 | # Usage:
18 | # This demo is used to predict the label of each image
19 | # if you want to use our models to predict some labels of the VOC2007 images
20 | # 1st: use the models pretrained on VOC2007
21 | # 2nd: put the images in the utils/demo_images
22 | # 3rd: run demo.py
23 |
24 | def Args():
25 | parser = argparse.ArgumentParser(description="settings")
26 | # model default resnet101
27 | parser.add_argument("--model", default="resnet101", type=str)
28 | parser.add_argument("--num_heads", default=1, type=int)
29 | parser.add_argument("--lam",default=0.1, type=float)
30 | parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str)
31 | parser.add_argument("--img_dir", default="images/", type=str)
32 |
33 | # dataset
34 | parser.add_argument("--dataset", default="voc07", type=str)
35 | parser.add_argument("--num_cls", default=20, type=int)
36 |
37 | args = parser.parse_args()
38 | return args
39 |
40 |
41 | def demo():
42 | args = Args()
43 |
44 | # model
45 | if args.model == "resnet101":
46 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls)
47 | normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
48 | img_size = 448
49 | if args.model == "vit_B16_224":
50 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
51 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
52 | img_size = 224
53 | if args.model == "vit_L16_224":
54 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
55 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
56 | img_size = 224
57 |
58 | model.cuda()
59 | print("Loading weights from {}".format(args.load_from))
60 | model.load_state_dict(torch.load(args.load_from))
61 |
62 | # image pre-process
63 | transform = transforms.Compose([
64 | transforms.Resize((img_size, img_size)),
65 | transforms.ToTensor(),
66 | normalize
67 | ])
68 |
69 | # prediction of each image's label
70 | for img_file in os.listdir(args.img_dir):
71 | print(os.path.join(args.img_dir, img_file), end=" prediction: ")
72 | img = Image.open(os.path.join(args.img_dir, img_file)).convert("RGB")
73 | img = transform(img)
74 | img = img.cuda()
75 | img = img.unsqueeze(0)
76 |
77 | model.eval()
78 | logit = model(img).squeeze(0)
79 | logit = nn.Sigmoid()(logit)
80 |
81 |
82 | pos = torch.where(logit > 0.5)[0].cpu().numpy()
83 | for k in pos:
84 | print(class_dict[args.dataset][k], end=",")
85 | print()
86 |
87 |
88 | if __name__ == "__main__":
89 | demo()
90 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | from pipeline.resnet_csra import ResNet_CSRA
8 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
9 | from pipeline.dataset import DataSet
10 | from utils.evaluation.eval import evaluation
11 | from utils.evaluation.warmUpLR import WarmUpLR
12 | from tqdm import tqdm
13 |
14 |
15 | # modify for wider dataset and vit models
16 |
17 | def Args():
18 | parser = argparse.ArgumentParser(description="settings")
19 | # model
20 | parser.add_argument("--model", default="resnet101")
21 | parser.add_argument("--num_heads", default=1, type=int)
22 | parser.add_argument("--lam",default=0.1, type=float)
23 | parser.add_argument("--cutmix", default=None, type=str) # the path to load cutmix-pretrained backbone
24 | # dataset
25 | parser.add_argument("--dataset", default="voc07", type=str)
26 | parser.add_argument("--num_cls", default=20, type=int)
27 | parser.add_argument("--train_aug", default=["randomflip", "resizedcrop"], type=list)
28 | parser.add_argument("--test_aug", default=[], type=list)
29 | parser.add_argument("--img_size", default=448, type=int)
30 | parser.add_argument("--batch_size", default=16, type=int)
31 | # optimizer, default SGD
32 | parser.add_argument("--lr", default=0.01, type=float)
33 | parser.add_argument("--momentum", default=0.9, type=float)
34 | parser.add_argument("--w_d", default=0.0001, type=float, help="weight_decay")
35 | parser.add_argument("--warmup_epoch", default=2, type=int)
36 | parser.add_argument("--total_epoch", default=30, type=int)
37 | parser.add_argument("--print_freq", default=100, type=int)
38 | args = parser.parse_args()
39 | return args
40 |
41 |
42 | def train(i, args, model, train_loader, optimizer, warmup_scheduler):
43 | print()
44 | model.train()
45 | epoch_begin = time.time()
46 | for index, data in enumerate(train_loader):
47 | batch_begin = time.time()
48 | img = data['img'].cuda()
49 | target = data['target'].cuda()
50 |
51 | optimizer.zero_grad()
52 | logit, loss = model(img, target)
53 | loss = loss.mean()
54 | loss.backward()
55 | optimizer.step()
56 | t = time.time() - batch_begin
57 |
58 | if index % args.print_freq == 0:
59 | print("Epoch {}[{}/{}]: loss:{:.5f}, lr:{:.5f}, time:{:.4f}".format(
60 | i,
61 | args.batch_size * (index + 1),
62 | len(train_loader.dataset),
63 | loss,
64 | optimizer.param_groups[0]["lr"],
65 | float(t)
66 | ))
67 |
68 | if warmup_scheduler and i <= args.warmup_epoch:
69 | warmup_scheduler.step()
70 |
71 |
72 | t = time.time() - epoch_begin
73 | print("Epoch {} training ends, total {:.2f}s".format(i, t))
74 |
75 |
76 | def val(i, args, model, test_loader, test_file):
77 | model.eval()
78 | print("Test on Epoch {}".format(i))
79 | result_list = []
80 |
81 | # calculate logit
82 | for index, data in enumerate(tqdm(test_loader)):
83 | img = data['img'].cuda()
84 | target = data['target'].cuda()
85 | img_path = data['img_path']
86 |
87 | with torch.no_grad():
88 | logit = model(img)
89 |
90 | result = nn.Sigmoid()(logit).cpu().detach().numpy().tolist()
91 | for k in range(len(img_path)):
92 | result_list.append(
93 | {
94 | "file_name": img_path[k].split("/")[-1].split(".")[0],
95 | "scores": result[k]
96 | }
97 | )
98 | # cal_mAP OP OR
99 | evaluation(result=result_list, types=args.dataset, ann_path=test_file[0])
100 |
101 |
102 |
103 | def main():
104 | args = Args()
105 |
106 | # model
107 | if args.model == "resnet101":
108 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls, cutmix=args.cutmix)
109 | if args.model == "vit_B16_224":
110 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
111 | if args.model == "vit_L16_224":
112 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
113 |
114 | model.cuda()
115 | if torch.cuda.device_count() > 1:
116 | print("lets use {} GPUs.".format(torch.cuda.device_count()))
117 | model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
118 |
119 | # data
120 | if args.dataset == "voc07":
121 | train_file = ["data/voc07/trainval_voc07.json"]
122 | test_file = ['data/voc07/test_voc07.json']
123 | step_size = 4
124 | if args.dataset == "coco":
125 | train_file = ['data/coco/train_coco2014.json']
126 | test_file = ['data/coco/val_coco2014.json']
127 | step_size = 5
128 | if args.dataset == "wider":
129 | train_file = ['data/wider/trainval_wider.json']
130 | test_file = ["data/wider/test_wider.json"]
131 | step_size = 5
132 | args.train_aug = ["randomflip"]
133 |
134 | train_dataset = DataSet(train_file, args.train_aug, args.img_size, args.dataset)
135 | test_dataset = DataSet(test_file, args.test_aug, args.img_size, args.dataset)
136 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8)
137 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
138 |
139 | # optimizer and warmup
140 | backbone, classifier = [], []
141 | for name, param in model.named_parameters():
142 | if 'classifier' in name:
143 | classifier.append(param)
144 | else:
145 | backbone.append(param)
146 | optimizer = optim.SGD(
147 | [
148 | {'params': backbone, 'lr': args.lr},
149 | {'params': classifier, 'lr': args.lr * 10}
150 | ],
151 | momentum=args.momentum, weight_decay=args.w_d)
152 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
153 |
154 | iter_per_epoch = len(train_loader)
155 | if args.warmup_epoch > 0:
156 | warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warmup_epoch)
157 | else:
158 | warmup_scheduler = None
159 |
160 | # training and validation
161 | for i in range(1, args.total_epoch + 1):
162 | train(i, args, model, train_loader, optimizer, warmup_scheduler)
163 | torch.save(model.state_dict(), "checkpoint/{}/epoch_{}.pth".format(args.model, i))
164 | val(i, args, model, test_loader, test_file)
165 | scheduler.step()
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/pipeline/csra.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class CSRA(nn.Module): # one basic block
7 | def __init__(self, input_dim, num_classes, T, lam):
8 | super(CSRA, self).__init__()
9 | self.T = T # temperature
10 | self.lam = lam # Lambda
11 | self.head = nn.Conv2d(input_dim, num_classes, 1, bias=False)
12 | self.softmax = nn.Softmax(dim=2)
13 |
14 | def forward(self, x):
15 | # x (B d H W)
16 | # normalize classifier
17 | # score (B C HxW)
18 | score = self.head(x) / torch.norm(self.head.weight, dim=1, keepdim=True).transpose(0,1)
19 | score = score.flatten(2)
20 | base_logit = torch.mean(score, dim=2)
21 |
22 | if self.T == 99: # max-pooling
23 | att_logit = torch.max(score, dim=2)[0]
24 | else:
25 | score_soft = self.softmax(score * self.T)
26 | att_logit = torch.sum(score * score_soft, dim=2)
27 |
28 | return base_logit + self.lam * att_logit
29 |
30 |
31 |
32 |
33 | class MHA(nn.Module): # multi-head attention
34 | temp_settings = { # softmax temperature settings
35 | 1: [1],
36 | 2: [1, 99],
37 | 4: [1, 2, 4, 99],
38 | 6: [1, 2, 3, 4, 5, 99],
39 | 8: [1, 2, 3, 4, 5, 6, 7, 99]
40 | }
41 |
42 | def __init__(self, num_heads, lam, input_dim, num_classes):
43 | super(MHA, self).__init__()
44 | self.temp_list = self.temp_settings[num_heads]
45 | self.multi_head = nn.ModuleList([
46 | CSRA(input_dim, num_classes, self.temp_list[i], lam)
47 | for i in range(num_heads)
48 | ])
49 |
50 | def forward(self, x):
51 | logit = 0.
52 | for head in self.multi_head:
53 | logit += head(x)
54 | return logit
55 |
--------------------------------------------------------------------------------
/pipeline/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 | from torchvision.transforms import transforms
5 | import torch
6 | import numpy as np
7 |
8 | # modify for transformation for vit
9 | # modfify wider crop-person images
10 |
11 |
12 | class DataSet(Dataset):
13 | def __init__(self,
14 | ann_files,
15 | augs,
16 | img_size,
17 | dataset,
18 | ):
19 | self.dataset = dataset
20 | self.ann_files = ann_files
21 | self.augment = self.augs_function(augs, img_size)
22 | self.transform = transforms.Compose(
23 | [
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
26 | ]
27 | # In this paper, we normalize the image data to [0, 1]
28 | # You can also use the so called 'ImageNet' Normalization method
29 | )
30 | self.anns = []
31 | self.load_anns()
32 | print(self.augment)
33 |
34 | # in wider dataset we use vit models
35 | # so transformation has been changed
36 | if self.dataset == "wider":
37 | self.transform = transforms.Compose(
38 | [
39 | transforms.ToTensor(),
40 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
41 | ]
42 | )
43 |
44 | def augs_function(self, augs, img_size):
45 | t = []
46 | if 'randomflip' in augs:
47 | t.append(transforms.RandomHorizontalFlip())
48 | if 'ColorJitter' in augs:
49 | t.append(transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0))
50 | if 'resizedcrop' in augs:
51 | t.append(transforms.RandomResizedCrop(img_size, scale=(0.7, 1.0)))
52 | if 'RandAugment' in augs:
53 | t.append(RandAugment())
54 |
55 | t.append(transforms.Resize((img_size, img_size)))
56 |
57 | return transforms.Compose(t)
58 |
59 | def load_anns(self):
60 | self.anns = []
61 | for ann_file in self.ann_files:
62 | json_data = json.load(open(ann_file, "r"))
63 | self.anns += json_data
64 |
65 | def __len__(self):
66 | return len(self.anns)
67 |
68 | def __getitem__(self, idx):
69 | idx = idx % len(self)
70 | ann = self.anns[idx]
71 | img = Image.open(ann["img_path"]).convert("RGB")
72 |
73 | if self.dataset == "wider":
74 | x, y, w, h = ann['bbox']
75 | img_area = img.crop([x, y, x+w, y+h])
76 | img_area = self.augment(img_area)
77 | img_area = self.transform(img_area)
78 | message = {
79 | "img_path": ann['img_path'],
80 | "target": torch.Tensor(ann['target']),
81 | "img": img_area
82 | }
83 | else: # voc and coco
84 | img = self.augment(img)
85 | img = self.transform(img)
86 | message = {
87 | "img_path": ann["img_path"],
88 | "target": torch.Tensor(ann["target"]),
89 | "img": img
90 | }
91 |
92 | return message
93 | # finally, if we use dataloader to get the data, we will get
94 | # {
95 | # "img_path": list, # length = batch_size
96 | # "target": Tensor, # shape: batch_size * num_classes
97 | # "img": Tensor, # shape: batch_size * 3 * 224 * 224
98 | # }
99 |
--------------------------------------------------------------------------------
/pipeline/resnet_csra.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import ResNet
2 | from torchvision.models.resnet import Bottleneck, BasicBlock
3 | from .csra import CSRA, MHA
4 | import torch.utils.model_zoo as model_zoo
5 | import logging
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | }
18 |
19 |
20 |
21 |
22 |
23 |
24 | class ResNet_CSRA(ResNet):
25 | arch_settings = {
26 | 18: (BasicBlock, (2, 2, 2, 2)),
27 | 34: (BasicBlock, (3, 4, 6, 3)),
28 | 50: (Bottleneck, (3, 4, 6, 3)),
29 | 101: (Bottleneck, (3, 4, 23, 3)),
30 | 152: (Bottleneck, (3, 8, 36, 3))
31 | }
32 |
33 | def __init__(self, num_heads, lam, num_classes, depth=101, input_dim=2048, cutmix=None):
34 | self.block, self.layers = self.arch_settings[depth]
35 | self.depth = depth
36 | super(ResNet_CSRA, self).__init__(self.block, self.layers)
37 | self.init_weights(pretrained=True, cutmix=cutmix)
38 |
39 | self.classifier = MHA(num_heads, lam, input_dim, num_classes)
40 | self.loss_func = F.binary_cross_entropy_with_logits
41 |
42 | def backbone(self, x):
43 | x = self.conv1(x)
44 | x = self.bn1(x)
45 | x = self.relu(x)
46 | x = self.maxpool(x)
47 |
48 | x = self.layer1(x)
49 | x = self.layer2(x)
50 | x = self.layer3(x)
51 | x = self.layer4(x)
52 |
53 | return x
54 |
55 | def forward_train(self, x, target):
56 | x = self.backbone(x)
57 | logit = self.classifier(x)
58 | loss = self.loss_func(logit, target, reduction="mean")
59 | return logit, loss
60 |
61 | def forward_test(self, x):
62 | x = self.backbone(x)
63 | x = self.classifier(x)
64 | return x
65 |
66 | def forward(self, x, target=None):
67 | if target is not None:
68 | return self.forward_train(x, target)
69 | else:
70 | return self.forward_test(x)
71 |
72 | def init_weights(self, pretrained=True, cutmix=None):
73 | if cutmix is not None:
74 | print("backbone params inited by CutMix pretrained model")
75 | state_dict = torch.load(cutmix)
76 | elif pretrained:
77 | print("backbone params inited by Pytorch official model")
78 | model_url = model_urls["resnet{}".format(self.depth)]
79 | state_dict = model_zoo.load_url(model_url)
80 |
81 | model_dict = self.state_dict()
82 | try:
83 | pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict}
84 | self.load_state_dict(pretrained_dict)
85 | except:
86 | logger = logging.getLogger()
87 | logger.info(
88 | "the keys in pretrained model is not equal to the keys in the ResNet you choose, trying to fix...")
89 | state_dict = self._keysFix(model_dict, state_dict)
90 | self.load_state_dict(state_dict)
91 |
92 | # remove the original 1000-class fc
93 | self.fc = nn.Sequential()
--------------------------------------------------------------------------------
/pipeline/timm_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .tuple import to_ntuple, to_2tuple, to_3tuple, to_4tuple
2 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
3 | from .weight_init import trunc_normal_
4 |
5 |
--------------------------------------------------------------------------------
/pipeline/timm_utils/drop.py:
--------------------------------------------------------------------------------
1 | """ DropBlock, DropPath
2 |
3 | PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4 |
5 | Papers:
6 | DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7 |
8 | Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9 |
10 | Code:
11 | DropBlock impl inspired by two Tensorflow impl that I liked:
12 | - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13 | - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14 |
15 | Hacked together by / Copyright 2020 Ross Wightman
16 | """
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | def drop_block_2d(
23 | x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24 | with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
26 |
27 | DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
28 | runs with success, but needs further validation and possibly optimization for lower runtime impact.
29 | """
30 | B, C, H, W = x.shape
31 | total_size = W * H
32 | clipped_block_size = min(block_size, min(W, H))
33 | # seed_drop_rate, the gamma parameter
34 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
35 | (W - block_size + 1) * (H - block_size + 1))
36 |
37 | # Forces the block to be inside the feature map.
38 | w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
39 | valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
40 | ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
41 | valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
42 |
43 | if batchwise:
44 | # one mask for whole batch, quite a bit faster
45 | uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
46 | else:
47 | uniform_noise = torch.rand_like(x)
48 | block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
49 | block_mask = -F.max_pool2d(
50 | -block_mask,
51 | kernel_size=clipped_block_size, # block_size,
52 | stride=1,
53 | padding=clipped_block_size // 2)
54 |
55 | if with_noise:
56 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
57 | if inplace:
58 | x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
59 | else:
60 | x = x * block_mask + normal_noise * (1 - block_mask)
61 | else:
62 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
63 | if inplace:
64 | x.mul_(block_mask * normalize_scale)
65 | else:
66 | x = x * block_mask * normalize_scale
67 | return x
68 |
69 |
70 | def drop_block_fast_2d(
71 | x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72 | gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
73 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74 |
75 | DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
76 | block mask at edges.
77 | """
78 | B, C, H, W = x.shape
79 | total_size = W * H
80 | clipped_block_size = min(block_size, min(W, H))
81 | gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
82 | (W - block_size + 1) * (H - block_size + 1))
83 |
84 | if batchwise:
85 | # one mask for whole batch, quite a bit faster
86 | block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
87 | else:
88 | # mask per batch element
89 | block_mask = torch.rand_like(x) < gamma
90 | block_mask = F.max_pool2d(
91 | block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
92 |
93 | if with_noise:
94 | normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
95 | if inplace:
96 | x.mul_(1. - block_mask).add_(normal_noise * block_mask)
97 | else:
98 | x = x * (1. - block_mask) + normal_noise * block_mask
99 | else:
100 | block_mask = 1 - block_mask
101 | normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
102 | if inplace:
103 | x.mul_(block_mask * normalize_scale)
104 | else:
105 | x = x * block_mask * normalize_scale
106 | return x
107 |
108 |
109 | class DropBlock2d(nn.Module):
110 | """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
111 | """
112 | def __init__(self,
113 | drop_prob=0.1,
114 | block_size=7,
115 | gamma_scale=1.0,
116 | with_noise=False,
117 | inplace=False,
118 | batchwise=False,
119 | fast=True):
120 | super(DropBlock2d, self).__init__()
121 | self.drop_prob = drop_prob
122 | self.gamma_scale = gamma_scale
123 | self.block_size = block_size
124 | self.with_noise = with_noise
125 | self.inplace = inplace
126 | self.batchwise = batchwise
127 | self.fast = fast # FIXME finish comparisons of fast vs not
128 |
129 | def forward(self, x):
130 | if not self.training or not self.drop_prob:
131 | return x
132 | if self.fast:
133 | return drop_block_fast_2d(
134 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
135 | else:
136 | return drop_block_2d(
137 | x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
138 |
139 |
140 | def drop_path(x, drop_prob: float = 0., training: bool = False):
141 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
142 |
143 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
144 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
145 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
146 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
147 | 'survival rate' as the argument.
148 |
149 | """
150 | if drop_prob == 0. or not training:
151 | return x
152 | keep_prob = 1 - drop_prob
153 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
154 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
155 | random_tensor.floor_() # binarize
156 | output = x.div(keep_prob) * random_tensor
157 | return output
158 |
159 |
160 | class DropPath(nn.Module):
161 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
162 | """
163 | def __init__(self, drop_prob=None):
164 | super(DropPath, self).__init__()
165 | self.drop_prob = drop_prob
166 |
167 | def forward(self, x):
168 | return drop_path(x, self.drop_prob, self.training)
169 |
--------------------------------------------------------------------------------
/pipeline/timm_utils/tuple.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | from torch._six import container_abcs
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, container_abcs.Iterable):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/pipeline/timm_utils/weight_init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import warnings
4 |
5 |
6 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
7 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
8 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
9 | def norm_cdf(x):
10 | # Computes standard normal cumulative distribution function
11 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
12 |
13 | if (mean < a - 2 * std) or (mean > b + 2 * std):
14 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
15 | "The distribution of values may be incorrect.",
16 | stacklevel=2)
17 |
18 | with torch.no_grad():
19 | # Values are generated by using a truncated uniform distribution and
20 | # then using the inverse CDF for the normal distribution.
21 | # Get upper and lower cdf values
22 | l = norm_cdf((a - mean) / std)
23 | u = norm_cdf((b - mean) / std)
24 |
25 | # Uniformly fill tensor with values from [l, u], then translate to
26 | # [2l-1, 2u-1].
27 | tensor.uniform_(2 * l - 1, 2 * u - 1)
28 |
29 | # Use inverse cdf transform for normal distribution to get truncated
30 | # standard normal
31 | tensor.erfinv_()
32 |
33 | # Transform to proper mean, std
34 | tensor.mul_(std * math.sqrt(2.))
35 | tensor.add_(mean)
36 |
37 | # Clamp to ensure it's in the proper range
38 | tensor.clamp_(min=a, max=b)
39 | return tensor
40 |
41 |
42 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
43 | # type: (Tensor, float, float, float, float) -> Tensor
44 | r"""Fills the input Tensor with values drawn from a truncated
45 | normal distribution. The values are effectively drawn from the
46 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
47 | with values outside :math:`[a, b]` redrawn until they are within
48 | the bounds. The method used for generating the random values works
49 | best when :math:`a \leq \text{mean} \leq b`.
50 | Args:
51 | tensor: an n-dimensional `torch.Tensor`
52 | mean: the mean of the normal distribution
53 | std: the standard deviation of the normal distribution
54 | a: the minimum cutoff value
55 | b: the maximum cutoff value
56 | Examples:
57 | >>> w = torch.empty(3, 5)
58 | >>> nn.init.trunc_normal_(w)
59 | """
60 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
61 |
--------------------------------------------------------------------------------
/pipeline/vit_csra.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 |
3 | A PyTorch implement of Vision Transformers as described in
4 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
5 |
6 | The official jax code is released and available at https://github.com/google-research/vision_transformer
7 |
8 | Status/TODO:
9 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
10 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
11 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
12 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
13 |
14 | Acknowledgments:
15 | * The paper authors for releasing code and weights, thanks!
16 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
17 | for some einops/einsum fun
18 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
19 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
20 |
21 | Hacked together by / Copyright 2020 Ross Wightman
22 | """
23 | import math
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 | import torch.utils.model_zoo as model_zoo
28 | from functools import partial
29 | from .timm_utils import DropPath, to_2tuple, trunc_normal_
30 | from .csra import MHA, CSRA
31 |
32 |
33 | default_cfgs = {
34 | 'vit_base_patch16_224': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
35 | 'vit_large_patch16_224':'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth'
36 | }
37 |
38 |
39 |
40 | class Mlp(nn.Module):
41 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42 | super().__init__()
43 | out_features = out_features or in_features
44 | hidden_features = hidden_features or in_features
45 | self.fc1 = nn.Linear(in_features, hidden_features)
46 | self.act = act_layer()
47 | self.fc2 = nn.Linear(hidden_features, out_features)
48 | self.drop = nn.Dropout(drop)
49 |
50 | def forward(self, x):
51 | x = self.fc1(x)
52 | x = self.act(x)
53 | x = self.drop(x)
54 | x = self.fc2(x)
55 | x = self.drop(x)
56 | return x
57 |
58 |
59 | class Attention(nn.Module):
60 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
61 | super().__init__()
62 | self.num_heads = num_heads
63 | head_dim = dim // num_heads # 64
64 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
65 | self.scale = qk_scale or head_dim ** -0.5
66 |
67 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
68 | self.attn_drop = nn.Dropout(attn_drop)
69 | self.proj = nn.Linear(dim, dim)
70 | self.proj_drop = nn.Dropout(proj_drop)
71 |
72 | def forward(self, x):
73 | B, N, C = x.shape
74 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
75 | # qkv (3, B, 12, N, C/12)
76 | # q (B, 12, N, C/12)
77 | # k (B, 12, N, C/12)
78 | # v (B, 12, N, C/12)
79 | # attn (B, 12, N, N)
80 | # x (B, 12, N, C/12)
81 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
82 |
83 | attn = (q @ k.transpose(-2, -1)) * self.scale
84 | attn = attn.softmax(dim=-1)
85 | attn = self.attn_drop(attn)
86 |
87 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
88 |
89 | x = self.proj(x)
90 | x = self.proj_drop(x)
91 |
92 | return x
93 |
94 |
95 | class Block(nn.Module):
96 |
97 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
98 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
99 | super().__init__()
100 | self.norm1 = norm_layer(dim)
101 | self.attn = Attention(
102 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
103 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
104 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
105 | self.norm2 = norm_layer(dim)
106 | mlp_hidden_dim = int(dim * mlp_ratio)
107 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
108 |
109 | def forward(self, x):
110 | x = x + self.drop_path(self.attn(self.norm1(x)))
111 | x = x + self.drop_path(self.mlp(self.norm2(x)))
112 | return x
113 |
114 |
115 | class PatchEmbed(nn.Module):
116 | """ Image to Patch Embedding
117 | """
118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
119 | super().__init__()
120 | img_size = to_2tuple(img_size)
121 | patch_size = to_2tuple(patch_size)
122 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
123 | self.img_size = img_size
124 | self.patch_size = patch_size
125 | self.num_patches = num_patches
126 |
127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
128 |
129 | def forward(self, x):
130 | B, C, H, W = x.shape
131 | # FIXME look at relaxing size constraints
132 | assert H == self.img_size[0] and W == self.img_size[1], \
133 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
134 | x = self.proj(x).flatten(2).transpose(1, 2)
135 | return x
136 |
137 |
138 | class HybridEmbed(nn.Module):
139 | """ CNN Feature Map Embedding
140 | Extract feature map from CNN, flatten, project to embedding dim.
141 | """
142 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
143 | super().__init__()
144 | assert isinstance(backbone, nn.Module)
145 | img_size = to_2tuple(img_size)
146 | self.img_size = img_size
147 | self.backbone = backbone
148 | if feature_size is None:
149 | with torch.no_grad():
150 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
151 | # map for all networks, the feature metadata has reliable channel and stride info, but using
152 | # stride to calc feature dim requires info about padding of each stage that isn't captured.
153 | training = backbone.training
154 | if training:
155 | backbone.eval()
156 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
157 | feature_size = o.shape[-2:]
158 | feature_dim = o.shape[1]
159 | backbone.train(training)
160 | else:
161 | feature_size = to_2tuple(feature_size)
162 | feature_dim = self.backbone.feature_info.channels()[-1]
163 | self.num_patches = feature_size[0] * feature_size[1]
164 | self.proj = nn.Linear(feature_dim, embed_dim)
165 |
166 | def forward(self, x):
167 | x = self.backbone(x)[-1]
168 | x = x.flatten(2).transpose(1, 2)
169 | x = self.proj(x)
170 | return x
171 |
172 |
173 | class VIT_CSRA(nn.Module):
174 | """ Vision Transformer with support for patch or hybrid CNN input stage
175 | """
176 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
177 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
178 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, cls_num_heads=1, cls_num_cls=80, lam=0.3):
179 | super().__init__()
180 | self.add_w = 0.
181 | self.normalize = False
182 | self.num_classes = num_classes
183 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
184 |
185 | if hybrid_backbone is not None:
186 | self.patch_embed = HybridEmbed(
187 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
188 | else:
189 | self.patch_embed = PatchEmbed(
190 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
191 | num_patches = self.patch_embed.num_patches
192 | self.HW = int(math.sqrt(num_patches))
193 |
194 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
195 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
196 | self.pos_drop = nn.Dropout(p=drop_rate)
197 |
198 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
199 | self.blocks = nn.ModuleList([
200 | Block(
201 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
202 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
203 | for i in range(depth)])
204 | self.norm = norm_layer(embed_dim)
205 |
206 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
207 | #self.repr = nn.Linear(embed_dim, representation_size)
208 | #self.repr_act = nn.Tanh()
209 |
210 | trunc_normal_(self.pos_embed, std=.02)
211 | trunc_normal_(self.cls_token, std=.02)
212 | self.apply(self._init_weights)
213 |
214 | # We add our MHA (CSRA) beside the orginal VIT structure below
215 | self.head = nn.Sequential() # delete original classifier
216 | self.classifier = MHA(input_dim=embed_dim, num_heads=cls_num_heads, num_classes=cls_num_cls, lam=lam)
217 |
218 | self.loss_func = F.binary_cross_entropy_with_logits
219 |
220 | def _init_weights(self, m):
221 | if isinstance(m, nn.Linear):
222 | trunc_normal_(m.weight, std=.02)
223 | if isinstance(m, nn.Linear) and m.bias is not None:
224 | nn.init.constant_(m.bias, 0)
225 | elif isinstance(m, nn.LayerNorm):
226 | nn.init.constant_(m.bias, 0)
227 | nn.init.constant_(m.weight, 1.0)
228 |
229 | def backbone(self, x):
230 | B = x.shape[0]
231 | x = self.patch_embed(x)
232 |
233 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
234 | x = torch.cat((cls_tokens, x), dim=1)
235 | x = x + self.pos_embed
236 | x = self.pos_drop(x)
237 |
238 | for blk in self.blocks:
239 | x = blk(x)
240 | x = self.norm(x)
241 |
242 | # (B, 1+HW, C)
243 | # we use all the feature to form the tensor like B C H W
244 | x = x[:, 1:]
245 | b, hw, c = x.shape
246 | x = x.transpose(1, 2)
247 | x = x.reshape(b, c, self.HW, self.HW)
248 |
249 | return x
250 |
251 | def forward_train(self, x, target):
252 | x = self.backbone(x)
253 | logit = self.classifier(x)
254 | loss = self.loss_func(logit, target, reduction="mean")
255 | return logit, loss
256 |
257 | def forward_test(self, x):
258 | x = self.backbone(x)
259 | x = self.classifier(x)
260 | return x
261 |
262 | def forward(self, x, target=None):
263 | if target is not None:
264 | return self.forward_train(x, target)
265 | else:
266 | return self.forward_test(x)
267 |
268 |
269 |
270 |
271 | def _conv_filter(state_dict, patch_size=16):
272 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
273 | out_dict = {}
274 | for k, v in state_dict.items():
275 | if 'patch_embed.proj.weight' in k:
276 | v = v.reshape((v.shape[0], 3, patch_size, patch_size))
277 | out_dict[k] = v
278 | return out_dict
279 |
280 |
281 | def VIT_B16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
282 | model = VIT_CSRA(
283 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
284 | norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
285 |
286 | model_url = default_cfgs['vit_base_patch16_224']
287 | if pretrained:
288 | state_dict = model_zoo.load_url(model_url)
289 | model.load_state_dict(state_dict, strict=False)
290 | return model
291 |
292 |
293 | def VIT_L16_224_CSRA(pretrained=True, cls_num_heads=1, cls_num_cls=80, lam=0.3):
294 | model = VIT_CSRA(
295 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
296 | norm_layer=partial(nn.LayerNorm, eps=1e-6), cls_num_heads=cls_num_heads, cls_num_cls=cls_num_cls, lam=lam)
297 |
298 | model_url = default_cfgs['vit_large_patch16_224']
299 | if pretrained:
300 | state_dict = model_zoo.load_url(model_url)
301 | model.load_state_dict(state_dict, strict=False)
302 | # load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
303 | return model
--------------------------------------------------------------------------------
/utils/demo_images/000001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000001.jpg
--------------------------------------------------------------------------------
/utils/demo_images/000002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000002.jpg
--------------------------------------------------------------------------------
/utils/demo_images/000004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000004.jpg
--------------------------------------------------------------------------------
/utils/demo_images/000006.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000006.jpg
--------------------------------------------------------------------------------
/utils/demo_images/000007.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000007.jpg
--------------------------------------------------------------------------------
/utils/demo_images/000009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/demo_images/000009.jpg
--------------------------------------------------------------------------------
/utils/evaluation/cal_PR.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 |
4 |
5 |
6 | def json_metric(score_json, target_json, num_classes, types):
7 | assert len(score_json) == len(target_json)
8 | scores = np.zeros((len(score_json), num_classes))
9 | targets = np.zeros((len(target_json), num_classes))
10 | for index in range(len(score_json)):
11 | scores[index] = score_json[index]["scores"]
12 | targets[index] = target_json[index]["target"]
13 |
14 |
15 | return metric(scores, targets, types)
16 |
17 | def json_metric_top3(score_json, target_json, num_classes, types):
18 | assert len(score_json) == len(target_json)
19 | scores = np.zeros((len(score_json), num_classes))
20 | targets = np.zeros((len(target_json), num_classes))
21 | for index in range(len(score_json)):
22 | tmp = np.array(score_json[index]['scores'])
23 | idx = np.argsort(-tmp)
24 | idx_after_3 = idx[3:]
25 | tmp[idx_after_3] = 0.
26 |
27 | scores[index] = tmp
28 | # scores[index] = score_json[index]["scores"]
29 | targets[index] = target_json[index]["target"]
30 |
31 | return metric(scores, targets, types)
32 |
33 |
34 | def metric(scores, targets, types):
35 | """
36 | :param scores: the output the model predict
37 | :param targets: the gt label
38 | :return: OP, OR, OF1, CP, CR, CF1
39 | calculate the Precision of every class by: TP/TP+FP i.e. TP/total predict
40 | calculate the Recall by: TP/total GT
41 | """
42 | num, num_class = scores.shape
43 | gt_num = np.zeros(num_class)
44 | tp_num = np.zeros(num_class)
45 | predict_num = np.zeros(num_class)
46 |
47 |
48 | for index in range(num_class):
49 | score = scores[:, index]
50 | target = targets[:, index]
51 | if types == 'wider':
52 | tmp = np.where(target == 99)[0]
53 | # score[tmp] = 0
54 | target[tmp] = 0
55 |
56 | if types == 'voc07':
57 | tmp = np.where(target != 0)[0]
58 | score = score[tmp]
59 | target = target[tmp]
60 | neg_id = np.where(target == -1)[0]
61 | target[neg_id] = 0
62 |
63 |
64 | gt_num[index] = np.sum(target == 1)
65 | predict_num[index] = np.sum(score >= 0.5)
66 | tp_num[index] = np.sum(target * (score >= 0.5))
67 |
68 | predict_num[predict_num == 0] = 1 # avoid dividing 0
69 | OP = np.sum(tp_num) / np.sum(predict_num)
70 | OR = np.sum(tp_num) / np.sum(gt_num)
71 | OF1 = (2 * OP * OR) / (OP + OR)
72 |
73 | #print(tp_num / predict_num)
74 | #print(tp_num / gt_num)
75 | CP = np.sum(tp_num / predict_num) / num_class
76 | CR = np.sum(tp_num / gt_num) / num_class
77 | CF1 = (2 * CP * CR) / (CP + CR)
78 |
79 | return OP, OR, OF1, CP, CR, CF1
80 |
--------------------------------------------------------------------------------
/utils/evaluation/cal_mAP.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import json
5 |
6 |
7 | def json_map(cls_id, pred_json, ann_json, types):
8 | assert len(ann_json) == len(pred_json)
9 | num = len(ann_json)
10 | predict = np.zeros((num), dtype=np.float64)
11 | target = np.zeros((num), dtype=np.float64)
12 |
13 | for i in range(num):
14 | predict[i] = pred_json[i]["scores"][cls_id]
15 | target[i] = ann_json[i]["target"][cls_id]
16 |
17 | if types == 'wider':
18 | tmp = np.where(target != 99)[0]
19 | predict = predict[tmp]
20 | target = target[tmp]
21 | num = len(tmp)
22 |
23 | if types == 'voc07':
24 | tmp = np.where(target != 0)[0]
25 | predict = predict[tmp]
26 | target = target[tmp]
27 | neg_id = np.where(target == -1)[0]
28 | target[neg_id] = 0
29 | num = len(tmp)
30 |
31 |
32 | tmp = np.argsort(-predict)
33 | target = target[tmp]
34 | predict = predict[tmp]
35 |
36 |
37 | pre, obj = 0, 0
38 | for i in range(num):
39 | if target[i] == 1:
40 | obj += 1.0
41 | pre += obj / (i+1)
42 | pre /= obj
43 | return pre
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/utils/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import numpy as np
4 | import json
5 | from tqdm import tqdm
6 | from .cal_mAP import json_map
7 | from .cal_PR import json_metric, metric, json_metric_top3
8 |
9 |
10 | voc_classes = ("aeroplane", "bicycle", "bird", "boat", "bottle",
11 | "bus", "car", "cat", "chair", "cow", "diningtable",
12 | "dog", "horse", "motorbike", "person", "pottedplant",
13 | "sheep", "sofa", "train", "tvmonitor")
14 | coco_classes = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
15 | 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
16 | 'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
17 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
18 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
19 | 'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
20 | 'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
21 | 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
22 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
23 | 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
24 | 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
25 | 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
26 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
27 | 'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush')
28 | wider_classes = (
29 | "Male","longHair","sunglass","Hat","Tshiirt","longSleeve","formal",
30 | "shorts","jeans","longPants","skirt","faceMask", "logo","stripe")
31 |
32 | class_dict = {
33 | "voc07": voc_classes,
34 | "coco": coco_classes,
35 | "wider": wider_classes,
36 | }
37 |
38 |
39 |
40 | def evaluation(result, types, ann_path):
41 | print("Evaluation")
42 | classes = class_dict[types]
43 | aps = np.zeros(len(classes), dtype=np.float64)
44 |
45 | ann_json = json.load(open(ann_path, "r"))
46 | pred_json = result
47 |
48 | for i, _ in enumerate(tqdm(classes)):
49 | ap = json_map(i, pred_json, ann_json, types)
50 | aps[i] = ap
51 | OP, OR, OF1, CP, CR, CF1 = json_metric(pred_json, ann_json, len(classes), types)
52 | print("mAP: {:4f}".format(np.mean(aps)))
53 | print("CP: {:4f}, CR: {:4f}, CF1 :{:4F}".format(CP, CR, CF1))
54 | print("OP: {:4f}, OR: {:4f}, OF1 {:4F}".format(OP, OR, OF1))
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/utils/evaluation/warmUpLR.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class WarmUpLR(torch.optim.lr_scheduler._LRScheduler):
5 | def __init__(self, optimizer, total_iters, last_epoch=-1):
6 | self.total_iters = total_iters
7 | super().__init__(optimizer, last_epoch=last_epoch)
8 |
9 | def get_lr(self):
10 | return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
11 |
12 |
--------------------------------------------------------------------------------
/utils/pipeline.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kevinz-code/CSRA/c8480d12742459809179eb0fc4ee0a88b6b98bfa/utils/pipeline.PNG
--------------------------------------------------------------------------------
/utils/prepare/prepare_coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | from pycocotools.coco import COCO
6 |
7 |
8 |
9 | def make_data(data_path=None, tag="train"):
10 | annFile = os.path.join(data_path, "annotations/instances_{}2014.json".format(tag))
11 | coco = COCO(annFile)
12 |
13 | img_id = coco.getImgIds()
14 | cat_id = coco.getCatIds()
15 | img_id = list(sorted(img_id))
16 | cat_trans = {}
17 | for i in range(len(cat_id)):
18 | cat_trans[cat_id[i]] = i
19 |
20 | message = []
21 |
22 |
23 | for i in img_id:
24 | data = {}
25 | target = [0] * 80
26 | path = ""
27 | img_info = coco.loadImgs(i)[0]
28 | ann_ids = coco.getAnnIds(imgIds = i)
29 | anns = coco.loadAnns(ann_ids)
30 | if len(anns) == 0:
31 | continue
32 | else:
33 | for i in range(len(anns)):
34 | cls = anns[i]['category_id']
35 | cls = cat_trans[cls]
36 | target[cls] = 1
37 | path = img_info['file_name']
38 | data['target'] = target
39 | data['img_path'] = os.path.join(os.path.join(data_path, "images/{}2014/".format(tag)), path)
40 | message.append(data)
41 |
42 | with open('data/coco/{}_coco2014.json'.format(tag), 'w') as f:
43 | json.dump(message, f)
44 |
45 |
46 |
47 | # The final json file include: train_coco2014.json & val_coco2014.json
48 | # which is the following format:
49 | # [item1, item2, item3, ......,]
50 | # item1 = {
51 | # "target":
52 | # "img_path":
53 | # }
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | # Usage: --data_path /your/dataset/path/COCO2014
57 | parser.add_argument("--data_path", default="Dataset/COCO2014/", type=str, help="The absolute path of COCO2014")
58 | args = parser.parse_args()
59 |
60 | if not os.path.exists("data/coco"):
61 | os.makedirs("data/coco")
62 |
63 | make_data(data_path=args.data_path, tag="train")
64 | make_data(data_path=args.data_path, tag="val")
65 |
66 | print("COCO data ready!")
67 | print("data/coco/train_coco2014.json, data/coco/val_coco2014.json")
68 |
--------------------------------------------------------------------------------
/utils/prepare/prepare_voc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 | import xml.dom.minidom as XML
6 |
7 |
8 |
9 | voc_cls_id = {"aeroplane":0, "bicycle":1, "bird":2, "boat":3, "bottle":4,
10 | "bus":5, "car":6, "cat":7, "chair":8, "cow":9,
11 | "diningtable":10, "dog":11, "horse":12, "motorbike":13, "person":14,
12 | "pottedplant":15, "sheep":16, "sofa":17, "train":18, "tvmonitor":19}
13 |
14 |
15 | def get_label(data_path):
16 | print("generating labels for VOC07 dataset")
17 | xml_paths = os.path.join(data_path, "VOC2007/Annotations/")
18 | save_dir = "data/voc07/labels"
19 |
20 | if not os.path.exists(save_dir):
21 | os.makedirs(save_dir)
22 |
23 | for i in os.listdir(xml_paths):
24 | if not i.endswith(".xml"):
25 | continue
26 | s_name = i.split('.')[0] + ".txt"
27 | s_dir = os.path.join(save_dir, s_name)
28 | xml_path = os.path.join(xml_paths, i)
29 | DomTree = XML.parse(xml_path)
30 | Root = DomTree.documentElement
31 |
32 | obj_all = Root.getElementsByTagName("object")
33 | leng = len(obj_all)
34 | cls = []
35 | difi_tag = []
36 | for obj in obj_all:
37 | # get the classes
38 | obj_name = obj.getElementsByTagName('name')[0]
39 | one_class = obj_name.childNodes[0].data
40 | cls.append(voc_cls_id[one_class])
41 |
42 | difficult = obj.getElementsByTagName('difficult')[0]
43 | difi_tag.append(difficult.childNodes[0].data)
44 |
45 | for i, c in enumerate(cls):
46 | with open(s_dir, "a") as f:
47 | f.writelines("%s,%s\n" % (c, difi_tag[i]))
48 |
49 |
50 | def transdifi(data_path):
51 | print("generating final json file for VOC07 dataset")
52 | label_dir = "data/voc07/labels/"
53 | img_dir = os.path.join(data_path, "VOC2007/JPEGImages/")
54 |
55 | # get trainval test id
56 | id_dirs = os.path.join(data_path, "VOC2007/ImageSets/Main/")
57 | f_train = open(os.path.join(id_dirs, "train.txt"), "r").readlines()
58 | f_val = open(os.path.join(id_dirs, "val.txt"), "r").readlines()
59 | f_trainval = f_train + f_val
60 | f_test = open(os.path.join(id_dirs, "test.txt"), "r")
61 |
62 | trainval_id = np.sort([int(line.strip()) for line in f_trainval]).tolist()
63 | test_id = [int(line.strip()) for line in f_test]
64 | trainval_data = []
65 | test_data = []
66 |
67 | # ternary label
68 | # -1 means negative
69 | # 0 means difficult
70 | # +1 means positive
71 |
72 | # binary label
73 | # 0 means negative
74 | # +1 means positive
75 |
76 | # we use binary labels in our implementation
77 |
78 | for item in sorted(os.listdir(label_dir)):
79 | with open(os.path.join(label_dir, item), "r") as f:
80 |
81 | target = np.array([-1] * 20)
82 | classes = []
83 | diffi_tag = []
84 |
85 | for line in f.readlines():
86 | cls, tag = map(int, line.strip().split(','))
87 | classes.append(cls)
88 | diffi_tag.append(tag)
89 |
90 | classes = np.array(classes)
91 | diffi_tag = np.array(diffi_tag)
92 | for i in range(20):
93 | if i in classes:
94 | i_index = np.where(classes == i)[0]
95 | if len(i_index) == 1:
96 | target[i] = 1 - diffi_tag[i_index]
97 | else:
98 | if len(i_index) == sum(diffi_tag[i_index]):
99 | target[i] = 0
100 | else:
101 | target[i] = 1
102 | else:
103 | continue
104 | img_path = os.path.join(img_dir, item.split('.')[0]+".jpg")
105 |
106 | if int(item.split('.')[0]) in trainval_id:
107 | target[target == -1] = 0 # from ternary to binary by treating difficult as negatives
108 | data = {"target": target.tolist(), "img_path": img_path}
109 | trainval_data.append(data)
110 | if int(item.split('.')[0]) in test_id:
111 | data = {"target": target.tolist(), "img_path": img_path}
112 | test_data.append(data)
113 |
114 | json.dump(trainval_data, open("data/voc07/trainval_voc07.json", "w"))
115 | json.dump(test_data, open("data/voc07/test_voc07.json", "w"))
116 | print("VOC07 data preparing finished!")
117 | print("data/voc07/trainval_voc07.json data/voc07/test_voc07.json")
118 |
119 | # remove label cash
120 | for item in os.listdir(label_dir):
121 | os.remove(os.path.join(label_dir, item))
122 | os.rmdir(label_dir)
123 |
124 |
125 | # We treat difficult classes in trainval_data as negtive while ignore them in test_data
126 | # The ignoring operation can be automatically done during evaluation (testing).
127 | # The final json file include: trainval_voc07.json & test_voc07.json
128 | # which is the following format:
129 | # [item1, item2, item3, ......,]
130 | # item1 = {
131 | # "target":
132 | # "img_path":
133 | # }
134 |
135 | if __name__ == "__main__":
136 | parser = argparse.ArgumentParser()
137 | # Usage: --data_path /your/dataset/path/VOCdevkit
138 | parser.add_argument("--data_path", default="Dataset/VOCdevkit/", type=str, help="The absolute path of VOCdevkit")
139 | args = parser.parse_args()
140 |
141 | if not os.path.exists("data/voc07"):
142 | os.makedirs("data/voc07")
143 |
144 | if 'VOCdevkit' not in args.data_path:
145 | print("WARNING: please include \'VOCdevkit\' str in your args.data_path")
146 | # exit()
147 |
148 | get_label(args.data_path)
149 | transdifi(args.data_path)
--------------------------------------------------------------------------------
/utils/prepare/prepare_wider.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import argparse
5 |
6 |
7 | def make_wider(tag, value, data_path):
8 | img_path = os.path.join(data_path, "Image")
9 | ann_path = os.path.join(data_path, "Annotations")
10 | ann_file = os.path.join(ann_path, "wider_attribute_{}.json".format(tag))
11 |
12 | data = json.load(open(ann_file, "r"))
13 |
14 | final = []
15 | image_list = data['images']
16 | for image in image_list:
17 | for person in image["targets"]: # iterate over each person
18 | tmp = {}
19 | tmp['img_path'] = os.path.join(img_path, image['file_name'])
20 | tmp['bbox'] = person['bbox']
21 | attr = person["attribute"]
22 | for i, item in enumerate(attr):
23 | if item == -1:
24 | attr[i] = 0
25 | if item == 0:
26 | attr[i] = value # pad un-specified samples
27 | if item == 1:
28 | attr[i] = 1
29 | tmp["target"] = attr
30 | final.append(tmp)
31 |
32 | json.dump(final, open("data/wider/{}_wider.json".format(tag), "w"))
33 | print("data/wider/{}_wider.json".format(tag))
34 |
35 |
36 |
37 | # which is the following format:
38 | # [item1, item2, item3, ......,]
39 | # item1 = {
40 | # "target":
41 | # "img_path":
42 | # }
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument("--data_path", default="Dataset/WIDER_ATTRIBUTE", type=str)
48 | args = parser.parse_args()
49 |
50 | if not os.path.exists("data/wider"):
51 | os.makedirs("data/wider")
52 |
53 | # 0 (zero) means negative, we treat un-specified attribute as negative in the trainval set
54 | make_wider(tag='trainval', value=0, data_path=args.data_path)
55 |
56 | # 99 means we ignore un-specified attribute in the test set, following previous work
57 | # the number 99 can be properly identified when evaluating mAP
58 | make_wider(tag='test', value=99, data_path=args.data_path)
59 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import json
3 | import torch
4 | from torchvision import transforms
5 | import cv2
6 | import numpy as np
7 | import os
8 | import torch.nn as nn
9 |
10 | def show_cam_on_img(img, mask, img_path_save):
11 | heat_map = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
12 | heat_map = np.float32(heat_map) / 255
13 |
14 | cam = heat_map + np.float32(img)
15 | cam = cam / np.max(cam)
16 | cv2.imwrite(img_path_save, np.uint8(255 * cam))
17 |
18 |
19 | img_path_read = ""
20 | img_path_save = ""
21 |
22 |
23 |
24 |
25 | def main():
26 | img = cv2.imread(img_path_read, flags=1)
27 |
28 | img = np.float32(cv2.resize(img, (224, 224))) / 255
29 |
30 | # cam_all is the score tensor of shape (B, C, H, W), similar to y_raw in out Figure 1
31 | # cls_idx specifying the i-th class out of C class
32 | # visualize the 0's class heatmap
33 | cls_idx = 0
34 | cam = cam_all[cls_idx]
35 |
36 |
37 | # cam = nn.ReLU()(cam)
38 | cam = cam / torch.max(cam)
39 |
40 | cam = cv2.resize(np.array(cam), (224, 224))
41 | show_cam_on_img(img, cam, img_path_save)
42 |
43 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | from pipeline.resnet_csra import ResNet_CSRA
8 | from pipeline.vit_csra import VIT_B16_224_CSRA, VIT_L16_224_CSRA, VIT_CSRA
9 | from pipeline.dataset import DataSet
10 | from utils.evaluation.eval import evaluation
11 | from utils.evaluation.eval import WarmUpLR
12 | from tqdm import tqdm
13 |
14 |
15 | def Args():
16 | parser = argparse.ArgumentParser(description="settings")
17 | # model default resnet101
18 | parser.add_argument("--model", default="resnet101", type=str)
19 | parser.add_argument("--num_heads", default=1, type=int)
20 | parser.add_argument("--lam",default=0.1, type=float)
21 | parser.add_argument("--load_from", default="models_local/resnet101_voc07_head1_lam0.1_94.7.pth", type=str)
22 | # dataset
23 | parser.add_argument("--dataset", default="voc07", type=str)
24 | parser.add_argument("--num_cls", default=20, type=int)
25 | parser.add_argument("--test_aug", default=[], type=list)
26 | parser.add_argument("--img_size", default=448, type=int)
27 | parser.add_argument("--batch_size", default=16, type=int)
28 |
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def val(args, model, test_loader, test_file):
34 | model.eval()
35 | print("Test on Pretrained Models")
36 | result_list = []
37 |
38 | # calculate logit
39 | for index, data in enumerate(tqdm(test_loader)):
40 | img = data['img'].cuda()
41 | target = data['target'].cuda()
42 | img_path = data['img_path']
43 |
44 | with torch.no_grad():
45 | logit = model(img)
46 |
47 | result = nn.Sigmoid()(logit).cpu().detach().numpy().tolist()
48 | for k in range(len(img_path)):
49 | result_list.append(
50 | {
51 | "file_name": img_path[k].split("/")[-1].split(".")[0],
52 | "scores": result[k]
53 | }
54 | )
55 |
56 | # cal_mAP OP OR
57 | evaluation(result=result_list, types=args.dataset, ann_path=test_file[0])
58 |
59 |
60 |
61 | def main():
62 | args = Args()
63 |
64 | # model
65 | if args.model == "resnet101":
66 | model = ResNet_CSRA(num_heads=args.num_heads, lam=args.lam, num_classes=args.num_cls, cutmix=args.cutmix)
67 | if args.model == "vit_B16_224":
68 | model = VIT_B16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
69 | if args.model == "vit_L16_224":
70 | model = VIT_L16_224_CSRA(cls_num_heads=args.num_heads, lam=args.lam, cls_num_cls=args.num_cls)
71 |
72 | model.cuda()
73 | print("Loading weights from {}".format(args.load_from))
74 | if torch.cuda.device_count() > 1:
75 | print("lets use {} GPUs.".format(torch.cuda.device_count()))
76 | model = nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count())))
77 | model.module.load_state_dict(torch.load(args.load_from))
78 | else:
79 | model.load_state_dict(torch.load(args.load_from))
80 |
81 | # data
82 | if args.dataset == "voc07":
83 | test_file = ['data/voc07/test_voc07.json']
84 | if args.dataset == "coco":
85 | test_file = ['data/coco/val_coco2014.json']
86 | if args.dataset == "wider":
87 | test_file = ['data/wider/test_wider.json']
88 |
89 |
90 | test_dataset = DataSet(test_file, args.test_aug, args.img_size, args.dataset)
91 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
92 |
93 | val(args, model, test_loader, test_file)
94 |
95 |
96 | if __name__ == "__main__":
97 | main()
98 |
--------------------------------------------------------------------------------