├── LICENSE
├── MAGNeto
├── .gitignore
├── README.md
├── data
│ └── nus_wide
│ │ └── notebooks
│ │ ├── Move Images.ipynb
│ │ └── Prepare Tag Data.ipynb
├── infer.py
├── magneto
│ ├── __init__.py
│ ├── augment_helper.py
│ ├── autoaugment.py
│ ├── data.py
│ ├── layers.py
│ ├── loss.py
│ ├── metrics.py
│ ├── model.py
│ └── utils.py
├── preprocess.py
├── requirements.txt
├── scripts
│ ├── start_infer.sh
│ ├── start_preprocess.sh
│ ├── start_train.sh
│ └── start_train_usp.sh
└── train.py
└── README.md
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
--------------------------------------------------------------------------------
/MAGNeto/.gitignore:
--------------------------------------------------------------------------------
1 | # Local files and directories
2 | data/nus_wide/images
3 | data/nus_wide/annotations
4 | data/nus_wide/raw_data
5 | runs
6 | snapshots
7 | tmp
8 | .vscode
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/MAGNeto/README.md:
--------------------------------------------------------------------------------
1 | # [MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem](https://arxiv.org/abs/2011.04349)
2 |
3 | ## Downloading NUS-WIDE dataset
4 | - Official: https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html
5 | - Unofficial: http://cs-people.bu.edu/hekun/data/TALR/NUSWIDE.zip (recommended for downloading all images)
6 |
7 | ## Data preparation
8 |
9 | ### Moving images to a single directory
10 |
11 | ```
12 | ./data/nus_wide/notebooks/Move\ Images.ipynb
13 | ```
14 |
15 | ### Preparing tag data
16 |
17 | ```
18 | ./data/nus_wide/notebooks/Prepare\ Tag\ Data.ipynb
19 | ```
20 |
21 | ## Setting up the environment
22 |
23 | ```bash
24 | pip install -U pip
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Generating label for raw data
29 |
30 | - Step 1: Reconfigure `scripts/start_preprocess.sh`
31 |
32 | To list all configurable parameters, run
33 |
34 | ```bash
35 | python preprocess.py -h
36 | ```
37 |
38 | - Step 2: Run
39 |
40 | ```bash
41 | bash scripts/start_preprocess.sh
42 | ```
43 |
44 | ## Training the model
45 |
46 | - Step 1: Reconfigure `scripts/start_train.sh`
47 |
48 | To list all configurable parameters, run
49 |
50 | ```bash
51 | python train.py -h
52 | ```
53 |
54 | - Step 2: Run
55 |
56 | ```bash
57 | bash scripts/start_train.sh
58 | ```
59 |
60 | ## Inferring test data
61 |
62 | - Step 1: Reconfigure `scripts/start_infer.sh`
63 |
64 | To list all configurable parameters, run
65 |
66 | ```bash
67 | python infer.py -h
68 | ```
69 |
70 | - Step 2: Run
71 |
72 | ```bash
73 | bash scripts/start_infer.sh
74 | ```
75 |
76 | ## Reference
77 |
78 | Please acknowledge the following paper in case of using this code as part of any published research:
79 |
80 | **"MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem."**
81 | Hieu Trong Phung, Anh Tuan Vu, Tung Dinh Nguyen, Lam Thanh Do, Giang Nam Ngo, Trung Thanh Tran, Ngoc C. Lê.
82 |
83 | @article{Hieu2020,
84 | title={MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem},
85 | author={Hieu Trong Phung and Anh Tuan Vu and Tung Dinh Nguyen and Lam Thanh Do and Giang Nam Ngo and Trung Thanh Tran and Ngoc C. L\^{e}},
86 | journal={arXiv preprint arXiv:2011.04349},
87 | year={2020}
88 | }
89 |
90 | ## License
91 |
92 | The code is released under the [GPLv3 License](https://www.gnu.org/licenses/gpl-3.0.en.html).
93 |
--------------------------------------------------------------------------------
/MAGNeto/data/nus_wide/notebooks/Move Images.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pathlib import Path\n",
10 | "from shutil import copy\n",
11 | "\n",
12 | "from tqdm import tqdm"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "IMG_DIR = Path('../downloads/Flickr/') # Path to the directory that contains all images\n",
22 | "SAVE_DIR = Path('../images') # Path to the target directory"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "name": "stderr",
32 | "output_type": "stream",
33 | "text": [
34 | "100%|██████████| 704/704 [00:51<00:00, 13.67it/s]\n"
35 | ]
36 | }
37 | ],
38 | "source": [
39 | "for subdir in tqdm(list(IMG_DIR.glob('*'))):\n",
40 | " for img_path in subdir.glob('*.jpg'):\n",
41 | " trg = SAVE_DIR / str(img_path).split('_')[-1] # Only use the IDs of available images to name the new moved images\n",
42 | " copy(img_path, trg)"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 4,
48 | "metadata": {},
49 | "outputs": [
50 | {
51 | "data": {
52 | "text/plain": [
53 | "269642"
54 | ]
55 | },
56 | "execution_count": 4,
57 | "metadata": {},
58 | "output_type": "execute_result"
59 | }
60 | ],
61 | "source": [
62 | "len(list(SAVE_DIR.glob('*.jpg')))"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 5,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "!rm -rf $IMG_DIR"
72 | ]
73 | }
74 | ],
75 | "metadata": {
76 | "kernelspec": {
77 | "display_name": "Python 3",
78 | "language": "python",
79 | "name": "python3"
80 | },
81 | "language_info": {
82 | "codemirror_mode": {
83 | "name": "ipython",
84 | "version": 3
85 | },
86 | "file_extension": ".py",
87 | "mimetype": "text/x-python",
88 | "name": "python",
89 | "nbconvert_exporter": "python",
90 | "pygments_lexer": "ipython3",
91 | "version": "3.6.12"
92 | }
93 | },
94 | "nbformat": 4,
95 | "nbformat_minor": 4
96 | }
97 |
--------------------------------------------------------------------------------
/MAGNeto/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import multiprocessing as mp
4 | import copy
5 |
6 | import torch
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 | import pandas as pd
10 | from tqdm import tqdm
11 |
12 | from magneto.model import MAGNeto
13 | from magneto.data import TagAndImageDataset
14 | from magneto.augment_helper import val_transform
15 | from magneto.utils import parse_infer_args
16 |
17 |
18 | def predict(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, opt: argparse.Namespace) -> list:
19 | '''
20 | input:
21 | + model:
22 | + dataloader:
23 | + opt: configuration.
24 | output:
25 | list of the predictions of all items.
26 | '''
27 | all_preds = []
28 | all_item_ids = []
29 |
30 | with torch.no_grad():
31 | for batch_val_idx, data in enumerate(tqdm(dataloader)):
32 | if opt.has_label:
33 | image_batch, tags_batch, mask_batch, _, item_id_batch = data
34 | else:
35 | image_batch, tags_batch, mask_batch, item_id_batch = data
36 | image_batch = image_batch.to(opt.device)
37 | tags_batch = tags_batch.to(opt.device)
38 | mask_batch = mask_batch.to(opt.device)
39 |
40 | preds, _, _, _, _ = model(tags_batch, image_batch, mask_batch)
41 |
42 | preds = preds.detach().cpu().numpy()
43 | mask_batch = mask_batch.detach().cpu().numpy()
44 | preds = [tuple(pred[~mask].tolist()) for pred, mask in zip(preds, mask_batch)]
45 |
46 | all_preds.extend(preds)
47 | all_item_ids.extend(item_id_batch.detach().cpu().tolist())
48 |
49 | return all_preds, all_item_ids
50 |
51 |
52 | def postprocess_prediction(row, opt: argparse.Namespace):
53 | '''
54 | input:
55 | + row
56 | + opt: configuration.
57 | output:
58 | [important_tags,] post_prediction
59 | '''
60 | tags = np.array(row['tags'].split(','))
61 | tags = tags[:opt.max_len]
62 |
63 | if opt.has_label:
64 | label = np.array(row['label'].split(','), dtype=np.uint8)
65 | label = label[:opt.max_len]
66 | mask = label == 1
67 | important_tags = tags[mask]
68 |
69 | final_results = sorted(
70 | zip(tags, row.raw_prediction, mask), key=lambda x: x[1], reverse=True)
71 | else:
72 | final_results = sorted(zip(tags, row.raw_prediction),
73 | key=lambda x: x[1], reverse=True)
74 |
75 | # Get at least top n important tags
76 | post_prediction = final_results[:opt.top]
77 | # Get other accepted important tags based on threshold value
78 | for final_result in final_results[opt.top:]:
79 | if final_result[1] > opt.threshold:
80 | post_prediction.append(final_result)
81 | else:
82 | break
83 |
84 | if opt.has_label:
85 | return important_tags, post_prediction
86 | else:
87 | return post_prediction
88 |
89 |
90 | def postprocess_predictions(df: pd.DataFrame, opt: argparse.Namespace) -> pd.DataFrame:
91 | '''
92 | input:
93 | + df: input pandas dataframe.
94 | + opt: configuration.
95 | output:
96 | postprocessed pandas dataframe.
97 | '''
98 | post_predictions = []
99 | if opt.has_label:
100 | list_of_important_tags = []
101 |
102 | if opt.use_multiprocessing:
103 | import multiprocessing as mp
104 |
105 | # Apply a patch for the multiprocessing module
106 | import multiprocessing.pool as mpp
107 | from magneto.utils import istarmap
108 | mpp.Pool.istarmap = istarmap
109 |
110 | all_rows = [row for idx, row in df.iterrows()]
111 |
112 | inputs = list(zip(
113 | all_rows,
114 | [copy.deepcopy(opt) for _ in range(len(df))]
115 | ))
116 |
117 | with mp.Pool(opt.num_workers) as pool:
118 | for result in tqdm(pool.istarmap(postprocess_prediction, inputs), total=len(inputs)):
119 | if opt.has_label:
120 | important_tags, post_prediction = result
121 | list_of_important_tags.append(important_tags)
122 | else:
123 | post_prediction = result
124 |
125 | post_predictions.append(post_prediction)
126 |
127 | else:
128 | for idx, row in tqdm(list(df.iterrows())):
129 | if opt.has_label:
130 | important_tags, post_prediction = postprocess_prediction(
131 | row, opt)
132 | list_of_important_tags.append(important_tags)
133 | else:
134 | post_prediction = postprocess_prediction(
135 | row, opt)
136 |
137 | post_predictions.append(post_prediction)
138 |
139 | list_of_pred_tags = []
140 | list_of_probs = []
141 |
142 | for post_prediction in post_predictions:
143 | post_prediction = list(zip(*post_prediction))
144 | if len(post_prediction) >= 2:
145 | # TODO we will take care of masks later.
146 | pred_tags, probs = post_prediction[0], post_prediction[1]
147 |
148 | list_of_pred_tags.append('\n'.join(pred_tags))
149 | probs = np.round(probs, decimals=3)
150 | probs = np.array(probs, dtype=str)
151 | list_of_probs.append('\n'.join(probs))
152 | else:
153 | list_of_pred_tags.append('')
154 | list_of_probs.append('')
155 |
156 | df['pred_tags'] = list_of_pred_tags
157 | df['probs'] = list_of_probs
158 |
159 | if opt.has_label:
160 | list_of_important_tags = list(
161 | map(lambda x: '\n'.join(x), list_of_important_tags))
162 |
163 | df['important_tags'] = list_of_important_tags
164 |
165 | return df
166 |
167 |
168 | def main():
169 | opt = parse_infer_args()
170 |
171 | states = torch.load(
172 | opt.model_path, map_location=lambda storage, loc: storage)
173 |
174 | # Load model's configuration
175 | model_config = states['config']
176 | opt.max_len = model_config['max_len']
177 | opt.d_model = model_config['d_model']
178 | opt.t_blocks = model_config['t_blocks']
179 | opt.t_heads = model_config['t_heads']
180 | opt.t_dim_feedforward = model_config['t_dim_feedforward']
181 | opt.i_blocks = model_config['i_blocks']
182 | opt.i_heads = model_config['i_heads']
183 | opt.i_dim_feedforward = model_config['i_dim_feedforward']
184 | opt.img_backbone = model_config['img_backbone']
185 | opt.g_dim_feedforward = model_config['g_dim_feedforward']
186 |
187 | test_dataset = TagAndImageDataset(
188 | csv_path=opt.csv_path,
189 | vocab_path=opt.vocab_path,
190 | img_dir=opt.img_dir,
191 | max_len=opt.max_len,
192 | has_label=opt.has_label,
193 | return_item_id=True,
194 | img_preprocess_fn=val_transform
195 | )
196 |
197 | test_dataloader = DataLoader(
198 | dataset=test_dataset,
199 | batch_size=opt.batch_size,
200 | num_workers=opt.num_workers,
201 | pin_memory=True if not opt.no_cuda else False
202 | )
203 | model = MAGNeto(
204 | d_model=opt.d_model,
205 | vocab_size=test_dataset.vocab_size,
206 | t_blocks=opt.t_blocks,
207 | t_heads=opt.t_heads,
208 | t_dim_feedforward=opt.t_dim_feedforward,
209 | i_blocks=opt.i_blocks,
210 | i_heads=opt.i_heads,
211 | i_dim_feedforward=opt.i_dim_feedforward,
212 | img_backbone=opt.img_backbone,
213 | g_dim_feedforward=opt.g_dim_feedforward,
214 | dropout=0
215 | )
216 | model.load_state_dict(states['model'])
217 | model.to(opt.device)
218 | model.eval()
219 |
220 | all_preds, all_item_ids = predict(model, test_dataloader, opt)
221 | raw_prediction_df = pd.DataFrame({
222 | 'item_id': all_item_ids,
223 | 'raw_prediction': all_preds
224 | }).drop_duplicates().set_index('item_id')
225 |
226 | base_df = pd.read_csv(opt.csv_path, index_col='item_id')
227 |
228 | # Log all error item ids
229 | error_item_ids = np.setdiff1d(base_df.index.unique(), raw_prediction_df.index.unique(), assume_unique=True).astype(str)
230 | if len(error_item_ids) > 0:
231 | print('Error item ids:', ', '.join(error_item_ids))
232 | with open('error_item_ids.txt', 'w') as f:
233 | f.write('\n'.join(error_item_ids))
234 |
235 | final_df = raw_prediction_df.join(base_df).reset_index()
236 |
237 | final_df = postprocess_predictions(final_df, opt)
238 |
239 | if opt.has_label:
240 | final_df.rename(columns={'important_tags': 'ground_truth'}, inplace=True)
241 | final_df[['item_id', 'tags', 'pred_tags', 'probs', 'ground_truth']].to_csv(
242 | 'prediction.csv', index=False)
243 | else:
244 | final_df[['item_id', 'tags', 'pred_tags', 'probs']].to_csv(
245 | 'prediction.csv', index=False)
246 |
247 |
248 | if __name__ == '__main__':
249 | main()
250 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pixta-dev/labteam/2c14e0605520c100eca24f92d79461167c765c2f/MAGNeto/magneto/__init__.py
--------------------------------------------------------------------------------
/MAGNeto/magneto/augment_helper.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import pandas as pd
5 | from scipy.special import softmax
6 | from torchvision import transforms
7 |
8 | from magneto.autoaugment import ImageNetPolicy
9 |
10 |
11 | MEAN = [0.485, 0.456, 0.406]
12 | STD = [0.229, 0.224, 0.225]
13 | INPUT_SHAPE = 112
14 |
15 | train_transform = transforms.Compose([
16 | transforms.RandomResizedCrop(INPUT_SHAPE, scale=(0.3, 1.0)),
17 | transforms.RandomHorizontalFlip(),
18 | ImageNetPolicy(),
19 | transforms.ToTensor(),
20 | transforms.Normalize(mean=MEAN, std=STD)
21 | ])
22 |
23 | val_transform = transforms.Compose([
24 | transforms.Resize(INPUT_SHAPE),
25 | transforms.CenterCrop(INPUT_SHAPE),
26 | transforms.ToTensor(),
27 | transforms.Normalize(mean=MEAN, std=STD)
28 | ])
29 |
30 |
31 | class TagAugmentation():
32 | def __init__(
33 | self,
34 | vocab_path: str,
35 | drop: float = 0.0,
36 | add: float = 0.0,
37 | path: str = None
38 | ):
39 | self.vocab = pd.read_csv(
40 | vocab_path, keep_default_na=False, na_values=['']).word.tolist()
41 | self.drop, self.add = drop, add
42 |
43 | def __call__(self, tags: np.array, label: np.array) -> (np.array, np.array):
44 | '''
45 | input:
46 | + tags: raw tags.
47 | + label: raw label.
48 | output:
49 | Processed tags and corresponding label.
50 | '''
51 | self.tags, self.label = tags, label
52 | self._seperate_indices()
53 |
54 | # NOTE: Dropping must be performed prior to adding process
55 | if self.drop:
56 | self.tags, self.label = self._drop_tag()
57 |
58 | if self.add:
59 | self.tags, self.label = self._add_tag()
60 |
61 | return self.tags, self.label
62 |
63 | def _get_num(self, prob: float):
64 | return random.randint(0, min(int(prob * len(self.unimportant_indices)), len(self.vocab)))
65 |
66 | def _seperate_indices(self):
67 | unimportant_mask = self.label == 0
68 | self.unimportant_indices = np.array(range(len(self.tags)))[
69 | unimportant_mask]
70 | self.important_indices = np.array(range(len(self.tags)))[
71 | np.logical_not(unimportant_mask)]
72 |
73 | def _drop_tag(self):
74 | # Randomly select the number of unimportant tags to keep
75 | num_unimportant_drop = self._get_num(self.drop)
76 | num_unimportant_keep = len(
77 | self.unimportant_indices) - num_unimportant_drop
78 |
79 | # Randomly choose indices of unimportant tags to keep based on the number above
80 | unimportant_keep_indices = np.array(random.sample(
81 | list(self.unimportant_indices), k=num_unimportant_keep), dtype=int)
82 | keep_indices = np.concatenate(
83 | (unimportant_keep_indices, self.important_indices))
84 | keep_indices.sort()
85 |
86 | return self.tags[keep_indices], self.label[keep_indices]
87 |
88 | def _add_tag(self):
89 | num_add = self._get_num(self.add)
90 |
91 | tags_add = []
92 | sampled_tags = 0
93 | while sampled_tags < num_add:
94 | noise_tags = random.sample(self.vocab, k=num_add - sampled_tags)
95 | valid_tags = [t for t in noise_tags if t not in self.tags]
96 | tags_add += valid_tags
97 | sampled_tags = len(tags_add)
98 |
99 | tags = np.concatenate((self.tags, np.asarray(tags_add)))
100 | label = np.concatenate((self.label, np.zeros(num_add)))
101 |
102 | return tags, label
103 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/autoaugment.py:
--------------------------------------------------------------------------------
1 | # Source: https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
2 | from PIL import Image, ImageEnhance, ImageOps
3 | import numpy as np
4 | import random
5 |
6 |
7 | class ImageNetPolicy(object):
8 | """ Randomly choose one of the best 24 Sub-policies on ImageNet.
9 |
10 | Example:
11 | >>> policy = ImageNetPolicy()
12 | >>> transformed = policy(image)
13 |
14 | Example as a PyTorch Transform:
15 | >>> transform=transforms.Compose([
16 | >>> transforms.Resize(256),
17 | >>> ImageNetPolicy(),
18 | >>> transforms.ToTensor()])
19 | """
20 |
21 | def __init__(self, fillcolor=(128, 128, 128)):
22 | self.policies = [
23 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
24 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
25 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
26 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
27 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
28 |
29 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
30 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
31 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
32 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
33 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
34 |
35 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
36 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
37 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
38 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
39 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
40 |
41 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
42 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
43 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
44 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
45 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
46 |
47 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
48 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
49 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
50 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
51 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
52 | ]
53 |
54 | def __call__(self, img):
55 | policy_idx = random.randint(0, len(self.policies) - 1)
56 | return self.policies[policy_idx](img)
57 |
58 | def __repr__(self):
59 | return "AutoAugment ImageNet Policy"
60 |
61 |
62 | class CIFAR10Policy(object):
63 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10.
64 |
65 | Example:
66 | >>> policy = CIFAR10Policy()
67 | >>> transformed = policy(image)
68 |
69 | Example as a PyTorch Transform:
70 | >>> transform=transforms.Compose([
71 | >>> transforms.Resize(256),
72 | >>> CIFAR10Policy(),
73 | >>> transforms.ToTensor()])
74 | """
75 |
76 | def __init__(self, fillcolor=(128, 128, 128)):
77 | self.policies = [
78 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
79 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
80 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
81 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
82 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
83 |
84 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
85 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
86 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
87 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
88 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
89 |
90 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
91 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
92 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
93 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
94 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
95 |
96 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
97 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
98 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
99 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
100 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
101 |
102 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
103 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
104 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
105 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
106 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
107 | ]
108 |
109 | def __call__(self, img):
110 | policy_idx = random.randint(0, len(self.policies) - 1)
111 | return self.policies[policy_idx](img)
112 |
113 | def __repr__(self):
114 | return "AutoAugment CIFAR10 Policy"
115 |
116 |
117 | class SVHNPolicy(object):
118 | """ Randomly choose one of the best 25 Sub-policies on SVHN.
119 |
120 | Example:
121 | >>> policy = SVHNPolicy()
122 | >>> transformed = policy(image)
123 |
124 | Example as a PyTorch Transform:
125 | >>> transform=transforms.Compose([
126 | >>> transforms.Resize(256),
127 | >>> SVHNPolicy(),
128 | >>> transforms.ToTensor()])
129 | """
130 |
131 | def __init__(self, fillcolor=(128, 128, 128)):
132 | self.policies = [
133 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
134 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
135 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
136 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
137 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
138 |
139 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
140 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
141 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
142 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
143 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
144 |
145 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
146 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
147 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
148 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
149 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
150 |
151 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
152 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
153 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
154 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
155 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
156 |
157 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
158 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
159 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
160 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
161 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
162 | ]
163 |
164 | def __call__(self, img):
165 | policy_idx = random.randint(0, len(self.policies) - 1)
166 | return self.policies[policy_idx](img)
167 |
168 | def __repr__(self):
169 | return "AutoAugment SVHN Policy"
170 |
171 |
172 | class SubPolicy(object):
173 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
174 | ranges = {
175 | "shearX": np.linspace(0, 0.3, 10),
176 | "shearY": np.linspace(0, 0.3, 10),
177 | "translateX": np.linspace(0, 150 / 331, 10),
178 | "translateY": np.linspace(0, 150 / 331, 10),
179 | "rotate": np.linspace(0, 30, 10),
180 | "color": np.linspace(0.0, 0.9, 10),
181 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
182 | "solarize": np.linspace(256, 0, 10),
183 | "contrast": np.linspace(0.0, 0.9, 10),
184 | "sharpness": np.linspace(0.0, 0.9, 10),
185 | "brightness": np.linspace(0.0, 0.9, 10),
186 | "autocontrast": [0] * 10,
187 | "equalize": [0] * 10,
188 | "invert": [0] * 10
189 | }
190 |
191 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
192 | def rotate_with_fill(img, magnitude):
193 | rot = img.convert("RGBA").rotate(magnitude)
194 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
195 |
196 | func = {
197 | "shearX": lambda img, magnitude: img.transform(
198 | img.size, Image.AFFINE, (1, magnitude *
199 | random.choice([-1, 1]), 0, 0, 1, 0),
200 | Image.BICUBIC, fillcolor=fillcolor),
201 | "shearY": lambda img, magnitude: img.transform(
202 | img.size, Image.AFFINE, (1, 0, 0, magnitude *
203 | random.choice([-1, 1]), 1, 0),
204 | Image.BICUBIC, fillcolor=fillcolor),
205 | "translateX": lambda img, magnitude: img.transform(
206 | img.size, Image.AFFINE, (1, 0, magnitude *
207 | img.size[0] * random.choice([-1, 1]), 0, 1, 0),
208 | fillcolor=fillcolor),
209 | "translateY": lambda img, magnitude: img.transform(
210 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude *
211 | img.size[1] * random.choice([-1, 1])),
212 | fillcolor=fillcolor),
213 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
214 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
215 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
216 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
217 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
218 | 1 + magnitude * random.choice([-1, 1])),
219 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
220 | 1 + magnitude * random.choice([-1, 1])),
221 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
222 | 1 + magnitude * random.choice([-1, 1])),
223 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
224 | "equalize": lambda img, magnitude: ImageOps.equalize(img),
225 | "invert": lambda img, magnitude: ImageOps.invert(img)
226 | }
227 |
228 | self.p1 = p1
229 | self.operation1 = func[operation1]
230 | self.magnitude1 = ranges[operation1][magnitude_idx1]
231 | self.p2 = p2
232 | self.operation2 = func[operation2]
233 | self.magnitude2 = ranges[operation2][magnitude_idx2]
234 | def __call__(self, img):
235 | if random.random() < self.p1:
236 | img = self.operation1(img, self.magnitude1)
237 | if random.random() < self.p2:
238 | img = self.operation2(img, self.magnitude2)
239 | return img
240 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import torch
7 | from torch.utils.data import Dataset, DataLoader
8 | from PIL import Image
9 |
10 | from magneto.augment_helper import train_transform, val_transform, TagAugmentation
11 |
12 |
13 | class TagAndImageDataset(Dataset):
14 | def __init__(
15 | self,
16 | csv_path: str,
17 | vocab_path: str,
18 | img_dir: str,
19 | max_len: int,
20 | has_label: bool = True,
21 | return_item_id: bool = False,
22 | tag_preprocess_fn: object = None,
23 | img_preprocess_fn: object = None,
24 | ):
25 | '''
26 | input:
27 | + csv_path: path to the csv file that contains "image_id", "tags"[, "label"].
28 | + vocab_path: path to the csv file that contains the vocabulary of the dataset.
29 | + img_dir: the directory that contains corresponding images.
30 | + max_len: the maximum number of tags.
31 | + has_label: whether to prepare and return label or not.
32 | + return_item_id: whether to return item_id or not.
33 | + tag_preprocess_fn: the preprocessing func for tags; only support when having label.
34 | + img_preprocess_fn: the preprocessing func for image.
35 | '''
36 | df = pd.read_csv(csv_path)
37 |
38 | self.has_label = has_label
39 | self.return_item_id = return_item_id
40 |
41 | self.list_of_tags = df['tags'].apply(
42 | lambda x: np.array(x.split(','))).tolist()
43 | self.list_of_image_path = df['item_id'].map(
44 | lambda x: os.path.join(img_dir, str(x) + '.jpg')).tolist()
45 | if self.has_label:
46 | self.list_of_label = df['label'].apply(
47 | lambda x: np.array(x.split(','), dtype=np.float32)).tolist()
48 | if self.return_item_id:
49 | self.list_of_item_id = df['item_id'].tolist()
50 |
51 | self.vocab = pd.read_csv(
52 | vocab_path, keep_default_na=False, na_values=[''])
53 | self.word_to_index = self.vocab.set_index('word')
54 | self.vocab_size = len(self.vocab)
55 |
56 | self.max_num_of_tags = max_len
57 | self.tag_preprocess_fn = tag_preprocess_fn
58 | self.img_preprocess_fn = img_preprocess_fn
59 |
60 | def __len__(self) -> int:
61 | return len(self.list_of_tags)
62 |
63 | def __getitem__(self, idx: object) -> (torch.tensor, torch.tensor, torch.tensor, torch.tensor):
64 | '''
65 | input:
66 | + idx: item's index.
67 | output:
68 | + image: self explanatory.
69 | + vectors: embedding vectors of tags.
70 | + label: corresponding label (only returned when being provided).
71 | + mask: generated mask used to mask-out padding positions.
72 | '''
73 | if torch.is_tensor(idx):
74 | idx = idx.tolist()
75 |
76 | # Get image
77 | image_path = self.list_of_image_path[idx]
78 | try:
79 | image = Image.open(image_path).convert('RGB')
80 | except:
81 | return self.__getitem__(random.randrange(self.__len__()))
82 |
83 | if self.img_preprocess_fn is not None:
84 | image = self.img_preprocess_fn(image)
85 |
86 | # Get indices of tags, corresponding mask and label (if provided)
87 | tags = self.list_of_tags[idx]
88 | if self.has_label:
89 | label = self.list_of_label[idx]
90 |
91 | assert len(tags) == len(label)
92 |
93 | if self.tag_preprocess_fn is not None:
94 | tags, label = self.tag_preprocess_fn(tags, label)
95 | if self.return_item_id:
96 | item_id = self.list_of_item_id[idx]
97 |
98 | # Create default mask
99 | mask = torch.zeros(self.max_num_of_tags, dtype=torch.bool)
100 |
101 | # Fixed the number of tags
102 | if len(tags) >= self.max_num_of_tags:
103 | # Get top N
104 | tags = tags[:self.max_num_of_tags]
105 | indices = self.word_to_index.loc[tags, 'index']
106 | indices = torch.tensor(indices, dtype=torch.int64)
107 | if self.has_label:
108 | label = torch.tensor(
109 | label[:self.max_num_of_tags],
110 | dtype=torch.float32
111 | )
112 | else:
113 | indices = self.word_to_index.loc[tags, 'index']
114 |
115 | # Right-padding
116 | # Padding idx will be n where n = vocab_size
117 | padding_vector = np.ones(
118 | self.max_num_of_tags, dtype=np.int64) * (self.vocab_size)
119 | padding_vector[:len(tags)] = indices
120 | indices = torch.tensor(padding_vector, dtype=torch.int64)
121 |
122 | mask[len(tags):] = True
123 |
124 | if self.has_label:
125 | zeros_vector = np.zeros(
126 | self.max_num_of_tags, dtype=np.float32)
127 | zeros_vector[:len(tags)] += label
128 | label = torch.tensor(zeros_vector, dtype=torch.float32)
129 |
130 | results = [image, indices, mask]
131 | if self.has_label:
132 | results.append(label)
133 | if self.return_item_id:
134 | results.append(item_id)
135 |
136 | return results
137 |
138 |
139 | def get_dataloaders(
140 | train_csv_path: str,
141 | val_csv_path: str,
142 | vocab_path: str,
143 | img_dir: str,
144 | tagaug_add_max_ratio: float,
145 | tagaug_drop_max_ratio: float,
146 | train_batch_size: int = 32,
147 | val_batch_size: int = 32,
148 | max_len: int = 100,
149 | num_workers: int = 0,
150 | pin_memory: bool = True
151 | ) -> (DataLoader, DataLoader):
152 | '''
153 | input:
154 | + train_csv_path: path to the csv file of the training dataset.
155 | + val_csv_path: path to the csv file of the validation dataset.
156 | + vocab_path: path to the csv file that contains the vocabulary of the dataset.
157 | + img_dir: the directory that contains all images for training and validation sets.
158 | + tagaug_add_max_ratio: the maximum ratio between the number of adding tags and non-important ones.
159 | + tagaug_drop_max_ratio: the maximum ratio between the number of dropping tags and non-important ones.
160 | + train_batch_size: the batch-size of the training dataloader.
161 | + val_batch_size: the batch-size of the validation dataloader.
162 | + max_len: the maximum length for each set of tags.
163 | + num_workers: the number of workers used to load data.
164 | + pin_memory: the pin_memory param of PyTorch's DataLoader class.
165 | output:
166 | the dataloaders for the training and validation sets.
167 | '''
168 | train_dataset = TagAndImageDataset(
169 | csv_path=train_csv_path,
170 | vocab_path=vocab_path,
171 | img_dir=img_dir,
172 | max_len=max_len,
173 | tag_preprocess_fn=TagAugmentation(
174 | vocab_path=vocab_path,
175 | add=tagaug_add_max_ratio,
176 | drop=tagaug_drop_max_ratio
177 | ) if (tagaug_add_max_ratio or tagaug_drop_max_ratio) else None, # Only use when necessary
178 | img_preprocess_fn=train_transform
179 | )
180 | val_dataset = TagAndImageDataset(
181 | csv_path=val_csv_path,
182 | vocab_path=vocab_path,
183 | img_dir=img_dir,
184 | max_len=max_len,
185 | img_preprocess_fn=val_transform
186 | )
187 |
188 | train_dataloader = DataLoader(
189 | dataset=train_dataset,
190 | batch_size=train_batch_size,
191 | shuffle=True,
192 | num_workers=num_workers,
193 | pin_memory=pin_memory
194 | )
195 | val_dataloader = DataLoader(
196 | dataset=val_dataset,
197 | batch_size=val_batch_size,
198 | num_workers=num_workers,
199 | pin_memory=pin_memory
200 | )
201 |
202 | return train_dataloader, val_dataloader, train_dataset.vocab_size
203 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torchvision import models
7 |
8 |
9 | def freeze_all_parameters(module: nn.Module):
10 | ''' Freeze all parameters of a PyTorch Module.
11 | input:
12 | + module: self-explanatory.
13 | '''
14 | for param in module.parameters():
15 | param.requires_grad = False
16 |
17 |
18 | def unfreeze_all_parameters(module: nn.Module):
19 | ''' Unfreeze all parameters of a PyTorch Module.
20 | input:
21 | + module: self-explanatory.
22 | '''
23 | for param in module.parameters():
24 | param.requires_grad = True
25 |
26 |
27 | class TagEmbedder(nn.Module):
28 | def __init__(self, vocab_size, d_model):
29 | super(TagEmbedder, self).__init__()
30 | self.embed = nn.Embedding(
31 | num_embeddings=vocab_size+1, # Plus the padding
32 | embedding_dim=d_model,
33 | )
34 |
35 | def forward(self, x):
36 | return self.embed(x)
37 |
38 |
39 | class MultiHeadMaskedScaledDotProduct(nn.Module):
40 | def __init__(self, d_k: int):
41 | '''
42 | input:
43 | + d_k: the dimensionality of the subspace.
44 | '''
45 | super(MultiHeadMaskedScaledDotProduct, self).__init__()
46 |
47 | self.d_k = d_k
48 |
49 | def forward(
50 | self,
51 | q: torch.tensor,
52 | k: torch.tensor,
53 | mask: torch.tensor = None
54 | ) -> torch.tensor:
55 | '''
56 | input:
57 | + q: the matrix of queries.
58 | + k: the matrix of keys.
59 | + mask: used to mask out padding positions.
60 | output:
61 | The matrix of scores.
62 | '''
63 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
64 |
65 | if mask is not None:
66 | scores = scores.masked_fill(
67 | mask.unsqueeze(1).unsqueeze(2),
68 | float('-inf')
69 | )
70 |
71 | return scores
72 |
73 |
74 | class MultiHeadAttention(nn.Module):
75 | def __init__(self, heads: int, d_model: int, dropout: float = 0.1):
76 | '''
77 | input:
78 | + heads: the number of heads of each Multi-Head Attention layer.
79 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads.
80 | + dropout: dropout value of tag encoder layers.
81 | '''
82 | super(MultiHeadAttention, self).__init__()
83 |
84 | self.d_model = d_model
85 | self.d_k = d_model // heads
86 | self.h = heads
87 |
88 | self.q_linear = nn.Linear(d_model, d_model)
89 | self.v_linear = nn.Linear(d_model, d_model)
90 | self.k_linear = nn.Linear(d_model, d_model)
91 |
92 | self.dp = MultiHeadMaskedScaledDotProduct(self.d_k)
93 | self.softmax = nn.Softmax(-1)
94 | self.dropout = nn.Dropout(dropout)
95 | self.out = nn.Linear(d_model, d_model)
96 |
97 | def forward(
98 | self,
99 | q: torch.tensor,
100 | k: torch.tensor,
101 | v: torch.tensor,
102 | mask: torch.tensor = None
103 | ) -> torch.tensor:
104 | '''
105 | input:
106 | + q: the matrix of queries.
107 | + k: the matrix of keys.
108 | + v: the matrix of values.
109 | + mask: used to mask out padding positions.
110 | output:
111 | context vectors.
112 | '''
113 | bs = q.size(0)
114 |
115 | # Perform linear operation and split into N heads
116 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
117 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
118 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
119 |
120 | # Transpose to get dimensions bs * N * sl * d_model
121 | k = k.transpose(1, 2)
122 | q = q.transpose(1, 2)
123 | v = v.transpose(1, 2)
124 |
125 | scores = self.dp(q, k, mask)
126 | scores = self.softmax(scores)
127 |
128 | scores = self.dropout(scores)
129 |
130 | # Compute context vectors based on calculated scores above
131 | context = torch.matmul(scores, v)
132 |
133 | # Concatenate heads and put through final linear layer
134 | concat = context.transpose(1, 2).contiguous()\
135 | .view(bs, -1, self.d_model)
136 | output = self.out(concat)
137 |
138 | return output
139 |
140 |
141 | class TagToImageLayer(nn.Module):
142 | def __init__(self, d_model: int, heads: int, dropout: float = 0.1):
143 | '''
144 | input:
145 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads.
146 | + heads: the number of heads of the Multi-Head Attention sub-layer.
147 | + dropout: the dropout value of the Multi-Head Attention sub-layer.
148 | '''
149 | super(TagToImageLayer, self).__init__()
150 |
151 | self.attn = MultiHeadAttention(
152 | heads, d_model, dropout=dropout)
153 |
154 | def forward(self, tag_vectors: torch.tensor, img_regions: torch.tensor) -> torch.tensor:
155 | '''
156 | input:
157 | + tag_vectors: self-explanatory.
158 | + img_regions: self-explanatory.
159 | output:
160 | output vectors.
161 | '''
162 | out = self.attn(tag_vectors, img_regions, img_regions)
163 |
164 | return out
165 |
166 |
167 | class GatingLayer(nn.Module):
168 | def __init__(self, in_features: int, dim_feedforward: int, dropout: float = 0.1):
169 | '''
170 | input:
171 | + in_features: the dimentionality of the input vectors.
172 | + dim_feedforward: the dimentionality of the hidden layer.
173 | + dropout: the dropout value of the Gating layer.
174 | '''
175 | super(GatingLayer, self).__init__()
176 |
177 | self.dropout_1 = nn.Dropout(dropout)
178 | self.linear_1 = nn.Linear(in_features, dim_feedforward)
179 | self.relu = nn.ReLU()
180 | self.dropout_2 = nn.Dropout(dropout)
181 | self.linear_2 = nn.Linear(dim_feedforward, 1)
182 | self.sigmoid = nn.Sigmoid()
183 |
184 | def forward(self, tag_vectors: torch.tensor) -> torch.tensor:
185 | '''
186 | input:
187 | + tag_vectors: self-explanatory.
188 | output:
189 | output gating values.
190 | '''
191 | out = self.dropout_1(tag_vectors)
192 | out = self.linear_1(out)
193 | out = self.relu(out)
194 | out = self.dropout_2(out)
195 | out = self.linear_2(out)
196 | out = self.sigmoid(out.squeeze(dim=-1))
197 |
198 | return out
199 |
200 |
201 | class ImageFeatureExtractor(nn.Module):
202 | def __init__(self, d_model: int, img_backbone: str):
203 | '''
204 | input:
205 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads.
206 | '''
207 | super(ImageFeatureExtractor, self).__init__()
208 |
209 | if img_backbone == 'resnet18':
210 | encoder = models.resnet18(pretrained=True)
211 | ex_out_dim = 512
212 | elif img_backbone == 'resnet50':
213 | encoder = models.resnet50(pretrained=True)
214 | ex_out_dim = 2048
215 |
216 | # Get all layers
217 | encoder_children = list(encoder.children())
218 | # Drop the last avg & fc layers
219 | self.backbone = nn.Sequential(*encoder_children[:-2])
220 |
221 | self.conv_1x1 = nn.Conv2d(ex_out_dim, d_model, kernel_size=(
222 | 1, 1), stride=(1, 1), bias=True)
223 | self.bn = nn.BatchNorm2d(d_model)
224 | self.flatten = nn.Flatten(start_dim=1, end_dim=2)
225 |
226 | self.freeze_all_layers()
227 | self.unfreeze_top_layers()
228 | # self.unfreeze_the_fourth_block()
229 | # self.unfreeze_the_third_block()
230 | # self.unfreeze_the_second_block()
231 | # self.unfreeze_the_first_block()
232 | # self.unfreeze_the_bottom_layers()
233 |
234 | def freeze_all_layers(self):
235 | ''' Freeze all image encoder's layers.
236 | '''
237 | freeze_all_parameters(self)
238 |
239 | def unfreeze_top_layers(self):
240 | ''' Unfreeze the top layers of the image feature extractor.
241 | '''
242 | # conv_1x1
243 | unfreeze_all_parameters(self.conv_1x1)
244 |
245 | # bn
246 | unfreeze_all_parameters(self.bn)
247 |
248 | def unfreeze_the_first_block(self):
249 | ''' Unfreeze the first block of the image encoder.
250 | '''
251 | assert type(self.backbone[4]) is nn.Sequential
252 |
253 | unfreeze_all_parameters(self.backbone[4])
254 |
255 | def unfreeze_the_second_block(self):
256 | ''' Unfreeze the second block of the image encoder.
257 | '''
258 | assert type(self.backbone[5]) is nn.Sequential
259 |
260 | unfreeze_all_parameters(self.backbone[5])
261 |
262 | def unfreeze_the_third_block(self):
263 | ''' Unfreeze the third block of the image encoder.
264 | '''
265 | assert type(self.backbone[6]) is nn.Sequential
266 |
267 | unfreeze_all_parameters(self.backbone[6])
268 |
269 | def unfreeze_the_fourth_block(self):
270 | ''' Unfreeze the fourth block of the image encoder.
271 | '''
272 | assert type(self.backbone[7]) is nn.Sequential
273 |
274 | unfreeze_all_parameters(self.backbone[7])
275 |
276 | def unfreeze_the_bottom_layers(self):
277 | ''' Unfreeze the bottom layers of the image encoder.
278 | '''
279 | # Unfreeze the first conv layer and bn
280 | assert type(self.backbone[0]) is nn.Conv2d
281 | assert type(self.backbone[1]) is nn.BatchNorm2d
282 | assert type(self.backbone[2]) is nn.ReLU
283 | assert type(self.backbone[3]) is nn.MaxPool2d
284 |
285 | # conv
286 | unfreeze_all_parameters(self.backbone[0])
287 |
288 | # bn
289 | unfreeze_all_parameters(self.backbone[1])
290 |
291 | def forward(self, x: torch.tensor) -> torch.tensor:
292 | '''
293 | input:
294 | + x: input image.
295 | output:
296 | image's features.
297 | '''
298 | features = self.backbone(x)
299 |
300 | out = self.conv_1x1(features)
301 | out = self.bn(out)
302 |
303 | # Convert the tensor from channel first to channel last
304 | out = out.permute(0, 2, 3, 1)
305 |
306 | out = self.flatten(out)
307 |
308 | return out
309 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def dice_loss(
7 | input: torch.tensor,
8 | target: torch.tensor,
9 | beta: float = 1.,
10 | reduction: str = 'mean',
11 | smooth: float = 1.
12 | ) -> torch.tensor:
13 | intersection = input * target
14 | score = ((1. + beta**2) * torch.sum(intersection, dim=-1) + smooth) \
15 | / (torch.sum(input, dim=-1) + (beta**2) * torch.sum(target, dim=-1) + smooth)
16 |
17 | loss = 1. - score
18 |
19 | if reduction == 'mean':
20 | return loss.mean()
21 | elif reduction == 'sum':
22 | return loss.sum()
23 |
24 | return loss
25 |
26 |
27 | class DiceLoss(nn.Module):
28 | """
29 | Dice loss's implementation.
30 | """
31 |
32 | def __init__(self, reduction: str = 'mean', beta: float = 1., smooth: float = 1.):
33 | """
34 | input:
35 | + reduction: specifies the reduction to apply to the output:
36 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
37 | ``'mean'``: the sum of the output will be divided by the number of
38 | elements in the output, ``'sum'``: the output will be summed.
39 | + beta: β is chosen such that recall is considered β times as important as precision.
40 | + smooth: smooth value.
41 | """
42 | super(DiceLoss, self).__init__()
43 |
44 | assert beta >= 0, 'β must be a positive real value!'
45 | assert reduction in ['none', 'mean', 'sum']
46 |
47 | self.beta = beta
48 | self.reduction = reduction
49 | self.smooth = smooth
50 |
51 | def forward(self, input: torch.tensor, target: torch.tensor) -> torch.tensor:
52 | """
53 | input:
54 | + input: prediction.
55 | + target: ground-truth.
56 | output:
57 | + loss value.
58 | """
59 | return dice_loss(
60 | input,
61 | target,
62 | beta=self.beta,
63 | reduction=self.reduction,
64 | smooth=self.smooth
65 | )
66 |
67 |
68 | class BCEDiceLoss(nn.Module):
69 | """
70 | The combination of Binary Cross-Entropy & Dice losses.
71 | """
72 |
73 | def __init__(
74 | self,
75 | reduction: str = 'mean',
76 | weight: torch.tensor = None,
77 | beta: float = 1.,
78 | smooth: float = 1.0
79 | ):
80 | """
81 | input:
82 | + reduction: specifies the reduction to apply to the output:
83 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
84 | ``'mean'``: the sum of the output will be divided by the number of
85 | elements in the output, ``'sum'``: the output will be summed.
86 | + weight: a manual rescaling weight given to the loss
87 | of each batch element. If given, has to be a Tensor of size `nbatch`,
88 | used for BCE part.
89 | + beta: β is chosen such that recall is considered β times as important as precision.
90 | + smooth: smooth value, used for Dice part.
91 | """
92 | super(BCEDiceLoss, self).__init__()
93 |
94 | assert beta >= 0, 'β must be a positive real value!'
95 | assert reduction in ['none', 'mean', 'sum']
96 |
97 | self.reduction = reduction
98 |
99 | # BCE's params
100 | self.weight = weight
101 |
102 | # Dice's params
103 | self.beta = beta
104 | self.smooth = smooth
105 |
106 | def forward(self, input: torch.tensor, target: torch.tensor) -> torch.tensor:
107 | """
108 | input:
109 | + input: prediction.
110 | + target: ground-truth.
111 | output:
112 | + loss value.
113 | """
114 | bce = F.binary_cross_entropy(
115 | input,
116 | target,
117 | weight=self.weight,
118 | reduction=self.reduction
119 | )
120 | dice = dice_loss(
121 | input,
122 | target,
123 | beta=self.beta,
124 | reduction=self.reduction,
125 | smooth=self.smooth
126 | )
127 |
128 | return bce + dice
129 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 |
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class PrecisionRecallFk:
10 | """
11 | Calculate precision, recall, f1 for predictions.
12 | """
13 |
14 | def __init__(self, enable_logger=False, threshold=0.5, eps=1e-9):
15 | if enable_logger:
16 | global logger
17 | self.logger = logger
18 |
19 | self.threshold = threshold
20 | self.eps = eps
21 |
22 | def __call__(self,
23 | prediction: np.array,
24 | ground_truth: np.array,
25 | betas: list = [1],
26 | top_ks: list = None) -> (float, float, float):
27 | """
28 | compute F-beta score
29 |
30 | input:
31 | + prediction: predictions of model, np.array of shape [B, N] with B be the batchsize
32 | and N is the number of classes
33 | + ground_truth: self explanatory, must have the same shape as prediction
34 | + betas: a list of betas
35 | + top_ks: if specified, compute f_score of top_k most confident prediction
36 | output:
37 | + f_score: (1+beta**2) * precision*recall/(beta**2 * precision+recall)
38 | """
39 | if top_ks is None:
40 | return self.f_k_score(prediction,
41 | ground_truth,
42 | betas)
43 | else:
44 | return self.f_k_score_top(prediction,
45 | ground_truth,
46 | betas,
47 | top_ks)
48 |
49 | def f_k_score(self,
50 | prediction: np.array,
51 | ground_truth: np.array,
52 | betas: list = [1],
53 | threshold: float = None):
54 | """
55 | compute F-beta score
56 |
57 | input:
58 | + prediction: unthresholded output of the model, np.array of shape [B, N] with B be the
59 | batchsize and N is the number of classes
60 | + ground_truth: self explanatory, must have the same shape as prediction
61 | + betas: a list of betas
62 | output:
63 | + f_scores: (1+beta**2) * precision*recall/(beta**2 * precision+recall)
64 | """
65 | assert prediction.shape == ground_truth.shape
66 |
67 | if threshold is None:
68 | prediction = prediction >= self.threshold
69 | else:
70 | prediction = (prediction >= threshold)
71 |
72 | prediction = prediction.astype(int)
73 |
74 | ground_truth = ground_truth.reshape(prediction.shape)
75 | num_prediction = np.count_nonzero(prediction, axis=1)
76 | num_ground_truth = np.count_nonzero(ground_truth, axis=1)
77 |
78 | if hasattr(self, "logger"):
79 | self.logger.info(
80 | "Predictions per item: {}, Labels per item: {}".format(np.mean(num_prediction),
81 | np.mean(num_ground_truth))
82 | )
83 |
84 | num_true_positive_pred = np.count_nonzero(
85 | ground_truth & prediction, axis=1)
86 |
87 | precision = num_true_positive_pred/num_prediction + self.eps
88 | recall = num_true_positive_pred/num_ground_truth + self.eps
89 |
90 | f_scores = {}
91 | for beta in betas:
92 | beta_squared = beta ** 2
93 | f_score = np.nan_to_num(
94 | (1 + beta_squared)*precision*recall / (beta_squared * precision+recall))
95 | f_scores["F{}".format(beta)] = np.nanmean(f_score)
96 |
97 | if hasattr(self, "logger"):
98 | self.logger.info(
99 | "Can't give predictions to {} items".format(
100 | np.count_nonzero(np.isnan(precision)))
101 | )
102 |
103 | return {"precision": np.nanmean(precision), "recall": np.nanmean(recall), "f_score": f_scores}
104 |
105 | def f_k_score_top(self,
106 | prediction: np.array,
107 | ground_truth: np.array,
108 | betas: list,
109 | top_ks: list):
110 | """
111 | compute F-beta score
112 |
113 | input:
114 | + prediction: unthresholded output of the model, np.array of shape [B, N] with B be the
115 | batchsize and N is the number of classes
116 | + ground_truth: self explanatory, must have the same shape as prediction
117 | + betas: a list of betas
118 | + top_ks: list of top_ks to compute the f_score
119 | output:
120 | + f_scores: (1+beta**2) * precision*recall/(beta**2 * precision+recall)
121 | """
122 |
123 | assert len(top_ks) > 0, "please specify top_k"
124 |
125 | outputs = {}
126 |
127 | for top_k in top_ks:
128 | # compute threshold for every top_k
129 | k_indices = np.argsort(prediction)[:, ::-1][:, top_k - 1]
130 |
131 | k_thresh = prediction[[range(len(k_indices)), k_indices]]
132 | k_thresh = k_thresh[..., np.newaxis]
133 | outputs["top_{}".format(top_k)] = self.f_k_score(
134 | prediction, ground_truth, betas, k_thresh)
135 |
136 | return outputs
137 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from magneto.layers import (
5 | TagEmbedder,
6 | ImageFeatureExtractor,
7 | TagToImageLayer,
8 | GatingLayer
9 | )
10 |
11 |
12 | class MAGNeto(nn.Module):
13 | def __init__(
14 | self,
15 | d_model: int,
16 | vocab_size: int,
17 | t_blocks: int,
18 | t_heads: int,
19 | i_blocks: int,
20 | i_heads: int,
21 | dropout: float,
22 | t_dim_feedforward: int = 2048,
23 | i_dim_feedforward: int = 2048,
24 | g_dim_feedforward: int = 2048,
25 | img_backbone: str = 'resnet50',
26 | ):
27 | '''
28 | input:
29 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads.
30 | + vocab_size: self explanatory.
31 | + t_blocks: the number of encoder layers, or blocks, for tag branch.
32 | + t_heads: the number of heads of each Multi-Head Attention layer of the tag branch.
33 | + i_blocks: the number of encoder layers, or blocks, for image branch.
34 | + i_heads: the number of heads of each Multi-Head Attention layer of the image branch.
35 | + dropout: dropout value of the whole network.
36 | + t_dim_feedforward: the dimension of the feedforward network model in the TransformerEncoderLayer class of the tag branch.
37 | + i_dim_feedforward: the dimension of the feedforward network model in the TransformerEncoderLayer class of the image branch.
38 | + g_dim_feedforward: the dimension of the feedforward network model in the GatingLayer class.
39 | + img_backbone: resnet18 or resnet50.
40 | '''
41 | super(MAGNeto, self).__init__()
42 |
43 | self.tag_embedder = TagEmbedder(vocab_size, d_model)
44 | self.tag_dropout = nn.Dropout(dropout)
45 | self.tag_encoder = nn.TransformerEncoder(
46 | nn.TransformerEncoderLayer(
47 | d_model=d_model, nhead=t_heads, dim_feedforward=t_dim_feedforward, dropout=dropout),
48 | num_layers=t_blocks
49 | )
50 | self.tag_linear = nn.Linear(d_model, 1)
51 | self.tag_sigmoid = nn.Sigmoid()
52 |
53 | self.img_feature_extractor = ImageFeatureExtractor(
54 | d_model, img_backbone)
55 | self.tag_to_img = TagToImageLayer(d_model, i_heads, dropout)
56 | self.img_dropout = nn.Dropout(dropout)
57 | self.img_encoder = nn.TransformerEncoder(
58 | nn.TransformerEncoderLayer(
59 | d_model=d_model, nhead=i_heads, dim_feedforward=i_dim_feedforward, dropout=dropout),
60 | num_layers=i_blocks
61 | )
62 | self.img_linear = nn.Linear(d_model, 1)
63 | self.img_sigmoid = nn.Sigmoid()
64 |
65 | self.gating = GatingLayer(
66 | d_model * 2, dim_feedforward=g_dim_feedforward, dropout=dropout)
67 |
68 | def forward(self, src: torch.tensor, img: torch.tensor, mask: torch.tensor) \
69 | -> (torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor):
70 | '''
71 | input:
72 | + src: input vectors.
73 | + img: input image.
74 | + mask: used to mask out padding positions.
75 | output:
76 | prediction.
77 | '''
78 | tag_vectors = self.tag_dropout(self.tag_embedder(src))
79 | tag_out = self.tag_encoder(tag_vectors.permute(
80 | 1, 0, 2), src_key_padding_mask=mask)
81 | tag_out = torch.relu(tag_out.permute(1, 0, 2))
82 |
83 | img_regions = self.img_feature_extractor(img)
84 | attn_out = self.img_dropout(torch.relu(
85 | self.tag_to_img(tag_vectors, img_regions)
86 | ))
87 | img_out = self.img_encoder(attn_out.permute(
88 | 1, 0, 2), src_key_padding_mask=mask)
89 | img_out = torch.relu(img_out.permute(1, 0, 2))
90 |
91 | img_weight = self.gating(torch.cat((tag_out, img_out), dim=-1))
92 | tag_weight = 1 - img_weight
93 |
94 | tag_out = self.tag_sigmoid(self.tag_linear(tag_out).squeeze(dim=-1))
95 | img_out = self.img_sigmoid(self.img_linear(img_out).squeeze(dim=-1))
96 |
97 | out = tag_weight * tag_out + img_weight * img_out
98 |
99 | return out, tag_out, img_out, tag_weight, img_weight
100 |
--------------------------------------------------------------------------------
/MAGNeto/magneto/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import multiprocessing as mp
4 | import multiprocessing.pool as mpp
5 |
6 | import numpy as np
7 | import torch
8 | from torch import nn
9 | from torch import optim
10 | from torch.utils.data import DataLoader
11 | from torch.utils.tensorboard import SummaryWriter
12 | from tqdm import tqdm
13 |
14 | from magneto.loss import BCEDiceLoss
15 | from magneto.metrics import PrecisionRecallFk
16 |
17 |
18 | def istarmap(self, func, iterable, chunksize=1):
19 | ''' starmap-version of imap
20 | '''
21 | if self._state != mpp.RUN:
22 | raise ValueError("Pool not running")
23 |
24 | if chunksize < 1:
25 | raise ValueError(
26 | "Chunksize must be 1+, not {0:n}".format(
27 | chunksize))
28 |
29 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize)
30 | result = mpp.IMapIterator(self._cache)
31 | self._taskqueue.put(
32 | (
33 | self._guarded_task_generation(result._job,
34 | mpp.starmapstar,
35 | task_batches),
36 | result._set_length
37 | ))
38 | return (item for chunk in result for item in chunk)
39 |
40 |
41 | def moving_avg(avg, update, alpha):
42 | return (alpha * avg) + ((1 - alpha) * update)
43 |
44 |
45 | def parse_train_args() -> argparse.Namespace:
46 | '''
47 | output:
48 | parsed arguments.
49 | '''
50 | parser = argparse.ArgumentParser(description='MAGNeto training process.')
51 | parser.add_argument(
52 | '--train-csv-path',
53 | type=str,
54 | help='[/path/to/train_data.csv]',
55 | required=True
56 | )
57 | parser.add_argument(
58 | '--val-csv-path',
59 | type=str,
60 | help='[/path/to/val_data.csv]',
61 | required=True
62 | )
63 | parser.add_argument(
64 | '--vocab-path',
65 | type=str,
66 | help='[/path/to/vocab.csv]',
67 | required=True
68 | )
69 | parser.add_argument(
70 | '--img-dir',
71 | type=str,
72 | help='[/path/to/img_dir]',
73 | required=False
74 | )
75 | parser.add_argument(
76 | '--save-dir',
77 | type=str,
78 | help='[/path/to/save_dir]',
79 | required=True
80 | )
81 | parser.add_argument(
82 | '--checkpoint-path',
83 | type=str,
84 | help='[/path/to/checkpoint.pth]'
85 | )
86 | parser.add_argument(
87 | '--load-weights-only',
88 | action='store_true',
89 | help='Only does load model\'s weights from checkpoint.'
90 | )
91 | parser.add_argument(
92 | '--exclude-top',
93 | action='store_true',
94 | help='Whether excluding top layers or not when loading checkpoint.'
95 | )
96 | parser.add_argument(
97 | '--start-from-epoch',
98 | type=int,
99 | help='(default: "0".)',
100 | default=0
101 | )
102 | parser.add_argument(
103 | '--max-len',
104 | type=int,
105 | help='The maximum length for each set of tags (default: 100).',
106 | default=100
107 | )
108 | parser.add_argument(
109 | '--t-heads',
110 | type=int,
111 | help='The number of heads of each Multi-Head Attention layer of the tag branch (default: 8).',
112 | default=8
113 | )
114 | parser.add_argument(
115 | '--t-blocks',
116 | type=int,
117 | help='The number of encoder layers, or blocks, for tag branch (default: 6).',
118 | default=6
119 | )
120 | parser.add_argument(
121 | '--t-dim-feedforward',
122 | type=int,
123 | help='The dimension of the feedforward network model in the TransformerEncoderLayer class of the tag branch, (default: 2048).',
124 | default=2048
125 | )
126 | parser.add_argument(
127 | '--i-heads',
128 | type=int,
129 | help='The number of heads of each Multi-Head Attention layer of the image branch (default: 8).',
130 | default=8
131 | )
132 | parser.add_argument(
133 | '--i-blocks',
134 | type=int,
135 | help='The number of encoder layers, or blocks, for image branch (default: 6).',
136 | default=2
137 | )
138 | parser.add_argument(
139 | '--i-dim-feedforward',
140 | type=int,
141 | help='The dimension of the feedforward network model in the TransformerEncoderLayer of the image branch class, (default: 2048).',
142 | default=2048
143 | )
144 | parser.add_argument(
145 | '--d-model',
146 | type=int,
147 | help='The dimentionality of a context vector, must be divisible by the number of heads, (default: 512).',
148 | default=512
149 | )
150 | parser.add_argument(
151 | '--img-backbone',
152 | type=str,
153 | help='resnet18 or resnet50, (default: resnet50).',
154 | default='resnet50'
155 | )
156 | parser.add_argument(
157 | '--g-dim-feedforward',
158 | type=int,
159 | help='The dimension of the feedforward network model in the GatingLayer class, (default: 2048).',
160 | default=2048
161 | )
162 | parser.add_argument(
163 | '--dropout',
164 | type=float,
165 | help='Dropout value of tag encoder layers (default: 0.1).',
166 | default=0.1
167 | )
168 | parser.add_argument(
169 | '--tagaug-add-max-ratio',
170 | type=float,
171 | help='The maximum ratio between the number of adding tags and non-important ones, (default: 0.3).',
172 | default=0.3
173 | )
174 | parser.add_argument(
175 | '--tagaug-drop-max-ratio',
176 | type=float,
177 | help='The maximum ratio between the number of dropping tags and non-important ones, (default: 0.3).',
178 | default=0.3
179 | )
180 | parser.add_argument(
181 | '--train-batch-size',
182 | type=int,
183 | help='The batch size used in the training process (default: 64).',
184 | default=64
185 | )
186 | parser.add_argument(
187 | '--val-batch-size',
188 | type=int,
189 | help='The batch size used in the validation process (default: 128).',
190 | default=128
191 | )
192 | parser.add_argument(
193 | '--num-workers',
194 | type=int,
195 | help='The number of workers used for data loaders, \
196 | -1 means using all available processors, \
197 | rules of thumb: num_workers ~ num_gpu * 4, \
198 | (default: 4).',
199 | default=4
200 | )
201 | parser.add_argument(
202 | '--epochs',
203 | type=int,
204 | required=True
205 | )
206 | parser.add_argument(
207 | '--lr',
208 | type=float,
209 | help='(default: "3e-2".)',
210 | default=3e-2
211 | )
212 | parser.add_argument(
213 | '--threshold',
214 | type=float,
215 | help='(default: "0.5".)',
216 | default=0.5
217 | )
218 | parser.add_argument(
219 | '--no-cuda',
220 | action='store_true'
221 | )
222 | parser.add_argument(
223 | '--gpu-id',
224 | type=int,
225 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).',
226 | default=0
227 | )
228 | parser.add_argument(
229 | '--use-steplr-scheduler',
230 | action='store_true',
231 | help='Whether or not to use StepLR scheduler. \
232 | all other schedulers should should be disabled.'
233 | )
234 | parser.add_argument(
235 | '--sl-gamma',
236 | type=float,
237 | help='StepLR-scheduler\'s multiplicative factor of learning rate decay (default: 0.9).',
238 | default=0.9
239 | )
240 | parser.add_argument(
241 | '--use-rop-scheduler',
242 | action='store_true',
243 | help='Whether or not to use ReduceLROnPlateau scheduler. \
244 | all other schedulers should should be disabled.'
245 | )
246 | parser.add_argument(
247 | '--rop-factor',
248 | type=float,
249 | help='ReduceLROnPlateau scheduler\'s factor parameter (default: 0.3).',
250 | default=0.3
251 | )
252 | parser.add_argument(
253 | '--rop-patience',
254 | type=int,
255 | help='ReduceLROnPlateau scheduler\'s patience parameter (default: 3).',
256 | default=3
257 | )
258 | parser.add_argument(
259 | '--log-graph',
260 | action='store_true',
261 | help='Write down model graph.'
262 | )
263 | parser.add_argument(
264 | '--save-latest',
265 | action='store_true',
266 | help='Save the latest checkpoint.'
267 | )
268 | parser.add_argument(
269 | '--save-best-f1',
270 | action='store_true',
271 | help='Save the checkpoint based on val F1.'
272 | )
273 | parser.add_argument(
274 | '--save-best-loss',
275 | action='store_true',
276 | help='Save the checkpoint based on val loss.'
277 | )
278 | parser.add_argument(
279 | '--save-all-epochs',
280 | action='store_true',
281 | help='Save a checkpoint for each epoch.'
282 | )
283 | parser.add_argument(
284 | '--log-weight-hist',
285 | action='store_true',
286 | help='Log the histogram of image and tag weights during the validation process.'
287 | )
288 |
289 | opt = parser.parse_args()
290 |
291 | # Check configuration
292 | assert not (opt.use_steplr_scheduler and opt.use_rop_scheduler), \
293 | 'Cannot use multiple schedulers at the same time!'
294 |
295 | if opt.num_workers == -1:
296 | opt.num_workers = mp.cpu_count()
297 |
298 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu'
299 | if not opt.no_cuda:
300 | assert torch.cuda.is_available()
301 |
302 | if not os.path.exists(opt.save_dir):
303 | os.makedirs(opt.save_dir)
304 |
305 | assert os.path.isfile(opt.train_csv_path)
306 | assert os.path.isfile(opt.val_csv_path)
307 | assert os.path.exists(opt.img_dir)
308 |
309 | return opt
310 |
311 |
312 | def parse_infer_args() -> argparse.Namespace:
313 | '''
314 | output:
315 | parsed arguments.
316 | '''
317 | parser = argparse.ArgumentParser(description='Inference module.')
318 | parser.add_argument(
319 | '--csv-path',
320 | type=str,
321 | help='[/path/to/data.csv]',
322 | required=True
323 | )
324 | parser.add_argument(
325 | '--img-dir',
326 | type=str,
327 | help='[/path/to/img_dir]',
328 | required=True
329 | )
330 | parser.add_argument(
331 | '--vocab-path',
332 | type=str,
333 | help='[/path/to/vocab.csv]',
334 | required=True
335 | )
336 | parser.add_argument(
337 | '--model-path',
338 | type=str,
339 | help='[/path/to/model.pth]',
340 | required=True
341 | )
342 | parser.add_argument(
343 | '--has-label',
344 | action='store_true'
345 | )
346 | parser.add_argument(
347 | '--batch-size',
348 | type=int,
349 | help='The batch size used in the inference process (default: 64).',
350 | default=64
351 | )
352 | parser.add_argument(
353 | '--num-workers',
354 | type=int,
355 | help='The number of workers used for data loaders, \
356 | -1 means using all available processors, \
357 | rules of thumb: num_workers ~ num_gpu * 4, \
358 | (default: 4).',
359 | default=4
360 | )
361 | parser.add_argument(
362 | '--threshold',
363 | type=float,
364 | help='(default: "0.5".)',
365 | default=0.5
366 | )
367 | parser.add_argument(
368 | '--top',
369 | type=int,
370 | help='The minimum number of selected important tags for each item (default: 5).',
371 | default=5
372 | )
373 | parser.add_argument(
374 | '--no-cuda',
375 | action='store_true'
376 | )
377 | parser.add_argument(
378 | '--gpu-id',
379 | type=int,
380 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).',
381 | default=0
382 | )
383 | parser.add_argument(
384 | '-m',
385 | '--use-multiprocessing',
386 | action='store_true',
387 | help='Activate multiprocessing.'
388 | )
389 |
390 | opt = parser.parse_args()
391 |
392 | # Check configuration
393 | if opt.num_workers == -1:
394 | opt.num_workers = mp.cpu_count()
395 |
396 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu'
397 | if not opt.no_cuda:
398 | assert torch.cuda.is_available()
399 |
400 | return opt
401 |
402 |
403 | def parse_preprocessing_args() -> argparse.Namespace:
404 | '''
405 | output:
406 | parsed arguments.
407 | '''
408 | parser = argparse.ArgumentParser(
409 | description='Generating labels for "tags" and "important_tags" pairs.')
410 | parser.add_argument(
411 | '-c',
412 | '--csv-path',
413 | type=str,
414 | help='/path/to/raw_data.csv',
415 | required=True
416 | )
417 | parser.add_argument(
418 | '-s',
419 | '--save-path',
420 | type=str,
421 | help='/path/to/result.csv',
422 | default='./result.csv'
423 | )
424 | parser.add_argument(
425 | '-tt',
426 | '--tags-field-type',
427 | type=str,
428 | help='str or list (default: str).',
429 | default='str'
430 | )
431 | parser.add_argument(
432 | '-it',
433 | '--important-tags-field-type',
434 | type=str,
435 | help='str or list (default: str).',
436 | default='str'
437 | )
438 | parser.add_argument(
439 | '-m',
440 | '--use-multiprocessing',
441 | action='store_true',
442 | help='Activate multiprocessing.'
443 | )
444 | parser.add_argument(
445 | '--num-workers',
446 | type=int,
447 | help='The number of workers used for data loaders, -1 means using all available processors, (default: -1).',
448 | default=-1
449 | )
450 |
451 | return parser.parse_args()
452 |
453 |
454 | def parse_pseudo_label_args() -> argparse.Namespace:
455 | '''
456 | output:
457 | parsed arguments.
458 | '''
459 | parser = argparse.ArgumentParser(description='Pseudo labeling module.')
460 | parser.add_argument(
461 | '--csv-path',
462 | type=str,
463 | help='[/path/to/data.csv]',
464 | required=True
465 | )
466 | parser.add_argument(
467 | '--img-dir',
468 | type=str,
469 | help='[/path/to/img_dir]',
470 | required=True
471 | )
472 | parser.add_argument(
473 | '--model-path',
474 | type=str,
475 | help='[/path/to/model.pth]',
476 | required=True
477 | )
478 | parser.add_argument(
479 | '--save-path',
480 | type=str,
481 | help='[/path/to/result.csv]',
482 | required=True
483 | )
484 | parser.add_argument(
485 | '--item-id-field',
486 | type=str,
487 | help='(default: "item_id".)',
488 | default='item_id'
489 | )
490 | parser.add_argument(
491 | '--tags-field',
492 | type=str,
493 | help='(default: "tags".)',
494 | default='tags'
495 | )
496 | parser.add_argument(
497 | '--batch-size',
498 | type=int,
499 | help='The batch size used in the inference process (default: 64).',
500 | default=64
501 | )
502 | parser.add_argument(
503 | '--num-workers',
504 | type=int,
505 | help='The number of workers used for data loaders, \
506 | -1 means using all available processors, \
507 | rules of thumb: num_workers ~ num_gpu * 4, \
508 | (default: 4).',
509 | default=4
510 | )
511 | parser.add_argument(
512 | '--threshold',
513 | type=float,
514 | help='(default: "0.5".)',
515 | default=0.5
516 | )
517 | parser.add_argument(
518 | '--pos-threshold',
519 | type=float,
520 | help='The threshold used to classify an item into positive or non-positive class, \
521 | an item with a score higher than the threshold will be considered a positive sample (default: 0.95).',
522 | default=0.95
523 | )
524 | parser.add_argument(
525 | '--neg-threshold',
526 | type=float,
527 | help='The threshold used to classify an item into negative or non-negative class, \
528 | an item with a score lower than the threshold will be considered a negative sample (default: 0.05).',
529 | default=0.05
530 | )
531 | parser.add_argument(
532 | '--max-ratio',
533 | type=float,
534 | help='The maximum value for the ratio of the number of the confident tags to the number of all tags (default: 0.05).',
535 | default=0.05
536 | )
537 | parser.add_argument(
538 | '--min-positive',
539 | type=int,
540 | help='The minimum number of positive tags in each item (default: 0).',
541 | default=0
542 | )
543 | parser.add_argument(
544 | '--no-cuda',
545 | action='store_true'
546 | )
547 | parser.add_argument(
548 | '--gpu-id',
549 | type=int,
550 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).',
551 | default=0
552 | )
553 |
554 | opt = parser.parse_args()
555 |
556 | # Check configuration
557 | if opt.num_workers == -1:
558 | opt.num_workers = mp.cpu_count()
559 |
560 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu'
561 | if not opt.no_cuda:
562 | assert torch.cuda.is_available()
563 |
564 | return opt
565 |
566 |
567 | class TensorBoardWriter(object):
568 | def __init__(self, log_dir: str, purge_step: int = 0):
569 | self.log_dir = log_dir
570 | self.purge_step = purge_step
571 |
572 | def __enter__(self):
573 | self.writer = SummaryWriter(
574 | log_dir=self.log_dir,
575 | purge_step=self.purge_step
576 | )
577 |
578 | return self.writer
579 |
580 | def __exit__(self, type, value, traceback):
581 | self.writer.close()
582 |
583 |
584 | class Trainer(object):
585 | def __init__(
586 | self,
587 | model: nn.Module,
588 | optimizer: optim.Optimizer,
589 | opt: argparse.Namespace
590 | ):
591 | self.model = model
592 | self.optimizer = optimizer
593 | self.opt = opt
594 |
595 | self.start_from_epoch = self.opt.start_from_epoch
596 | self.stop_at_epoch = self.opt.start_from_epoch + self.opt.epochs
597 | self.log_dir = './runs/{0}'.format(
598 | "_".join(self.opt.save_dir.split("/")[-1].split(".")))
599 |
600 | self.criterion = {
601 | 'both': BCEDiceLoss(beta=1.0),
602 | 'tag': BCEDiceLoss(beta=1.0),
603 | 'img': BCEDiceLoss(beta=1.0)
604 | }
605 |
606 | # Initialize monitoring params
607 | self.best_val_loss = np.inf
608 | self.best_val_f1 = 0
609 | self.best_val_precision = 0
610 | self.best_val_recall = 0
611 | self.alpha = 0.9 # Mean over 10 iters
612 |
613 | self.fk_eval = PrecisionRecallFk(
614 | enable_logger=False, threshold=self.opt.threshold)
615 |
616 | if self.opt.use_rop_scheduler:
617 | self.rop_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
618 | optimizer=self.optimizer,
619 | factor=self.opt.rop_factor,
620 | patience=self.opt.rop_patience,
621 | min_lr=1e-7,
622 | verbose=True
623 | )
624 | elif self.opt.use_steplr_scheduler:
625 | self.steplr_scheduler = optim.lr_scheduler.StepLR(
626 | optimizer=self.optimizer,
627 | step_size=1,
628 | gamma=self.opt.sl_gamma
629 | )
630 |
631 | if self.opt.checkpoint_path is not None:
632 | self._load_checkpoint()
633 |
634 | def fit(
635 | self,
636 | train_dataloader: DataLoader,
637 | val_dataloader: DataLoader
638 | ):
639 | with TensorBoardWriter(self.log_dir, purge_step=self.start_from_epoch) as writer:
640 | if self.opt.log_graph:
641 | self._log_graph(train_dataloader, writer)
642 |
643 | print('\nTraining model...')
644 | for epoch in range(self.start_from_epoch, self.stop_at_epoch):
645 | self._fit_an_epoch(
646 | train_dataloader, val_dataloader, writer, epoch)
647 |
648 | def _log_graph(self, dataloader, writer):
649 | image_batch, tags_batch, mask_batch, _ = next(
650 | iter(dataloader))
651 | image_batch = image_batch.to(self.opt.device)
652 | tags_batch = tags_batch.to(self.opt.device)
653 | mask_batch = mask_batch.to(self.opt.device)
654 | writer.add_graph(self.model, (tags_batch, image_batch, mask_batch))
655 |
656 | def _load_checkpoint(self):
657 | assert os.path.isfile(self.opt.checkpoint_path)
658 |
659 | print('\nLoading checkpoint...')
660 | states = torch.load(self.opt.checkpoint_path,
661 | map_location=lambda storage, loc: storage)
662 | print('|`-- Loading model...')
663 | print('+--------------------')
664 | model_dict = self.model.state_dict()
665 | excluding_layers = [
666 | 'img_linear.weight',
667 | 'img_linear.bias',
668 | 'tag_linear.weight',
669 | 'tag_linear.bias',
670 | 'gating.linear_1.weight',
671 | 'gating.linear_1.bias',
672 | 'gating.linear_2.weight',
673 | 'gating.linear_2.bias'
674 | ] if self.opt.exclude_top else []
675 | pretrained_dict = {k: v for k, v in states['model'].items()
676 | if k in model_dict and k not in excluding_layers}
677 | model_dict.update(pretrained_dict)
678 | self.model.load_state_dict(model_dict)
679 | if not self.opt.load_weights_only:
680 | print('|`-- Loading optimizer...')
681 | self.optimizer.load_state_dict(states['optimizer'])
682 | print('|`-- Loading best val loss...')
683 | self.best_val_loss = states['best_val_loss']
684 | print('|`-- Loading best val f1...')
685 | self.best_val_f1 = states['best_val_f1']
686 | print('|`-- Loading best val precision...')
687 | self.best_val_precision = states['best_val_precision']
688 | print(' `-- Loading best val recall...')
689 | self.best_val_recall = states['best_val_recall']
690 |
691 | def _save_checkpoint(
692 | self,
693 | new_loss,
694 | new_f1,
695 | new_precision,
696 | new_recall,
697 | epoch
698 | ):
699 | found_better_val_loss = new_loss < self.best_val_loss
700 | found_better_val_f1 = new_f1 > self.best_val_f1
701 |
702 | self.best_val_loss = np.minimum(
703 | self.best_val_loss, new_loss)
704 | self.best_val_f1 = np.maximum(
705 | self.best_val_f1, new_f1)
706 | self.best_val_precision = np.maximum(
707 | self.best_val_precision, new_precision)
708 | self.best_val_recall = np.maximum(
709 | self.best_val_recall, new_recall)
710 |
711 | states = {
712 | 'model': self.model.state_dict(),
713 | 'optimizer': self.optimizer.state_dict(),
714 | 'best_val_loss': self.best_val_loss,
715 | 'best_val_f1': self.best_val_f1,
716 | 'best_val_precision': self.best_val_precision,
717 | 'best_val_recall': self.best_val_recall,
718 | 'config': vars(self.opt)
719 | }
720 |
721 | if self.opt.save_best_loss and found_better_val_loss:
722 | print(' \__ Found a better checkpoint based on val loss -> Saving...')
723 | torch.save(states, os.path.join(
724 | self.opt.save_dir, 'best_loss.pth'))
725 |
726 | if self.opt.save_best_f1 and found_better_val_f1:
727 | print(' \__ Found a better checkpoint based on val F1 -> Saving...')
728 | torch.save(states, os.path.join(self.opt.save_dir, 'best_f1.pth'))
729 |
730 | if self.opt.save_latest:
731 | torch.save(states, os.path.join(self.opt.save_dir, 'latest.pth'))
732 |
733 | if self.opt.save_all_epochs:
734 | torch.save(states, os.path.join(
735 | self.opt.save_dir, 'epoch_{0}.pth'.format(epoch+1)))
736 |
737 | def _compute_running_precision_recall_f1(
738 | self,
739 | pred,
740 | label,
741 | running_precision,
742 | running_recall,
743 | running_f1
744 | ):
745 | fk_eval_dict = self.fk_eval(pred, label, betas=[1])
746 | running_precision = moving_avg(
747 | running_precision, np.nan_to_num(fk_eval_dict['precision']), self.alpha)
748 | running_recall = moving_avg(
749 | running_recall, np.nan_to_num(fk_eval_dict['recall']), self.alpha)
750 | running_f1 = moving_avg(
751 | running_f1, np.nan_to_num(fk_eval_dict['f_score']['F1']), self.alpha)
752 |
753 | return running_precision, running_recall, running_f1
754 |
755 | def _compute_batch_precision_recall_f1(
756 | self,
757 | pred,
758 | label,
759 | batch_val_idx,
760 | local_batch_size,
761 | batch_val_precision,
762 | batch_val_recall,
763 | batch_val_f1
764 | ):
765 | fk_eval_dict = self.fk_eval(pred, label, betas=[1])
766 | batch_val_precision[batch_val_idx] = np.nan_to_num(
767 | fk_eval_dict['precision']) * local_batch_size
768 | batch_val_recall[batch_val_idx] = np.nan_to_num(
769 | fk_eval_dict['recall']) * local_batch_size
770 | batch_val_f1[batch_val_idx] = np.nan_to_num(
771 | fk_eval_dict['f_score']['F1']) * local_batch_size
772 |
773 | return batch_val_precision, batch_val_recall, batch_val_f1
774 |
775 | def _fit_an_epoch(self, train_dataloader, val_dataloader, writer, epoch):
776 | # Training process
777 | self.model.train()
778 |
779 | # Initialize a dictionary to store numeric values
780 | running = {
781 | 'loss': {
782 | 'both': 0,
783 | 'tag': 0,
784 | 'img': 0,
785 | 'sum': 0
786 | },
787 | 'f1': {
788 | 'both': 0,
789 | 'tag': 0,
790 | 'img': 0
791 | },
792 | 'precision': {
793 | 'both': 0,
794 | 'tag': 0,
795 | 'img': 0
796 | },
797 | 'recall': {
798 | 'both': 0,
799 | 'tag': 0,
800 | 'img': 0
801 | },
802 | 'weight': {
803 | 'tag': 0,
804 | 'img': 0
805 | }
806 | }
807 |
808 | train_pbar = tqdm(train_dataloader)
809 | train_pbar.desc = '* Epoch {0}'.format(epoch+1)
810 |
811 | for batch_idx, (image_batch, tags_batch, mask_batch, label_batch) in enumerate(train_pbar):
812 | image_batch = image_batch.to(self.opt.device)
813 | tags_batch = tags_batch.to(self.opt.device)
814 | label_batch = label_batch.to(self.opt.device)
815 | mask_batch = mask_batch.to(self.opt.device)
816 |
817 | preds = dict()
818 | weight = dict()
819 | preds['both'], preds['tag'], preds['img'], weight['tag'], weight['img'] = \
820 | self.model(tags_batch, image_batch, mask_batch)
821 |
822 | for key in preds.keys():
823 | preds[key] = preds[key].masked_fill(
824 | mask_batch,
825 | 0.0
826 | )
827 |
828 | loss = dict()
829 | for key in preds.keys():
830 | loss[key] = self.criterion[key](preds[key], label_batch)
831 | loss['sum'] = loss['both'] + loss['tag'] + loss['img']
832 |
833 | self.optimizer.zero_grad()
834 | loss['sum'].backward()
835 | self.optimizer.step()
836 |
837 | processed_label = label_batch.detach().cpu().numpy().astype(np.uint8)
838 | processed_pred = dict()
839 | for key in preds.keys():
840 | processed_pred[key] = preds[key].detach().cpu().numpy()
841 |
842 | # Compute running losses
843 | for key in running['loss'].keys():
844 | running['loss'][key] = moving_avg(
845 | running['loss'][key], loss[key].item(), self.alpha)
846 |
847 | # Compute running weights
848 | for key in running['weight'].keys():
849 | running['weight'][key] = moving_avg(
850 | running['weight'][key], weight[key].mean().item(), self.alpha)
851 |
852 | # Compute running precision, recall and f1
853 | for key in processed_pred.keys():
854 | running['precision'][key], running['recall'][key], running['f1'][key] = \
855 | self._compute_running_precision_recall_f1(
856 | processed_pred[key],
857 | processed_label,
858 | running['precision'][key],
859 | running['recall'][key],
860 | running['f1'][key]
861 | )
862 |
863 | train_pbar.set_postfix({
864 | 'loss': running['loss']['both'],
865 | 'f1': running['f1']['both'],
866 | 'prec': running['precision']['both'],
867 | 'recall': running['recall']['both'],
868 | })
869 |
870 | # Log to TensorBoard
871 | for key in running.keys():
872 | for subkey in running[key]:
873 | writer.add_scalar('{0}/train_{1}'.format(key, subkey),
874 | running[key][subkey], epoch)
875 |
876 | # Validation process
877 | self.model.eval()
878 |
879 | with torch.no_grad():
880 | # Initialize a dictionary to store 1d arrays
881 | batch_val = {
882 | 'loss': {
883 | 'both': np.zeros(len(val_dataloader)),
884 | 'tag': np.zeros(len(val_dataloader)),
885 | 'img': np.zeros(len(val_dataloader)),
886 | 'sum': np.zeros(len(val_dataloader))
887 | },
888 | 'f1': {
889 | 'both': np.zeros(len(val_dataloader)),
890 | 'tag': np.zeros(len(val_dataloader)),
891 | 'img': np.zeros(len(val_dataloader))
892 | },
893 | 'precision': {
894 | 'both': np.zeros(len(val_dataloader)),
895 | 'tag': np.zeros(len(val_dataloader)),
896 | 'img': np.zeros(len(val_dataloader))
897 | },
898 | 'recall': {
899 | 'both': np.zeros(len(val_dataloader)),
900 | 'tag': np.zeros(len(val_dataloader)),
901 | 'img': np.zeros(len(val_dataloader))
902 | },
903 | 'weight': {
904 | 'tag': np.zeros(len(val_dataloader)),
905 | 'img': np.zeros(len(val_dataloader))
906 | },
907 | }
908 |
909 | if self.opt.log_weight_hist:
910 | all_val_weights = {
911 | 'tag': [],
912 | 'img': []
913 | }
914 |
915 | num_items = 0
916 |
917 | val_pbar = tqdm(val_dataloader)
918 | val_pbar.desc = '\__ Validating'
919 |
920 | for batch_val_idx, (image_batch, tags_batch, mask_batch, label_batch) in enumerate(val_pbar):
921 | image_batch = image_batch.to(self.opt.device)
922 | tags_batch = tags_batch.to(self.opt.device)
923 | label_batch = label_batch.to(self.opt.device)
924 | mask_batch = mask_batch.to(self.opt.device)
925 |
926 | preds = dict()
927 | weight = dict()
928 | preds['both'], preds['tag'], preds['img'], weight['tag'], weight['img'] = \
929 | self.model(tags_batch, image_batch, mask_batch)
930 |
931 | for key in preds.keys():
932 | preds[key] = preds[key].masked_fill(
933 | mask_batch,
934 | 0.0
935 | )
936 |
937 | val_loss = dict()
938 | for key in preds.keys():
939 | val_loss[key] = self.criterion[key](
940 | preds[key], label_batch)
941 | val_loss['sum'] = val_loss['both'] + \
942 | val_loss['tag'] + val_loss['img']
943 |
944 | processed_label = label_batch.detach().cpu().numpy().astype(np.uint8)
945 | processed_pred = dict()
946 | for key in preds.keys():
947 | processed_pred[key] = preds[key].detach().cpu().numpy()
948 |
949 | local_batch_size = label_batch.size(0)
950 | num_items += local_batch_size
951 |
952 | # Compute sum batch losses
953 | for key in batch_val['loss'].keys():
954 | batch_val['loss'][key][batch_val_idx] = val_loss[key].item(
955 | ) * local_batch_size
956 |
957 | # Compute sum batch weights
958 | for key in batch_val['weight'].keys():
959 | batch_val['weight'][key][batch_val_idx] = weight[key].mean().item(
960 | ) * local_batch_size
961 |
962 | # Keep all available weight values (if needed)
963 | if self.opt.log_weight_hist:
964 | for key in batch_val['weight'].keys():
965 | all_val_weights[key].extend(
966 | weight[key].detach().cpu().numpy().reshape(-1))
967 |
968 | # Compute sum batch precision, recall and f1
969 | for key in processed_pred.keys():
970 | batch_val['precision'][key], batch_val['recall'][key], batch_val['f1'][key] = \
971 | self._compute_batch_precision_recall_f1(
972 | processed_pred[key],
973 | processed_label,
974 | batch_val_idx,
975 | local_batch_size,
976 | batch_val['precision'][key],
977 | batch_val['recall'][key],
978 | batch_val['f1'][key]
979 | )
980 |
981 | val_pbar.set_postfix({
982 | 'loss': batch_val['loss']['both'][batch_val_idx] / local_batch_size,
983 | 'f1': batch_val['f1']['both'][batch_val_idx] / local_batch_size,
984 | 'prec': batch_val['precision']['both'][batch_val_idx] / local_batch_size,
985 | 'recall': batch_val['recall']['both'][batch_val_idx] / local_batch_size,
986 | })
987 |
988 | mean_val = dict()
989 | for key in batch_val.keys():
990 | mean_val[key] = dict()
991 | for subkey in batch_val[key].keys():
992 | mean_val[key][subkey] = np.sum(
993 | batch_val[key][subkey]) / num_items
994 |
995 | if self.opt.use_rop_scheduler:
996 | self.rop_scheduler.step(mean_val['loss']['sum'])
997 | elif self.opt.use_steplr_scheduler:
998 | self.steplr_scheduler.step()
999 |
1000 | # Log to TensorBoard
1001 | for key in mean_val.keys():
1002 | for subkey in mean_val[key].keys():
1003 | writer.add_scalar('{0}/val_{1}'.format(key, subkey),
1004 | mean_val[key][subkey], epoch)
1005 | if self.opt.log_weight_hist:
1006 | for key in all_val_weights.keys():
1007 | writer.add_histogram(
1008 | 'weight/{0}'.format(key), np.array(all_val_weights[key]), epoch)
1009 |
1010 | # Save checkpoint
1011 | self._save_checkpoint(
1012 | new_loss=mean_val['loss']['both'],
1013 | new_f1=mean_val['f1']['both'],
1014 | new_precision=mean_val['precision']['both'],
1015 | new_recall=mean_val['recall']['both'],
1016 | epoch=epoch
1017 | )
1018 |
--------------------------------------------------------------------------------
/MAGNeto/preprocess.py:
--------------------------------------------------------------------------------
1 | import ast
2 |
3 | import pandas as pd
4 | from tqdm import tqdm
5 |
6 | from magneto.utils import parse_preprocessing_args
7 |
8 |
9 | def make_label(tags, important_tags) -> list:
10 | '''
11 | input:
12 | + tags: all available tags of an item.
13 | + important_tags: tags that marked as important.
14 | output:
15 | a binary mask with 0 for unimportant tags and 1 for important ones.
16 | '''
17 | return ['1' if tag in important_tags else '0' for tag in tags]
18 |
19 |
20 | def label_important_tags(
21 | item_id,
22 | tags,
23 | important_tags
24 | ) -> dict:
25 | '''
26 | input:
27 | + item_id: the ID of an item.
28 | + tags: all available tags of an item.
29 | + important_tags: tags that marked as important.
30 | output:
31 | a dictionary which includes all needed information of an item.
32 | '''
33 | label = make_label(tags, important_tags)
34 |
35 | return {
36 | 'item_id': item_id,
37 | 'tags': ','.join(tags),
38 | 'important_tags': ','.join(important_tags),
39 | 'label': ','.join(label)
40 | }
41 |
42 |
43 | def main():
44 | opt = parse_preprocessing_args()
45 |
46 | df = pd.read_csv(opt.csv_path)
47 |
48 | assert 'tags' in df.columns
49 | assert 'important_tags' in df.columns
50 | assert opt.tags_field_type in ['str', 'list']
51 | assert opt.important_tags_field_type in ['str', 'list']
52 |
53 | series_of_item_id = df['item_id']
54 | series_of_tags = df['tags']
55 | series_of_important_tags = df['important_tags']
56 |
57 | if opt.tags_field_type == 'str':
58 | series_of_tags = series_of_tags.apply(lambda x: x.split(','))
59 | elif opt.tags_field_type == 'list':
60 | series_of_tags = series_of_tags.apply(ast.literal_eval)
61 |
62 | if opt.important_tags_field_type == 'str':
63 | series_of_important_tags = series_of_important_tags.apply(
64 | lambda x: x.split(','))
65 | elif opt.important_tags_field_type == 'list':
66 | series_of_important_tags = series_of_important_tags.apply(ast.literal_eval)
67 |
68 | rows_dict = dict()
69 | i = 0
70 |
71 | if opt.use_multiprocessing:
72 | import multiprocessing as mp
73 |
74 | # Apply a patch for the multiprocessing module
75 | import multiprocessing.pool as mpp
76 | from magneto.utils import istarmap
77 | mpp.Pool.istarmap = istarmap
78 |
79 | if opt.num_workers == -1:
80 | opt.num_workers = mp.cpu_count()
81 |
82 | inputs = list(zip(
83 | series_of_item_id,
84 | series_of_tags,
85 | series_of_important_tags
86 | ))
87 | with mp.Pool(opt.num_workers) as pool:
88 | for result in tqdm(pool.istarmap(label_important_tags, inputs), total=len(inputs)):
89 | rows_dict[i] = result
90 | i += 1
91 | else:
92 | for item_id, tags, important_tags \
93 | in tqdm(list(zip(
94 | series_of_item_id,
95 | series_of_tags,
96 | series_of_important_tags
97 | ))):
98 |
99 | result = label_important_tags(
100 | item_id,
101 | tags,
102 | important_tags
103 | )
104 |
105 | rows_dict[i] = result
106 | i += 1
107 |
108 | new_df = pd.DataFrame.from_dict(rows_dict, 'index')
109 | new_df.to_csv(opt.save_path, index=False)
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/MAGNeto/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.19.0
2 | scipy==1.5.2
3 | opencv-python==4.2.0.34
4 | pandas==1.0.5
5 | Pillow==6.1.0
6 | tensorboard==2.2.2
7 | torch==1.5.1
8 | torchvision==0.6.1
9 | tqdm==4.47.0
10 | matplotlib==3.2.2
11 |
--------------------------------------------------------------------------------
/MAGNeto/scripts/start_infer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m infer \
4 | --csv-path ./data/nus_wide/annotations/val_81_with_label.csv \
5 | --img-dir ./data/nus_wide/images \
6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \
7 | --model-path ./snapshots/demo/best_f1.pth \
8 | --batch-size 32 \
9 | --num-workers 4 \
10 | --threshold 0.5 \
11 | --top 0 \
12 | --gpu-id 0 \
13 | --has-label \
14 | -m
15 |
--------------------------------------------------------------------------------
/MAGNeto/scripts/start_preprocess.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m preprocess \
4 | -c data/nus_wide/annotations/train_81.csv \
5 | -s data/nus_wide/annotations/train_81_with_label.csv \
6 | -tt str \
7 | -it str \
8 | -m \
9 | --num-workers 4
--------------------------------------------------------------------------------
/MAGNeto/scripts/start_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m train \
4 | --train-csv-path ./data/nus_wide/annotations/train_81_with_label.csv \
5 | --val-csv-path ./data/nus_wide/annotations/val_81_with_label.csv \
6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \
7 | --img-dir ./data/nus_wide/images \
8 | --save-dir ./snapshots/demo \
9 | --start-from-epoch 0 \
10 | --t-heads 4 \
11 | --t-blocks 2 \
12 | --t-dim-feedforward 512 \
13 | --i-heads 4 \
14 | --i-blocks 1 \
15 | --i-dim-feedforward 512 \
16 | --img-backbone resnet18 \
17 | --d-model 128 \
18 | --max-len 16 \
19 | --g-dim-feedforward 512 \
20 | --dropout 0.3 \
21 | --threshold 0.5 \
22 | --tagaug-add-max-ratio 1.0 \
23 | --tagaug-drop-max-ratio 0.0 \
24 | --train-batch-size 32 \
25 | --val-batch-size 32 \
26 | --epochs 500 \
27 | --gpu-id 0 \
28 | --num-workers 8 \
29 | --log-graph \
30 | --save-best-loss \
31 | --save-best-f1 \
32 | --save-latest \
33 | --lr 1e-2 \
34 | --log-weight-hist \
--------------------------------------------------------------------------------
/MAGNeto/scripts/start_train_usp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m train \
4 | --train-csv-path ./data/nus_wide/annotations/train_81_with_label.csv \
5 | --val-csv-path ./data/nus_wide/annotations/val_81_with_label.csv \
6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \
7 | --img-dir ./data/nus_wide/images \
8 | --save-dir ./snapshots/nus_wide_81_add0p0_drop0p0_with_unsupervised_pretraining_Sep_12_20 \
9 | --checkpoint-path ./snapshots/nuswide_top_81_ver_1_unsupervised_pretraining_Sep_10_20/ckpt.pth \
10 | --load-weights-only \
11 | --exclude-top \
12 | --start-from-epoch 0 \
13 | --t-heads 4 \
14 | --t-blocks 2 \
15 | --t-dim-feedforward 512 \
16 | --i-heads 4 \
17 | --i-blocks 1 \
18 | --i-dim-feedforward 512 \
19 | --img-backbone resnet18 \
20 | --d-model 128 \
21 | --max-len 16 \
22 | --g-dim-feedforward 512 \
23 | --dropout 0.3 \
24 | --threshold 0.5 \
25 | --tagaug-add-max-ratio 0.0 \
26 | --tagaug-drop-max-ratio 0.0 \
27 | --train-batch-size 32 \
28 | --val-batch-size 32 \
29 | --epochs 500 \
30 | --gpu-id 3 \
31 | --num-workers 8 \
32 | --log-graph \
33 | --save-best-loss \
34 | --lr 1e-2 \
35 | --log-weight-hist \
--------------------------------------------------------------------------------
/MAGNeto/train.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | from torch import optim
4 |
5 | from magneto.utils import parse_train_args
6 | from magneto.data import get_dataloaders
7 | from magneto.model import MAGNeto
8 | from magneto.utils import Trainer
9 |
10 | warnings.filterwarnings("ignore")
11 |
12 |
13 | def main():
14 | ##### GET CONFIGURATION #####
15 | opt = parse_train_args()
16 |
17 | ##### PREPARING DATASETS #####
18 | print('\nPreparing datasets...')
19 | train_dataloader, val_dataloader, vocab_size = get_dataloaders(
20 | train_csv_path=opt.train_csv_path,
21 | val_csv_path=opt.val_csv_path,
22 | vocab_path=opt.vocab_path,
23 | img_dir=opt.img_dir,
24 | tagaug_add_max_ratio=opt.tagaug_add_max_ratio,
25 | tagaug_drop_max_ratio=opt.tagaug_drop_max_ratio,
26 | train_batch_size=opt.train_batch_size,
27 | val_batch_size=opt.val_batch_size,
28 | max_len=opt.max_len,
29 | num_workers=opt.num_workers,
30 | pin_memory=True if not opt.no_cuda else False
31 | )
32 |
33 | ##### CREATE MODEL #####
34 | model = MAGNeto(
35 | d_model=opt.d_model,
36 | vocab_size=vocab_size,
37 | t_blocks=opt.t_blocks,
38 | t_heads=opt.t_heads,
39 | t_dim_feedforward=opt.t_dim_feedforward,
40 | i_blocks=opt.i_blocks,
41 | i_heads=opt.i_heads,
42 | i_dim_feedforward=opt.i_dim_feedforward,
43 | img_backbone=opt.img_backbone,
44 | g_dim_feedforward=opt.g_dim_feedforward,
45 | dropout=opt.dropout,
46 | )
47 | model = model.to(opt.device)
48 |
49 | ##### CREATE OPTIMIZER #####
50 | optimizer = optim.SGD(
51 | filter(lambda p: p.requires_grad, model.parameters()),
52 | lr=opt.lr,
53 | momentum=0.9
54 | )
55 |
56 | ##### CREATE TRAINER AND START THE TRAINING PROCESS #####
57 | trainer = Trainer(
58 | model=model,
59 | optimizer=optimizer,
60 | opt=opt
61 | )
62 | trainer.fit(train_dataloader, val_dataloader)
63 |
64 |
65 | if __name__ == '__main__':
66 | main()
67 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LabTeam
2 |
3 | ## Publications
4 | - MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem [[code](MAGNeto)][[abs](https://arxiv.org/abs/2011.04349)][[pdf](https://arxiv.org/pdf/2011.04349)]
5 |
--------------------------------------------------------------------------------