├── LICENSE ├── MAGNeto ├── .gitignore ├── README.md ├── data │ └── nus_wide │ │ └── notebooks │ │ ├── Move Images.ipynb │ │ └── Prepare Tag Data.ipynb ├── infer.py ├── magneto │ ├── __init__.py │ ├── augment_helper.py │ ├── autoaugment.py │ ├── data.py │ ├── layers.py │ ├── loss.py │ ├── metrics.py │ ├── model.py │ └── utils.py ├── preprocess.py ├── requirements.txt ├── scripts │ ├── start_infer.sh │ ├── start_preprocess.sh │ ├── start_train.sh │ └── start_train_usp.sh └── train.py └── README.md /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . -------------------------------------------------------------------------------- /MAGNeto/.gitignore: -------------------------------------------------------------------------------- 1 | # Local files and directories 2 | data/nus_wide/images 3 | data/nus_wide/annotations 4 | data/nus_wide/raw_data 5 | runs 6 | snapshots 7 | tmp 8 | .vscode 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /MAGNeto/README.md: -------------------------------------------------------------------------------- 1 | # [MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem](https://arxiv.org/abs/2011.04349) 2 | 3 | ## Downloading NUS-WIDE dataset 4 | - Official: https://lms.comp.nus.edu.sg/wp-content/uploads/2019/research/nuswide/NUS-WIDE.html 5 | - Unofficial: http://cs-people.bu.edu/hekun/data/TALR/NUSWIDE.zip (recommended for downloading all images) 6 | 7 | ## Data preparation 8 | 9 | ### Moving images to a single directory 10 | 11 | ``` 12 | ./data/nus_wide/notebooks/Move\ Images.ipynb 13 | ``` 14 | 15 | ### Preparing tag data 16 | 17 | ``` 18 | ./data/nus_wide/notebooks/Prepare\ Tag\ Data.ipynb 19 | ``` 20 | 21 | ## Setting up the environment 22 | 23 | ```bash 24 | pip install -U pip 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Generating label for raw data 29 | 30 | - Step 1: Reconfigure `scripts/start_preprocess.sh` 31 | 32 | To list all configurable parameters, run 33 | 34 | ```bash 35 | python preprocess.py -h 36 | ``` 37 | 38 | - Step 2: Run 39 | 40 | ```bash 41 | bash scripts/start_preprocess.sh 42 | ``` 43 | 44 | ## Training the model 45 | 46 | - Step 1: Reconfigure `scripts/start_train.sh` 47 | 48 | To list all configurable parameters, run 49 | 50 | ```bash 51 | python train.py -h 52 | ``` 53 | 54 | - Step 2: Run 55 | 56 | ```bash 57 | bash scripts/start_train.sh 58 | ``` 59 | 60 | ## Inferring test data 61 | 62 | - Step 1: Reconfigure `scripts/start_infer.sh` 63 | 64 | To list all configurable parameters, run 65 | 66 | ```bash 67 | python infer.py -h 68 | ``` 69 | 70 | - Step 2: Run 71 | 72 | ```bash 73 | bash scripts/start_infer.sh 74 | ``` 75 | 76 | ## Reference 77 | 78 | Please acknowledge the following paper in case of using this code as part of any published research: 79 | 80 | **"MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem."** 81 | Hieu Trong Phung, Anh Tuan Vu, Tung Dinh Nguyen, Lam Thanh Do, Giang Nam Ngo, Trung Thanh Tran, Ngoc C. Lê. 82 | 83 | @article{Hieu2020, 84 | title={MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem}, 85 | author={Hieu Trong Phung and Anh Tuan Vu and Tung Dinh Nguyen and Lam Thanh Do and Giang Nam Ngo and Trung Thanh Tran and Ngoc C. L\^{e}}, 86 | journal={arXiv preprint arXiv:2011.04349}, 87 | year={2020} 88 | } 89 | 90 | ## License 91 | 92 | The code is released under the [GPLv3 License](https://www.gnu.org/licenses/gpl-3.0.en.html). 93 | -------------------------------------------------------------------------------- /MAGNeto/data/nus_wide/notebooks/Move Images.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pathlib import Path\n", 10 | "from shutil import copy\n", 11 | "\n", 12 | "from tqdm import tqdm" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "IMG_DIR = Path('../downloads/Flickr/') # Path to the directory that contains all images\n", 22 | "SAVE_DIR = Path('../images') # Path to the target directory" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stderr", 32 | "output_type": "stream", 33 | "text": [ 34 | "100%|██████████| 704/704 [00:51<00:00, 13.67it/s]\n" 35 | ] 36 | } 37 | ], 38 | "source": [ 39 | "for subdir in tqdm(list(IMG_DIR.glob('*'))):\n", 40 | " for img_path in subdir.glob('*.jpg'):\n", 41 | " trg = SAVE_DIR / str(img_path).split('_')[-1] # Only use the IDs of available images to name the new moved images\n", 42 | " copy(img_path, trg)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 4, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "269642" 54 | ] 55 | }, 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "len(list(SAVE_DIR.glob('*.jpg')))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "!rm -rf $IMG_DIR" 72 | ] 73 | } 74 | ], 75 | "metadata": { 76 | "kernelspec": { 77 | "display_name": "Python 3", 78 | "language": "python", 79 | "name": "python3" 80 | }, 81 | "language_info": { 82 | "codemirror_mode": { 83 | "name": "ipython", 84 | "version": 3 85 | }, 86 | "file_extension": ".py", 87 | "mimetype": "text/x-python", 88 | "name": "python", 89 | "nbconvert_exporter": "python", 90 | "pygments_lexer": "ipython3", 91 | "version": "3.6.12" 92 | } 93 | }, 94 | "nbformat": 4, 95 | "nbformat_minor": 4 96 | } 97 | -------------------------------------------------------------------------------- /MAGNeto/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing as mp 4 | import copy 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | 12 | from magneto.model import MAGNeto 13 | from magneto.data import TagAndImageDataset 14 | from magneto.augment_helper import val_transform 15 | from magneto.utils import parse_infer_args 16 | 17 | 18 | def predict(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, opt: argparse.Namespace) -> list: 19 | ''' 20 | input: 21 | + model: 22 | + dataloader: 23 | + opt: configuration. 24 | output: 25 | list of the predictions of all items. 26 | ''' 27 | all_preds = [] 28 | all_item_ids = [] 29 | 30 | with torch.no_grad(): 31 | for batch_val_idx, data in enumerate(tqdm(dataloader)): 32 | if opt.has_label: 33 | image_batch, tags_batch, mask_batch, _, item_id_batch = data 34 | else: 35 | image_batch, tags_batch, mask_batch, item_id_batch = data 36 | image_batch = image_batch.to(opt.device) 37 | tags_batch = tags_batch.to(opt.device) 38 | mask_batch = mask_batch.to(opt.device) 39 | 40 | preds, _, _, _, _ = model(tags_batch, image_batch, mask_batch) 41 | 42 | preds = preds.detach().cpu().numpy() 43 | mask_batch = mask_batch.detach().cpu().numpy() 44 | preds = [tuple(pred[~mask].tolist()) for pred, mask in zip(preds, mask_batch)] 45 | 46 | all_preds.extend(preds) 47 | all_item_ids.extend(item_id_batch.detach().cpu().tolist()) 48 | 49 | return all_preds, all_item_ids 50 | 51 | 52 | def postprocess_prediction(row, opt: argparse.Namespace): 53 | ''' 54 | input: 55 | + row 56 | + opt: configuration. 57 | output: 58 | [important_tags,] post_prediction 59 | ''' 60 | tags = np.array(row['tags'].split(',')) 61 | tags = tags[:opt.max_len] 62 | 63 | if opt.has_label: 64 | label = np.array(row['label'].split(','), dtype=np.uint8) 65 | label = label[:opt.max_len] 66 | mask = label == 1 67 | important_tags = tags[mask] 68 | 69 | final_results = sorted( 70 | zip(tags, row.raw_prediction, mask), key=lambda x: x[1], reverse=True) 71 | else: 72 | final_results = sorted(zip(tags, row.raw_prediction), 73 | key=lambda x: x[1], reverse=True) 74 | 75 | # Get at least top n important tags 76 | post_prediction = final_results[:opt.top] 77 | # Get other accepted important tags based on threshold value 78 | for final_result in final_results[opt.top:]: 79 | if final_result[1] > opt.threshold: 80 | post_prediction.append(final_result) 81 | else: 82 | break 83 | 84 | if opt.has_label: 85 | return important_tags, post_prediction 86 | else: 87 | return post_prediction 88 | 89 | 90 | def postprocess_predictions(df: pd.DataFrame, opt: argparse.Namespace) -> pd.DataFrame: 91 | ''' 92 | input: 93 | + df: input pandas dataframe. 94 | + opt: configuration. 95 | output: 96 | postprocessed pandas dataframe. 97 | ''' 98 | post_predictions = [] 99 | if opt.has_label: 100 | list_of_important_tags = [] 101 | 102 | if opt.use_multiprocessing: 103 | import multiprocessing as mp 104 | 105 | # Apply a patch for the multiprocessing module 106 | import multiprocessing.pool as mpp 107 | from magneto.utils import istarmap 108 | mpp.Pool.istarmap = istarmap 109 | 110 | all_rows = [row for idx, row in df.iterrows()] 111 | 112 | inputs = list(zip( 113 | all_rows, 114 | [copy.deepcopy(opt) for _ in range(len(df))] 115 | )) 116 | 117 | with mp.Pool(opt.num_workers) as pool: 118 | for result in tqdm(pool.istarmap(postprocess_prediction, inputs), total=len(inputs)): 119 | if opt.has_label: 120 | important_tags, post_prediction = result 121 | list_of_important_tags.append(important_tags) 122 | else: 123 | post_prediction = result 124 | 125 | post_predictions.append(post_prediction) 126 | 127 | else: 128 | for idx, row in tqdm(list(df.iterrows())): 129 | if opt.has_label: 130 | important_tags, post_prediction = postprocess_prediction( 131 | row, opt) 132 | list_of_important_tags.append(important_tags) 133 | else: 134 | post_prediction = postprocess_prediction( 135 | row, opt) 136 | 137 | post_predictions.append(post_prediction) 138 | 139 | list_of_pred_tags = [] 140 | list_of_probs = [] 141 | 142 | for post_prediction in post_predictions: 143 | post_prediction = list(zip(*post_prediction)) 144 | if len(post_prediction) >= 2: 145 | # TODO we will take care of masks later. 146 | pred_tags, probs = post_prediction[0], post_prediction[1] 147 | 148 | list_of_pred_tags.append('\n'.join(pred_tags)) 149 | probs = np.round(probs, decimals=3) 150 | probs = np.array(probs, dtype=str) 151 | list_of_probs.append('\n'.join(probs)) 152 | else: 153 | list_of_pred_tags.append('') 154 | list_of_probs.append('') 155 | 156 | df['pred_tags'] = list_of_pred_tags 157 | df['probs'] = list_of_probs 158 | 159 | if opt.has_label: 160 | list_of_important_tags = list( 161 | map(lambda x: '\n'.join(x), list_of_important_tags)) 162 | 163 | df['important_tags'] = list_of_important_tags 164 | 165 | return df 166 | 167 | 168 | def main(): 169 | opt = parse_infer_args() 170 | 171 | states = torch.load( 172 | opt.model_path, map_location=lambda storage, loc: storage) 173 | 174 | # Load model's configuration 175 | model_config = states['config'] 176 | opt.max_len = model_config['max_len'] 177 | opt.d_model = model_config['d_model'] 178 | opt.t_blocks = model_config['t_blocks'] 179 | opt.t_heads = model_config['t_heads'] 180 | opt.t_dim_feedforward = model_config['t_dim_feedforward'] 181 | opt.i_blocks = model_config['i_blocks'] 182 | opt.i_heads = model_config['i_heads'] 183 | opt.i_dim_feedforward = model_config['i_dim_feedforward'] 184 | opt.img_backbone = model_config['img_backbone'] 185 | opt.g_dim_feedforward = model_config['g_dim_feedforward'] 186 | 187 | test_dataset = TagAndImageDataset( 188 | csv_path=opt.csv_path, 189 | vocab_path=opt.vocab_path, 190 | img_dir=opt.img_dir, 191 | max_len=opt.max_len, 192 | has_label=opt.has_label, 193 | return_item_id=True, 194 | img_preprocess_fn=val_transform 195 | ) 196 | 197 | test_dataloader = DataLoader( 198 | dataset=test_dataset, 199 | batch_size=opt.batch_size, 200 | num_workers=opt.num_workers, 201 | pin_memory=True if not opt.no_cuda else False 202 | ) 203 | model = MAGNeto( 204 | d_model=opt.d_model, 205 | vocab_size=test_dataset.vocab_size, 206 | t_blocks=opt.t_blocks, 207 | t_heads=opt.t_heads, 208 | t_dim_feedforward=opt.t_dim_feedforward, 209 | i_blocks=opt.i_blocks, 210 | i_heads=opt.i_heads, 211 | i_dim_feedforward=opt.i_dim_feedforward, 212 | img_backbone=opt.img_backbone, 213 | g_dim_feedforward=opt.g_dim_feedforward, 214 | dropout=0 215 | ) 216 | model.load_state_dict(states['model']) 217 | model.to(opt.device) 218 | model.eval() 219 | 220 | all_preds, all_item_ids = predict(model, test_dataloader, opt) 221 | raw_prediction_df = pd.DataFrame({ 222 | 'item_id': all_item_ids, 223 | 'raw_prediction': all_preds 224 | }).drop_duplicates().set_index('item_id') 225 | 226 | base_df = pd.read_csv(opt.csv_path, index_col='item_id') 227 | 228 | # Log all error item ids 229 | error_item_ids = np.setdiff1d(base_df.index.unique(), raw_prediction_df.index.unique(), assume_unique=True).astype(str) 230 | if len(error_item_ids) > 0: 231 | print('Error item ids:', ', '.join(error_item_ids)) 232 | with open('error_item_ids.txt', 'w') as f: 233 | f.write('\n'.join(error_item_ids)) 234 | 235 | final_df = raw_prediction_df.join(base_df).reset_index() 236 | 237 | final_df = postprocess_predictions(final_df, opt) 238 | 239 | if opt.has_label: 240 | final_df.rename(columns={'important_tags': 'ground_truth'}, inplace=True) 241 | final_df[['item_id', 'tags', 'pred_tags', 'probs', 'ground_truth']].to_csv( 242 | 'prediction.csv', index=False) 243 | else: 244 | final_df[['item_id', 'tags', 'pred_tags', 'probs']].to_csv( 245 | 'prediction.csv', index=False) 246 | 247 | 248 | if __name__ == '__main__': 249 | main() 250 | -------------------------------------------------------------------------------- /MAGNeto/magneto/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pixta-dev/labteam/2c14e0605520c100eca24f92d79461167c765c2f/MAGNeto/magneto/__init__.py -------------------------------------------------------------------------------- /MAGNeto/magneto/augment_helper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy.special import softmax 6 | from torchvision import transforms 7 | 8 | from magneto.autoaugment import ImageNetPolicy 9 | 10 | 11 | MEAN = [0.485, 0.456, 0.406] 12 | STD = [0.229, 0.224, 0.225] 13 | INPUT_SHAPE = 112 14 | 15 | train_transform = transforms.Compose([ 16 | transforms.RandomResizedCrop(INPUT_SHAPE, scale=(0.3, 1.0)), 17 | transforms.RandomHorizontalFlip(), 18 | ImageNetPolicy(), 19 | transforms.ToTensor(), 20 | transforms.Normalize(mean=MEAN, std=STD) 21 | ]) 22 | 23 | val_transform = transforms.Compose([ 24 | transforms.Resize(INPUT_SHAPE), 25 | transforms.CenterCrop(INPUT_SHAPE), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=MEAN, std=STD) 28 | ]) 29 | 30 | 31 | class TagAugmentation(): 32 | def __init__( 33 | self, 34 | vocab_path: str, 35 | drop: float = 0.0, 36 | add: float = 0.0, 37 | path: str = None 38 | ): 39 | self.vocab = pd.read_csv( 40 | vocab_path, keep_default_na=False, na_values=['']).word.tolist() 41 | self.drop, self.add = drop, add 42 | 43 | def __call__(self, tags: np.array, label: np.array) -> (np.array, np.array): 44 | ''' 45 | input: 46 | + tags: raw tags. 47 | + label: raw label. 48 | output: 49 | Processed tags and corresponding label. 50 | ''' 51 | self.tags, self.label = tags, label 52 | self._seperate_indices() 53 | 54 | # NOTE: Dropping must be performed prior to adding process 55 | if self.drop: 56 | self.tags, self.label = self._drop_tag() 57 | 58 | if self.add: 59 | self.tags, self.label = self._add_tag() 60 | 61 | return self.tags, self.label 62 | 63 | def _get_num(self, prob: float): 64 | return random.randint(0, min(int(prob * len(self.unimportant_indices)), len(self.vocab))) 65 | 66 | def _seperate_indices(self): 67 | unimportant_mask = self.label == 0 68 | self.unimportant_indices = np.array(range(len(self.tags)))[ 69 | unimportant_mask] 70 | self.important_indices = np.array(range(len(self.tags)))[ 71 | np.logical_not(unimportant_mask)] 72 | 73 | def _drop_tag(self): 74 | # Randomly select the number of unimportant tags to keep 75 | num_unimportant_drop = self._get_num(self.drop) 76 | num_unimportant_keep = len( 77 | self.unimportant_indices) - num_unimportant_drop 78 | 79 | # Randomly choose indices of unimportant tags to keep based on the number above 80 | unimportant_keep_indices = np.array(random.sample( 81 | list(self.unimportant_indices), k=num_unimportant_keep), dtype=int) 82 | keep_indices = np.concatenate( 83 | (unimportant_keep_indices, self.important_indices)) 84 | keep_indices.sort() 85 | 86 | return self.tags[keep_indices], self.label[keep_indices] 87 | 88 | def _add_tag(self): 89 | num_add = self._get_num(self.add) 90 | 91 | tags_add = [] 92 | sampled_tags = 0 93 | while sampled_tags < num_add: 94 | noise_tags = random.sample(self.vocab, k=num_add - sampled_tags) 95 | valid_tags = [t for t in noise_tags if t not in self.tags] 96 | tags_add += valid_tags 97 | sampled_tags = len(tags_add) 98 | 99 | tags = np.concatenate((self.tags, np.asarray(tags_add))) 100 | label = np.concatenate((self.label, np.zeros(num_add))) 101 | 102 | return tags, label 103 | -------------------------------------------------------------------------------- /MAGNeto/magneto/autoaugment.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 2 | from PIL import Image, ImageEnhance, ImageOps 3 | import numpy as np 4 | import random 5 | 6 | 7 | class ImageNetPolicy(object): 8 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 9 | 10 | Example: 11 | >>> policy = ImageNetPolicy() 12 | >>> transformed = policy(image) 13 | 14 | Example as a PyTorch Transform: 15 | >>> transform=transforms.Compose([ 16 | >>> transforms.Resize(256), 17 | >>> ImageNetPolicy(), 18 | >>> transforms.ToTensor()]) 19 | """ 20 | 21 | def __init__(self, fillcolor=(128, 128, 128)): 22 | self.policies = [ 23 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 24 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 25 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 26 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 27 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 28 | 29 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 30 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 31 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 32 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 33 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 34 | 35 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 36 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 37 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 38 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 39 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 40 | 41 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 42 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 43 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 44 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 45 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 46 | 47 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 48 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 49 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 50 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 51 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 52 | ] 53 | 54 | def __call__(self, img): 55 | policy_idx = random.randint(0, len(self.policies) - 1) 56 | return self.policies[policy_idx](img) 57 | 58 | def __repr__(self): 59 | return "AutoAugment ImageNet Policy" 60 | 61 | 62 | class CIFAR10Policy(object): 63 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 64 | 65 | Example: 66 | >>> policy = CIFAR10Policy() 67 | >>> transformed = policy(image) 68 | 69 | Example as a PyTorch Transform: 70 | >>> transform=transforms.Compose([ 71 | >>> transforms.Resize(256), 72 | >>> CIFAR10Policy(), 73 | >>> transforms.ToTensor()]) 74 | """ 75 | 76 | def __init__(self, fillcolor=(128, 128, 128)): 77 | self.policies = [ 78 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 79 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 80 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 81 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 82 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 83 | 84 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 85 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 86 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 87 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 88 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 89 | 90 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 91 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 92 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 93 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 94 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 95 | 96 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 97 | SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor), 98 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 99 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 100 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 101 | 102 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 103 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 104 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 105 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 106 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 107 | ] 108 | 109 | def __call__(self, img): 110 | policy_idx = random.randint(0, len(self.policies) - 1) 111 | return self.policies[policy_idx](img) 112 | 113 | def __repr__(self): 114 | return "AutoAugment CIFAR10 Policy" 115 | 116 | 117 | class SVHNPolicy(object): 118 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 119 | 120 | Example: 121 | >>> policy = SVHNPolicy() 122 | >>> transformed = policy(image) 123 | 124 | Example as a PyTorch Transform: 125 | >>> transform=transforms.Compose([ 126 | >>> transforms.Resize(256), 127 | >>> SVHNPolicy(), 128 | >>> transforms.ToTensor()]) 129 | """ 130 | 131 | def __init__(self, fillcolor=(128, 128, 128)): 132 | self.policies = [ 133 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 134 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 135 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 136 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 137 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 138 | 139 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 140 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 141 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 142 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 143 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 144 | 145 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 146 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 147 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 148 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 149 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 150 | 151 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 152 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 153 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 154 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 155 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 156 | 157 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 158 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 159 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 160 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 161 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 162 | ] 163 | 164 | def __call__(self, img): 165 | policy_idx = random.randint(0, len(self.policies) - 1) 166 | return self.policies[policy_idx](img) 167 | 168 | def __repr__(self): 169 | return "AutoAugment SVHN Policy" 170 | 171 | 172 | class SubPolicy(object): 173 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 174 | ranges = { 175 | "shearX": np.linspace(0, 0.3, 10), 176 | "shearY": np.linspace(0, 0.3, 10), 177 | "translateX": np.linspace(0, 150 / 331, 10), 178 | "translateY": np.linspace(0, 150 / 331, 10), 179 | "rotate": np.linspace(0, 30, 10), 180 | "color": np.linspace(0.0, 0.9, 10), 181 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 182 | "solarize": np.linspace(256, 0, 10), 183 | "contrast": np.linspace(0.0, 0.9, 10), 184 | "sharpness": np.linspace(0.0, 0.9, 10), 185 | "brightness": np.linspace(0.0, 0.9, 10), 186 | "autocontrast": [0] * 10, 187 | "equalize": [0] * 10, 188 | "invert": [0] * 10 189 | } 190 | 191 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 192 | def rotate_with_fill(img, magnitude): 193 | rot = img.convert("RGBA").rotate(magnitude) 194 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 195 | 196 | func = { 197 | "shearX": lambda img, magnitude: img.transform( 198 | img.size, Image.AFFINE, (1, magnitude * 199 | random.choice([-1, 1]), 0, 0, 1, 0), 200 | Image.BICUBIC, fillcolor=fillcolor), 201 | "shearY": lambda img, magnitude: img.transform( 202 | img.size, Image.AFFINE, (1, 0, 0, magnitude * 203 | random.choice([-1, 1]), 1, 0), 204 | Image.BICUBIC, fillcolor=fillcolor), 205 | "translateX": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, magnitude * 207 | img.size[0] * random.choice([-1, 1]), 0, 1, 0), 208 | fillcolor=fillcolor), 209 | "translateY": lambda img, magnitude: img.transform( 210 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * 211 | img.size[1] * random.choice([-1, 1])), 212 | fillcolor=fillcolor), 213 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 214 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 215 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 216 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 217 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 220 | 1 + magnitude * random.choice([-1, 1])), 221 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 222 | 1 + magnitude * random.choice([-1, 1])), 223 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 224 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 225 | "invert": lambda img, magnitude: ImageOps.invert(img) 226 | } 227 | 228 | self.p1 = p1 229 | self.operation1 = func[operation1] 230 | self.magnitude1 = ranges[operation1][magnitude_idx1] 231 | self.p2 = p2 232 | self.operation2 = func[operation2] 233 | self.magnitude2 = ranges[operation2][magnitude_idx2] 234 | def __call__(self, img): 235 | if random.random() < self.p1: 236 | img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: 238 | img = self.operation2(img, self.magnitude2) 239 | return img 240 | -------------------------------------------------------------------------------- /MAGNeto/magneto/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from torch.utils.data import Dataset, DataLoader 8 | from PIL import Image 9 | 10 | from magneto.augment_helper import train_transform, val_transform, TagAugmentation 11 | 12 | 13 | class TagAndImageDataset(Dataset): 14 | def __init__( 15 | self, 16 | csv_path: str, 17 | vocab_path: str, 18 | img_dir: str, 19 | max_len: int, 20 | has_label: bool = True, 21 | return_item_id: bool = False, 22 | tag_preprocess_fn: object = None, 23 | img_preprocess_fn: object = None, 24 | ): 25 | ''' 26 | input: 27 | + csv_path: path to the csv file that contains "image_id", "tags"[, "label"]. 28 | + vocab_path: path to the csv file that contains the vocabulary of the dataset. 29 | + img_dir: the directory that contains corresponding images. 30 | + max_len: the maximum number of tags. 31 | + has_label: whether to prepare and return label or not. 32 | + return_item_id: whether to return item_id or not. 33 | + tag_preprocess_fn: the preprocessing func for tags; only support when having label. 34 | + img_preprocess_fn: the preprocessing func for image. 35 | ''' 36 | df = pd.read_csv(csv_path) 37 | 38 | self.has_label = has_label 39 | self.return_item_id = return_item_id 40 | 41 | self.list_of_tags = df['tags'].apply( 42 | lambda x: np.array(x.split(','))).tolist() 43 | self.list_of_image_path = df['item_id'].map( 44 | lambda x: os.path.join(img_dir, str(x) + '.jpg')).tolist() 45 | if self.has_label: 46 | self.list_of_label = df['label'].apply( 47 | lambda x: np.array(x.split(','), dtype=np.float32)).tolist() 48 | if self.return_item_id: 49 | self.list_of_item_id = df['item_id'].tolist() 50 | 51 | self.vocab = pd.read_csv( 52 | vocab_path, keep_default_na=False, na_values=['']) 53 | self.word_to_index = self.vocab.set_index('word') 54 | self.vocab_size = len(self.vocab) 55 | 56 | self.max_num_of_tags = max_len 57 | self.tag_preprocess_fn = tag_preprocess_fn 58 | self.img_preprocess_fn = img_preprocess_fn 59 | 60 | def __len__(self) -> int: 61 | return len(self.list_of_tags) 62 | 63 | def __getitem__(self, idx: object) -> (torch.tensor, torch.tensor, torch.tensor, torch.tensor): 64 | ''' 65 | input: 66 | + idx: item's index. 67 | output: 68 | + image: self explanatory. 69 | + vectors: embedding vectors of tags. 70 | + label: corresponding label (only returned when being provided). 71 | + mask: generated mask used to mask-out padding positions. 72 | ''' 73 | if torch.is_tensor(idx): 74 | idx = idx.tolist() 75 | 76 | # Get image 77 | image_path = self.list_of_image_path[idx] 78 | try: 79 | image = Image.open(image_path).convert('RGB') 80 | except: 81 | return self.__getitem__(random.randrange(self.__len__())) 82 | 83 | if self.img_preprocess_fn is not None: 84 | image = self.img_preprocess_fn(image) 85 | 86 | # Get indices of tags, corresponding mask and label (if provided) 87 | tags = self.list_of_tags[idx] 88 | if self.has_label: 89 | label = self.list_of_label[idx] 90 | 91 | assert len(tags) == len(label) 92 | 93 | if self.tag_preprocess_fn is not None: 94 | tags, label = self.tag_preprocess_fn(tags, label) 95 | if self.return_item_id: 96 | item_id = self.list_of_item_id[idx] 97 | 98 | # Create default mask 99 | mask = torch.zeros(self.max_num_of_tags, dtype=torch.bool) 100 | 101 | # Fixed the number of tags 102 | if len(tags) >= self.max_num_of_tags: 103 | # Get top N 104 | tags = tags[:self.max_num_of_tags] 105 | indices = self.word_to_index.loc[tags, 'index'] 106 | indices = torch.tensor(indices, dtype=torch.int64) 107 | if self.has_label: 108 | label = torch.tensor( 109 | label[:self.max_num_of_tags], 110 | dtype=torch.float32 111 | ) 112 | else: 113 | indices = self.word_to_index.loc[tags, 'index'] 114 | 115 | # Right-padding 116 | # Padding idx will be n where n = vocab_size 117 | padding_vector = np.ones( 118 | self.max_num_of_tags, dtype=np.int64) * (self.vocab_size) 119 | padding_vector[:len(tags)] = indices 120 | indices = torch.tensor(padding_vector, dtype=torch.int64) 121 | 122 | mask[len(tags):] = True 123 | 124 | if self.has_label: 125 | zeros_vector = np.zeros( 126 | self.max_num_of_tags, dtype=np.float32) 127 | zeros_vector[:len(tags)] += label 128 | label = torch.tensor(zeros_vector, dtype=torch.float32) 129 | 130 | results = [image, indices, mask] 131 | if self.has_label: 132 | results.append(label) 133 | if self.return_item_id: 134 | results.append(item_id) 135 | 136 | return results 137 | 138 | 139 | def get_dataloaders( 140 | train_csv_path: str, 141 | val_csv_path: str, 142 | vocab_path: str, 143 | img_dir: str, 144 | tagaug_add_max_ratio: float, 145 | tagaug_drop_max_ratio: float, 146 | train_batch_size: int = 32, 147 | val_batch_size: int = 32, 148 | max_len: int = 100, 149 | num_workers: int = 0, 150 | pin_memory: bool = True 151 | ) -> (DataLoader, DataLoader): 152 | ''' 153 | input: 154 | + train_csv_path: path to the csv file of the training dataset. 155 | + val_csv_path: path to the csv file of the validation dataset. 156 | + vocab_path: path to the csv file that contains the vocabulary of the dataset. 157 | + img_dir: the directory that contains all images for training and validation sets. 158 | + tagaug_add_max_ratio: the maximum ratio between the number of adding tags and non-important ones. 159 | + tagaug_drop_max_ratio: the maximum ratio between the number of dropping tags and non-important ones. 160 | + train_batch_size: the batch-size of the training dataloader. 161 | + val_batch_size: the batch-size of the validation dataloader. 162 | + max_len: the maximum length for each set of tags. 163 | + num_workers: the number of workers used to load data. 164 | + pin_memory: the pin_memory param of PyTorch's DataLoader class. 165 | output: 166 | the dataloaders for the training and validation sets. 167 | ''' 168 | train_dataset = TagAndImageDataset( 169 | csv_path=train_csv_path, 170 | vocab_path=vocab_path, 171 | img_dir=img_dir, 172 | max_len=max_len, 173 | tag_preprocess_fn=TagAugmentation( 174 | vocab_path=vocab_path, 175 | add=tagaug_add_max_ratio, 176 | drop=tagaug_drop_max_ratio 177 | ) if (tagaug_add_max_ratio or tagaug_drop_max_ratio) else None, # Only use when necessary 178 | img_preprocess_fn=train_transform 179 | ) 180 | val_dataset = TagAndImageDataset( 181 | csv_path=val_csv_path, 182 | vocab_path=vocab_path, 183 | img_dir=img_dir, 184 | max_len=max_len, 185 | img_preprocess_fn=val_transform 186 | ) 187 | 188 | train_dataloader = DataLoader( 189 | dataset=train_dataset, 190 | batch_size=train_batch_size, 191 | shuffle=True, 192 | num_workers=num_workers, 193 | pin_memory=pin_memory 194 | ) 195 | val_dataloader = DataLoader( 196 | dataset=val_dataset, 197 | batch_size=val_batch_size, 198 | num_workers=num_workers, 199 | pin_memory=pin_memory 200 | ) 201 | 202 | return train_dataloader, val_dataloader, train_dataset.vocab_size 203 | -------------------------------------------------------------------------------- /MAGNeto/magneto/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision import models 7 | 8 | 9 | def freeze_all_parameters(module: nn.Module): 10 | ''' Freeze all parameters of a PyTorch Module. 11 | input: 12 | + module: self-explanatory. 13 | ''' 14 | for param in module.parameters(): 15 | param.requires_grad = False 16 | 17 | 18 | def unfreeze_all_parameters(module: nn.Module): 19 | ''' Unfreeze all parameters of a PyTorch Module. 20 | input: 21 | + module: self-explanatory. 22 | ''' 23 | for param in module.parameters(): 24 | param.requires_grad = True 25 | 26 | 27 | class TagEmbedder(nn.Module): 28 | def __init__(self, vocab_size, d_model): 29 | super(TagEmbedder, self).__init__() 30 | self.embed = nn.Embedding( 31 | num_embeddings=vocab_size+1, # Plus the padding 32 | embedding_dim=d_model, 33 | ) 34 | 35 | def forward(self, x): 36 | return self.embed(x) 37 | 38 | 39 | class MultiHeadMaskedScaledDotProduct(nn.Module): 40 | def __init__(self, d_k: int): 41 | ''' 42 | input: 43 | + d_k: the dimensionality of the subspace. 44 | ''' 45 | super(MultiHeadMaskedScaledDotProduct, self).__init__() 46 | 47 | self.d_k = d_k 48 | 49 | def forward( 50 | self, 51 | q: torch.tensor, 52 | k: torch.tensor, 53 | mask: torch.tensor = None 54 | ) -> torch.tensor: 55 | ''' 56 | input: 57 | + q: the matrix of queries. 58 | + k: the matrix of keys. 59 | + mask: used to mask out padding positions. 60 | output: 61 | The matrix of scores. 62 | ''' 63 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 64 | 65 | if mask is not None: 66 | scores = scores.masked_fill( 67 | mask.unsqueeze(1).unsqueeze(2), 68 | float('-inf') 69 | ) 70 | 71 | return scores 72 | 73 | 74 | class MultiHeadAttention(nn.Module): 75 | def __init__(self, heads: int, d_model: int, dropout: float = 0.1): 76 | ''' 77 | input: 78 | + heads: the number of heads of each Multi-Head Attention layer. 79 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads. 80 | + dropout: dropout value of tag encoder layers. 81 | ''' 82 | super(MultiHeadAttention, self).__init__() 83 | 84 | self.d_model = d_model 85 | self.d_k = d_model // heads 86 | self.h = heads 87 | 88 | self.q_linear = nn.Linear(d_model, d_model) 89 | self.v_linear = nn.Linear(d_model, d_model) 90 | self.k_linear = nn.Linear(d_model, d_model) 91 | 92 | self.dp = MultiHeadMaskedScaledDotProduct(self.d_k) 93 | self.softmax = nn.Softmax(-1) 94 | self.dropout = nn.Dropout(dropout) 95 | self.out = nn.Linear(d_model, d_model) 96 | 97 | def forward( 98 | self, 99 | q: torch.tensor, 100 | k: torch.tensor, 101 | v: torch.tensor, 102 | mask: torch.tensor = None 103 | ) -> torch.tensor: 104 | ''' 105 | input: 106 | + q: the matrix of queries. 107 | + k: the matrix of keys. 108 | + v: the matrix of values. 109 | + mask: used to mask out padding positions. 110 | output: 111 | context vectors. 112 | ''' 113 | bs = q.size(0) 114 | 115 | # Perform linear operation and split into N heads 116 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 117 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 118 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 119 | 120 | # Transpose to get dimensions bs * N * sl * d_model 121 | k = k.transpose(1, 2) 122 | q = q.transpose(1, 2) 123 | v = v.transpose(1, 2) 124 | 125 | scores = self.dp(q, k, mask) 126 | scores = self.softmax(scores) 127 | 128 | scores = self.dropout(scores) 129 | 130 | # Compute context vectors based on calculated scores above 131 | context = torch.matmul(scores, v) 132 | 133 | # Concatenate heads and put through final linear layer 134 | concat = context.transpose(1, 2).contiguous()\ 135 | .view(bs, -1, self.d_model) 136 | output = self.out(concat) 137 | 138 | return output 139 | 140 | 141 | class TagToImageLayer(nn.Module): 142 | def __init__(self, d_model: int, heads: int, dropout: float = 0.1): 143 | ''' 144 | input: 145 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads. 146 | + heads: the number of heads of the Multi-Head Attention sub-layer. 147 | + dropout: the dropout value of the Multi-Head Attention sub-layer. 148 | ''' 149 | super(TagToImageLayer, self).__init__() 150 | 151 | self.attn = MultiHeadAttention( 152 | heads, d_model, dropout=dropout) 153 | 154 | def forward(self, tag_vectors: torch.tensor, img_regions: torch.tensor) -> torch.tensor: 155 | ''' 156 | input: 157 | + tag_vectors: self-explanatory. 158 | + img_regions: self-explanatory. 159 | output: 160 | output vectors. 161 | ''' 162 | out = self.attn(tag_vectors, img_regions, img_regions) 163 | 164 | return out 165 | 166 | 167 | class GatingLayer(nn.Module): 168 | def __init__(self, in_features: int, dim_feedforward: int, dropout: float = 0.1): 169 | ''' 170 | input: 171 | + in_features: the dimentionality of the input vectors. 172 | + dim_feedforward: the dimentionality of the hidden layer. 173 | + dropout: the dropout value of the Gating layer. 174 | ''' 175 | super(GatingLayer, self).__init__() 176 | 177 | self.dropout_1 = nn.Dropout(dropout) 178 | self.linear_1 = nn.Linear(in_features, dim_feedforward) 179 | self.relu = nn.ReLU() 180 | self.dropout_2 = nn.Dropout(dropout) 181 | self.linear_2 = nn.Linear(dim_feedforward, 1) 182 | self.sigmoid = nn.Sigmoid() 183 | 184 | def forward(self, tag_vectors: torch.tensor) -> torch.tensor: 185 | ''' 186 | input: 187 | + tag_vectors: self-explanatory. 188 | output: 189 | output gating values. 190 | ''' 191 | out = self.dropout_1(tag_vectors) 192 | out = self.linear_1(out) 193 | out = self.relu(out) 194 | out = self.dropout_2(out) 195 | out = self.linear_2(out) 196 | out = self.sigmoid(out.squeeze(dim=-1)) 197 | 198 | return out 199 | 200 | 201 | class ImageFeatureExtractor(nn.Module): 202 | def __init__(self, d_model: int, img_backbone: str): 203 | ''' 204 | input: 205 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads. 206 | ''' 207 | super(ImageFeatureExtractor, self).__init__() 208 | 209 | if img_backbone == 'resnet18': 210 | encoder = models.resnet18(pretrained=True) 211 | ex_out_dim = 512 212 | elif img_backbone == 'resnet50': 213 | encoder = models.resnet50(pretrained=True) 214 | ex_out_dim = 2048 215 | 216 | # Get all layers 217 | encoder_children = list(encoder.children()) 218 | # Drop the last avg & fc layers 219 | self.backbone = nn.Sequential(*encoder_children[:-2]) 220 | 221 | self.conv_1x1 = nn.Conv2d(ex_out_dim, d_model, kernel_size=( 222 | 1, 1), stride=(1, 1), bias=True) 223 | self.bn = nn.BatchNorm2d(d_model) 224 | self.flatten = nn.Flatten(start_dim=1, end_dim=2) 225 | 226 | self.freeze_all_layers() 227 | self.unfreeze_top_layers() 228 | # self.unfreeze_the_fourth_block() 229 | # self.unfreeze_the_third_block() 230 | # self.unfreeze_the_second_block() 231 | # self.unfreeze_the_first_block() 232 | # self.unfreeze_the_bottom_layers() 233 | 234 | def freeze_all_layers(self): 235 | ''' Freeze all image encoder's layers. 236 | ''' 237 | freeze_all_parameters(self) 238 | 239 | def unfreeze_top_layers(self): 240 | ''' Unfreeze the top layers of the image feature extractor. 241 | ''' 242 | # conv_1x1 243 | unfreeze_all_parameters(self.conv_1x1) 244 | 245 | # bn 246 | unfreeze_all_parameters(self.bn) 247 | 248 | def unfreeze_the_first_block(self): 249 | ''' Unfreeze the first block of the image encoder. 250 | ''' 251 | assert type(self.backbone[4]) is nn.Sequential 252 | 253 | unfreeze_all_parameters(self.backbone[4]) 254 | 255 | def unfreeze_the_second_block(self): 256 | ''' Unfreeze the second block of the image encoder. 257 | ''' 258 | assert type(self.backbone[5]) is nn.Sequential 259 | 260 | unfreeze_all_parameters(self.backbone[5]) 261 | 262 | def unfreeze_the_third_block(self): 263 | ''' Unfreeze the third block of the image encoder. 264 | ''' 265 | assert type(self.backbone[6]) is nn.Sequential 266 | 267 | unfreeze_all_parameters(self.backbone[6]) 268 | 269 | def unfreeze_the_fourth_block(self): 270 | ''' Unfreeze the fourth block of the image encoder. 271 | ''' 272 | assert type(self.backbone[7]) is nn.Sequential 273 | 274 | unfreeze_all_parameters(self.backbone[7]) 275 | 276 | def unfreeze_the_bottom_layers(self): 277 | ''' Unfreeze the bottom layers of the image encoder. 278 | ''' 279 | # Unfreeze the first conv layer and bn 280 | assert type(self.backbone[0]) is nn.Conv2d 281 | assert type(self.backbone[1]) is nn.BatchNorm2d 282 | assert type(self.backbone[2]) is nn.ReLU 283 | assert type(self.backbone[3]) is nn.MaxPool2d 284 | 285 | # conv 286 | unfreeze_all_parameters(self.backbone[0]) 287 | 288 | # bn 289 | unfreeze_all_parameters(self.backbone[1]) 290 | 291 | def forward(self, x: torch.tensor) -> torch.tensor: 292 | ''' 293 | input: 294 | + x: input image. 295 | output: 296 | image's features. 297 | ''' 298 | features = self.backbone(x) 299 | 300 | out = self.conv_1x1(features) 301 | out = self.bn(out) 302 | 303 | # Convert the tensor from channel first to channel last 304 | out = out.permute(0, 2, 3, 1) 305 | 306 | out = self.flatten(out) 307 | 308 | return out 309 | -------------------------------------------------------------------------------- /MAGNeto/magneto/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def dice_loss( 7 | input: torch.tensor, 8 | target: torch.tensor, 9 | beta: float = 1., 10 | reduction: str = 'mean', 11 | smooth: float = 1. 12 | ) -> torch.tensor: 13 | intersection = input * target 14 | score = ((1. + beta**2) * torch.sum(intersection, dim=-1) + smooth) \ 15 | / (torch.sum(input, dim=-1) + (beta**2) * torch.sum(target, dim=-1) + smooth) 16 | 17 | loss = 1. - score 18 | 19 | if reduction == 'mean': 20 | return loss.mean() 21 | elif reduction == 'sum': 22 | return loss.sum() 23 | 24 | return loss 25 | 26 | 27 | class DiceLoss(nn.Module): 28 | """ 29 | Dice loss's implementation. 30 | """ 31 | 32 | def __init__(self, reduction: str = 'mean', beta: float = 1., smooth: float = 1.): 33 | """ 34 | input: 35 | + reduction: specifies the reduction to apply to the output: 36 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 37 | ``'mean'``: the sum of the output will be divided by the number of 38 | elements in the output, ``'sum'``: the output will be summed. 39 | + beta: β is chosen such that recall is considered β times as important as precision. 40 | + smooth: smooth value. 41 | """ 42 | super(DiceLoss, self).__init__() 43 | 44 | assert beta >= 0, 'β must be a positive real value!' 45 | assert reduction in ['none', 'mean', 'sum'] 46 | 47 | self.beta = beta 48 | self.reduction = reduction 49 | self.smooth = smooth 50 | 51 | def forward(self, input: torch.tensor, target: torch.tensor) -> torch.tensor: 52 | """ 53 | input: 54 | + input: prediction. 55 | + target: ground-truth. 56 | output: 57 | + loss value. 58 | """ 59 | return dice_loss( 60 | input, 61 | target, 62 | beta=self.beta, 63 | reduction=self.reduction, 64 | smooth=self.smooth 65 | ) 66 | 67 | 68 | class BCEDiceLoss(nn.Module): 69 | """ 70 | The combination of Binary Cross-Entropy & Dice losses. 71 | """ 72 | 73 | def __init__( 74 | self, 75 | reduction: str = 'mean', 76 | weight: torch.tensor = None, 77 | beta: float = 1., 78 | smooth: float = 1.0 79 | ): 80 | """ 81 | input: 82 | + reduction: specifies the reduction to apply to the output: 83 | ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, 84 | ``'mean'``: the sum of the output will be divided by the number of 85 | elements in the output, ``'sum'``: the output will be summed. 86 | + weight: a manual rescaling weight given to the loss 87 | of each batch element. If given, has to be a Tensor of size `nbatch`, 88 | used for BCE part. 89 | + beta: β is chosen such that recall is considered β times as important as precision. 90 | + smooth: smooth value, used for Dice part. 91 | """ 92 | super(BCEDiceLoss, self).__init__() 93 | 94 | assert beta >= 0, 'β must be a positive real value!' 95 | assert reduction in ['none', 'mean', 'sum'] 96 | 97 | self.reduction = reduction 98 | 99 | # BCE's params 100 | self.weight = weight 101 | 102 | # Dice's params 103 | self.beta = beta 104 | self.smooth = smooth 105 | 106 | def forward(self, input: torch.tensor, target: torch.tensor) -> torch.tensor: 107 | """ 108 | input: 109 | + input: prediction. 110 | + target: ground-truth. 111 | output: 112 | + loss value. 113 | """ 114 | bce = F.binary_cross_entropy( 115 | input, 116 | target, 117 | weight=self.weight, 118 | reduction=self.reduction 119 | ) 120 | dice = dice_loss( 121 | input, 122 | target, 123 | beta=self.beta, 124 | reduction=self.reduction, 125 | smooth=self.smooth 126 | ) 127 | 128 | return bce + dice 129 | -------------------------------------------------------------------------------- /MAGNeto/magneto/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | 4 | 5 | logging.basicConfig(level=logging.INFO) 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | class PrecisionRecallFk: 10 | """ 11 | Calculate precision, recall, f1 for predictions. 12 | """ 13 | 14 | def __init__(self, enable_logger=False, threshold=0.5, eps=1e-9): 15 | if enable_logger: 16 | global logger 17 | self.logger = logger 18 | 19 | self.threshold = threshold 20 | self.eps = eps 21 | 22 | def __call__(self, 23 | prediction: np.array, 24 | ground_truth: np.array, 25 | betas: list = [1], 26 | top_ks: list = None) -> (float, float, float): 27 | """ 28 | compute F-beta score 29 | 30 | input: 31 | + prediction: predictions of model, np.array of shape [B, N] with B be the batchsize 32 | and N is the number of classes 33 | + ground_truth: self explanatory, must have the same shape as prediction 34 | + betas: a list of betas 35 | + top_ks: if specified, compute f_score of top_k most confident prediction 36 | output: 37 | + f_score: (1+beta**2) * precision*recall/(beta**2 * precision+recall) 38 | """ 39 | if top_ks is None: 40 | return self.f_k_score(prediction, 41 | ground_truth, 42 | betas) 43 | else: 44 | return self.f_k_score_top(prediction, 45 | ground_truth, 46 | betas, 47 | top_ks) 48 | 49 | def f_k_score(self, 50 | prediction: np.array, 51 | ground_truth: np.array, 52 | betas: list = [1], 53 | threshold: float = None): 54 | """ 55 | compute F-beta score 56 | 57 | input: 58 | + prediction: unthresholded output of the model, np.array of shape [B, N] with B be the 59 | batchsize and N is the number of classes 60 | + ground_truth: self explanatory, must have the same shape as prediction 61 | + betas: a list of betas 62 | output: 63 | + f_scores: (1+beta**2) * precision*recall/(beta**2 * precision+recall) 64 | """ 65 | assert prediction.shape == ground_truth.shape 66 | 67 | if threshold is None: 68 | prediction = prediction >= self.threshold 69 | else: 70 | prediction = (prediction >= threshold) 71 | 72 | prediction = prediction.astype(int) 73 | 74 | ground_truth = ground_truth.reshape(prediction.shape) 75 | num_prediction = np.count_nonzero(prediction, axis=1) 76 | num_ground_truth = np.count_nonzero(ground_truth, axis=1) 77 | 78 | if hasattr(self, "logger"): 79 | self.logger.info( 80 | "Predictions per item: {}, Labels per item: {}".format(np.mean(num_prediction), 81 | np.mean(num_ground_truth)) 82 | ) 83 | 84 | num_true_positive_pred = np.count_nonzero( 85 | ground_truth & prediction, axis=1) 86 | 87 | precision = num_true_positive_pred/num_prediction + self.eps 88 | recall = num_true_positive_pred/num_ground_truth + self.eps 89 | 90 | f_scores = {} 91 | for beta in betas: 92 | beta_squared = beta ** 2 93 | f_score = np.nan_to_num( 94 | (1 + beta_squared)*precision*recall / (beta_squared * precision+recall)) 95 | f_scores["F{}".format(beta)] = np.nanmean(f_score) 96 | 97 | if hasattr(self, "logger"): 98 | self.logger.info( 99 | "Can't give predictions to {} items".format( 100 | np.count_nonzero(np.isnan(precision))) 101 | ) 102 | 103 | return {"precision": np.nanmean(precision), "recall": np.nanmean(recall), "f_score": f_scores} 104 | 105 | def f_k_score_top(self, 106 | prediction: np.array, 107 | ground_truth: np.array, 108 | betas: list, 109 | top_ks: list): 110 | """ 111 | compute F-beta score 112 | 113 | input: 114 | + prediction: unthresholded output of the model, np.array of shape [B, N] with B be the 115 | batchsize and N is the number of classes 116 | + ground_truth: self explanatory, must have the same shape as prediction 117 | + betas: a list of betas 118 | + top_ks: list of top_ks to compute the f_score 119 | output: 120 | + f_scores: (1+beta**2) * precision*recall/(beta**2 * precision+recall) 121 | """ 122 | 123 | assert len(top_ks) > 0, "please specify top_k" 124 | 125 | outputs = {} 126 | 127 | for top_k in top_ks: 128 | # compute threshold for every top_k 129 | k_indices = np.argsort(prediction)[:, ::-1][:, top_k - 1] 130 | 131 | k_thresh = prediction[[range(len(k_indices)), k_indices]] 132 | k_thresh = k_thresh[..., np.newaxis] 133 | outputs["top_{}".format(top_k)] = self.f_k_score( 134 | prediction, ground_truth, betas, k_thresh) 135 | 136 | return outputs 137 | -------------------------------------------------------------------------------- /MAGNeto/magneto/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from magneto.layers import ( 5 | TagEmbedder, 6 | ImageFeatureExtractor, 7 | TagToImageLayer, 8 | GatingLayer 9 | ) 10 | 11 | 12 | class MAGNeto(nn.Module): 13 | def __init__( 14 | self, 15 | d_model: int, 16 | vocab_size: int, 17 | t_blocks: int, 18 | t_heads: int, 19 | i_blocks: int, 20 | i_heads: int, 21 | dropout: float, 22 | t_dim_feedforward: int = 2048, 23 | i_dim_feedforward: int = 2048, 24 | g_dim_feedforward: int = 2048, 25 | img_backbone: str = 'resnet50', 26 | ): 27 | ''' 28 | input: 29 | + d_model: the dimentionality of a context vector, must be divisible by the number of heads. 30 | + vocab_size: self explanatory. 31 | + t_blocks: the number of encoder layers, or blocks, for tag branch. 32 | + t_heads: the number of heads of each Multi-Head Attention layer of the tag branch. 33 | + i_blocks: the number of encoder layers, or blocks, for image branch. 34 | + i_heads: the number of heads of each Multi-Head Attention layer of the image branch. 35 | + dropout: dropout value of the whole network. 36 | + t_dim_feedforward: the dimension of the feedforward network model in the TransformerEncoderLayer class of the tag branch. 37 | + i_dim_feedforward: the dimension of the feedforward network model in the TransformerEncoderLayer class of the image branch. 38 | + g_dim_feedforward: the dimension of the feedforward network model in the GatingLayer class. 39 | + img_backbone: resnet18 or resnet50. 40 | ''' 41 | super(MAGNeto, self).__init__() 42 | 43 | self.tag_embedder = TagEmbedder(vocab_size, d_model) 44 | self.tag_dropout = nn.Dropout(dropout) 45 | self.tag_encoder = nn.TransformerEncoder( 46 | nn.TransformerEncoderLayer( 47 | d_model=d_model, nhead=t_heads, dim_feedforward=t_dim_feedforward, dropout=dropout), 48 | num_layers=t_blocks 49 | ) 50 | self.tag_linear = nn.Linear(d_model, 1) 51 | self.tag_sigmoid = nn.Sigmoid() 52 | 53 | self.img_feature_extractor = ImageFeatureExtractor( 54 | d_model, img_backbone) 55 | self.tag_to_img = TagToImageLayer(d_model, i_heads, dropout) 56 | self.img_dropout = nn.Dropout(dropout) 57 | self.img_encoder = nn.TransformerEncoder( 58 | nn.TransformerEncoderLayer( 59 | d_model=d_model, nhead=i_heads, dim_feedforward=i_dim_feedforward, dropout=dropout), 60 | num_layers=i_blocks 61 | ) 62 | self.img_linear = nn.Linear(d_model, 1) 63 | self.img_sigmoid = nn.Sigmoid() 64 | 65 | self.gating = GatingLayer( 66 | d_model * 2, dim_feedforward=g_dim_feedforward, dropout=dropout) 67 | 68 | def forward(self, src: torch.tensor, img: torch.tensor, mask: torch.tensor) \ 69 | -> (torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor): 70 | ''' 71 | input: 72 | + src: input vectors. 73 | + img: input image. 74 | + mask: used to mask out padding positions. 75 | output: 76 | prediction. 77 | ''' 78 | tag_vectors = self.tag_dropout(self.tag_embedder(src)) 79 | tag_out = self.tag_encoder(tag_vectors.permute( 80 | 1, 0, 2), src_key_padding_mask=mask) 81 | tag_out = torch.relu(tag_out.permute(1, 0, 2)) 82 | 83 | img_regions = self.img_feature_extractor(img) 84 | attn_out = self.img_dropout(torch.relu( 85 | self.tag_to_img(tag_vectors, img_regions) 86 | )) 87 | img_out = self.img_encoder(attn_out.permute( 88 | 1, 0, 2), src_key_padding_mask=mask) 89 | img_out = torch.relu(img_out.permute(1, 0, 2)) 90 | 91 | img_weight = self.gating(torch.cat((tag_out, img_out), dim=-1)) 92 | tag_weight = 1 - img_weight 93 | 94 | tag_out = self.tag_sigmoid(self.tag_linear(tag_out).squeeze(dim=-1)) 95 | img_out = self.img_sigmoid(self.img_linear(img_out).squeeze(dim=-1)) 96 | 97 | out = tag_weight * tag_out + img_weight * img_out 98 | 99 | return out, tag_out, img_out, tag_weight, img_weight 100 | -------------------------------------------------------------------------------- /MAGNeto/magneto/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing as mp 4 | import multiprocessing.pool as mpp 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | from torch import optim 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | from magneto.loss import BCEDiceLoss 15 | from magneto.metrics import PrecisionRecallFk 16 | 17 | 18 | def istarmap(self, func, iterable, chunksize=1): 19 | ''' starmap-version of imap 20 | ''' 21 | if self._state != mpp.RUN: 22 | raise ValueError("Pool not running") 23 | 24 | if chunksize < 1: 25 | raise ValueError( 26 | "Chunksize must be 1+, not {0:n}".format( 27 | chunksize)) 28 | 29 | task_batches = mpp.Pool._get_tasks(func, iterable, chunksize) 30 | result = mpp.IMapIterator(self._cache) 31 | self._taskqueue.put( 32 | ( 33 | self._guarded_task_generation(result._job, 34 | mpp.starmapstar, 35 | task_batches), 36 | result._set_length 37 | )) 38 | return (item for chunk in result for item in chunk) 39 | 40 | 41 | def moving_avg(avg, update, alpha): 42 | return (alpha * avg) + ((1 - alpha) * update) 43 | 44 | 45 | def parse_train_args() -> argparse.Namespace: 46 | ''' 47 | output: 48 | parsed arguments. 49 | ''' 50 | parser = argparse.ArgumentParser(description='MAGNeto training process.') 51 | parser.add_argument( 52 | '--train-csv-path', 53 | type=str, 54 | help='[/path/to/train_data.csv]', 55 | required=True 56 | ) 57 | parser.add_argument( 58 | '--val-csv-path', 59 | type=str, 60 | help='[/path/to/val_data.csv]', 61 | required=True 62 | ) 63 | parser.add_argument( 64 | '--vocab-path', 65 | type=str, 66 | help='[/path/to/vocab.csv]', 67 | required=True 68 | ) 69 | parser.add_argument( 70 | '--img-dir', 71 | type=str, 72 | help='[/path/to/img_dir]', 73 | required=False 74 | ) 75 | parser.add_argument( 76 | '--save-dir', 77 | type=str, 78 | help='[/path/to/save_dir]', 79 | required=True 80 | ) 81 | parser.add_argument( 82 | '--checkpoint-path', 83 | type=str, 84 | help='[/path/to/checkpoint.pth]' 85 | ) 86 | parser.add_argument( 87 | '--load-weights-only', 88 | action='store_true', 89 | help='Only does load model\'s weights from checkpoint.' 90 | ) 91 | parser.add_argument( 92 | '--exclude-top', 93 | action='store_true', 94 | help='Whether excluding top layers or not when loading checkpoint.' 95 | ) 96 | parser.add_argument( 97 | '--start-from-epoch', 98 | type=int, 99 | help='(default: "0".)', 100 | default=0 101 | ) 102 | parser.add_argument( 103 | '--max-len', 104 | type=int, 105 | help='The maximum length for each set of tags (default: 100).', 106 | default=100 107 | ) 108 | parser.add_argument( 109 | '--t-heads', 110 | type=int, 111 | help='The number of heads of each Multi-Head Attention layer of the tag branch (default: 8).', 112 | default=8 113 | ) 114 | parser.add_argument( 115 | '--t-blocks', 116 | type=int, 117 | help='The number of encoder layers, or blocks, for tag branch (default: 6).', 118 | default=6 119 | ) 120 | parser.add_argument( 121 | '--t-dim-feedforward', 122 | type=int, 123 | help='The dimension of the feedforward network model in the TransformerEncoderLayer class of the tag branch, (default: 2048).', 124 | default=2048 125 | ) 126 | parser.add_argument( 127 | '--i-heads', 128 | type=int, 129 | help='The number of heads of each Multi-Head Attention layer of the image branch (default: 8).', 130 | default=8 131 | ) 132 | parser.add_argument( 133 | '--i-blocks', 134 | type=int, 135 | help='The number of encoder layers, or blocks, for image branch (default: 6).', 136 | default=2 137 | ) 138 | parser.add_argument( 139 | '--i-dim-feedforward', 140 | type=int, 141 | help='The dimension of the feedforward network model in the TransformerEncoderLayer of the image branch class, (default: 2048).', 142 | default=2048 143 | ) 144 | parser.add_argument( 145 | '--d-model', 146 | type=int, 147 | help='The dimentionality of a context vector, must be divisible by the number of heads, (default: 512).', 148 | default=512 149 | ) 150 | parser.add_argument( 151 | '--img-backbone', 152 | type=str, 153 | help='resnet18 or resnet50, (default: resnet50).', 154 | default='resnet50' 155 | ) 156 | parser.add_argument( 157 | '--g-dim-feedforward', 158 | type=int, 159 | help='The dimension of the feedforward network model in the GatingLayer class, (default: 2048).', 160 | default=2048 161 | ) 162 | parser.add_argument( 163 | '--dropout', 164 | type=float, 165 | help='Dropout value of tag encoder layers (default: 0.1).', 166 | default=0.1 167 | ) 168 | parser.add_argument( 169 | '--tagaug-add-max-ratio', 170 | type=float, 171 | help='The maximum ratio between the number of adding tags and non-important ones, (default: 0.3).', 172 | default=0.3 173 | ) 174 | parser.add_argument( 175 | '--tagaug-drop-max-ratio', 176 | type=float, 177 | help='The maximum ratio between the number of dropping tags and non-important ones, (default: 0.3).', 178 | default=0.3 179 | ) 180 | parser.add_argument( 181 | '--train-batch-size', 182 | type=int, 183 | help='The batch size used in the training process (default: 64).', 184 | default=64 185 | ) 186 | parser.add_argument( 187 | '--val-batch-size', 188 | type=int, 189 | help='The batch size used in the validation process (default: 128).', 190 | default=128 191 | ) 192 | parser.add_argument( 193 | '--num-workers', 194 | type=int, 195 | help='The number of workers used for data loaders, \ 196 | -1 means using all available processors, \ 197 | rules of thumb: num_workers ~ num_gpu * 4, \ 198 | (default: 4).', 199 | default=4 200 | ) 201 | parser.add_argument( 202 | '--epochs', 203 | type=int, 204 | required=True 205 | ) 206 | parser.add_argument( 207 | '--lr', 208 | type=float, 209 | help='(default: "3e-2".)', 210 | default=3e-2 211 | ) 212 | parser.add_argument( 213 | '--threshold', 214 | type=float, 215 | help='(default: "0.5".)', 216 | default=0.5 217 | ) 218 | parser.add_argument( 219 | '--no-cuda', 220 | action='store_true' 221 | ) 222 | parser.add_argument( 223 | '--gpu-id', 224 | type=int, 225 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).', 226 | default=0 227 | ) 228 | parser.add_argument( 229 | '--use-steplr-scheduler', 230 | action='store_true', 231 | help='Whether or not to use StepLR scheduler. \ 232 | all other schedulers should should be disabled.' 233 | ) 234 | parser.add_argument( 235 | '--sl-gamma', 236 | type=float, 237 | help='StepLR-scheduler\'s multiplicative factor of learning rate decay (default: 0.9).', 238 | default=0.9 239 | ) 240 | parser.add_argument( 241 | '--use-rop-scheduler', 242 | action='store_true', 243 | help='Whether or not to use ReduceLROnPlateau scheduler. \ 244 | all other schedulers should should be disabled.' 245 | ) 246 | parser.add_argument( 247 | '--rop-factor', 248 | type=float, 249 | help='ReduceLROnPlateau scheduler\'s factor parameter (default: 0.3).', 250 | default=0.3 251 | ) 252 | parser.add_argument( 253 | '--rop-patience', 254 | type=int, 255 | help='ReduceLROnPlateau scheduler\'s patience parameter (default: 3).', 256 | default=3 257 | ) 258 | parser.add_argument( 259 | '--log-graph', 260 | action='store_true', 261 | help='Write down model graph.' 262 | ) 263 | parser.add_argument( 264 | '--save-latest', 265 | action='store_true', 266 | help='Save the latest checkpoint.' 267 | ) 268 | parser.add_argument( 269 | '--save-best-f1', 270 | action='store_true', 271 | help='Save the checkpoint based on val F1.' 272 | ) 273 | parser.add_argument( 274 | '--save-best-loss', 275 | action='store_true', 276 | help='Save the checkpoint based on val loss.' 277 | ) 278 | parser.add_argument( 279 | '--save-all-epochs', 280 | action='store_true', 281 | help='Save a checkpoint for each epoch.' 282 | ) 283 | parser.add_argument( 284 | '--log-weight-hist', 285 | action='store_true', 286 | help='Log the histogram of image and tag weights during the validation process.' 287 | ) 288 | 289 | opt = parser.parse_args() 290 | 291 | # Check configuration 292 | assert not (opt.use_steplr_scheduler and opt.use_rop_scheduler), \ 293 | 'Cannot use multiple schedulers at the same time!' 294 | 295 | if opt.num_workers == -1: 296 | opt.num_workers = mp.cpu_count() 297 | 298 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu' 299 | if not opt.no_cuda: 300 | assert torch.cuda.is_available() 301 | 302 | if not os.path.exists(opt.save_dir): 303 | os.makedirs(opt.save_dir) 304 | 305 | assert os.path.isfile(opt.train_csv_path) 306 | assert os.path.isfile(opt.val_csv_path) 307 | assert os.path.exists(opt.img_dir) 308 | 309 | return opt 310 | 311 | 312 | def parse_infer_args() -> argparse.Namespace: 313 | ''' 314 | output: 315 | parsed arguments. 316 | ''' 317 | parser = argparse.ArgumentParser(description='Inference module.') 318 | parser.add_argument( 319 | '--csv-path', 320 | type=str, 321 | help='[/path/to/data.csv]', 322 | required=True 323 | ) 324 | parser.add_argument( 325 | '--img-dir', 326 | type=str, 327 | help='[/path/to/img_dir]', 328 | required=True 329 | ) 330 | parser.add_argument( 331 | '--vocab-path', 332 | type=str, 333 | help='[/path/to/vocab.csv]', 334 | required=True 335 | ) 336 | parser.add_argument( 337 | '--model-path', 338 | type=str, 339 | help='[/path/to/model.pth]', 340 | required=True 341 | ) 342 | parser.add_argument( 343 | '--has-label', 344 | action='store_true' 345 | ) 346 | parser.add_argument( 347 | '--batch-size', 348 | type=int, 349 | help='The batch size used in the inference process (default: 64).', 350 | default=64 351 | ) 352 | parser.add_argument( 353 | '--num-workers', 354 | type=int, 355 | help='The number of workers used for data loaders, \ 356 | -1 means using all available processors, \ 357 | rules of thumb: num_workers ~ num_gpu * 4, \ 358 | (default: 4).', 359 | default=4 360 | ) 361 | parser.add_argument( 362 | '--threshold', 363 | type=float, 364 | help='(default: "0.5".)', 365 | default=0.5 366 | ) 367 | parser.add_argument( 368 | '--top', 369 | type=int, 370 | help='The minimum number of selected important tags for each item (default: 5).', 371 | default=5 372 | ) 373 | parser.add_argument( 374 | '--no-cuda', 375 | action='store_true' 376 | ) 377 | parser.add_argument( 378 | '--gpu-id', 379 | type=int, 380 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).', 381 | default=0 382 | ) 383 | parser.add_argument( 384 | '-m', 385 | '--use-multiprocessing', 386 | action='store_true', 387 | help='Activate multiprocessing.' 388 | ) 389 | 390 | opt = parser.parse_args() 391 | 392 | # Check configuration 393 | if opt.num_workers == -1: 394 | opt.num_workers = mp.cpu_count() 395 | 396 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu' 397 | if not opt.no_cuda: 398 | assert torch.cuda.is_available() 399 | 400 | return opt 401 | 402 | 403 | def parse_preprocessing_args() -> argparse.Namespace: 404 | ''' 405 | output: 406 | parsed arguments. 407 | ''' 408 | parser = argparse.ArgumentParser( 409 | description='Generating labels for "tags" and "important_tags" pairs.') 410 | parser.add_argument( 411 | '-c', 412 | '--csv-path', 413 | type=str, 414 | help='/path/to/raw_data.csv', 415 | required=True 416 | ) 417 | parser.add_argument( 418 | '-s', 419 | '--save-path', 420 | type=str, 421 | help='/path/to/result.csv', 422 | default='./result.csv' 423 | ) 424 | parser.add_argument( 425 | '-tt', 426 | '--tags-field-type', 427 | type=str, 428 | help='str or list (default: str).', 429 | default='str' 430 | ) 431 | parser.add_argument( 432 | '-it', 433 | '--important-tags-field-type', 434 | type=str, 435 | help='str or list (default: str).', 436 | default='str' 437 | ) 438 | parser.add_argument( 439 | '-m', 440 | '--use-multiprocessing', 441 | action='store_true', 442 | help='Activate multiprocessing.' 443 | ) 444 | parser.add_argument( 445 | '--num-workers', 446 | type=int, 447 | help='The number of workers used for data loaders, -1 means using all available processors, (default: -1).', 448 | default=-1 449 | ) 450 | 451 | return parser.parse_args() 452 | 453 | 454 | def parse_pseudo_label_args() -> argparse.Namespace: 455 | ''' 456 | output: 457 | parsed arguments. 458 | ''' 459 | parser = argparse.ArgumentParser(description='Pseudo labeling module.') 460 | parser.add_argument( 461 | '--csv-path', 462 | type=str, 463 | help='[/path/to/data.csv]', 464 | required=True 465 | ) 466 | parser.add_argument( 467 | '--img-dir', 468 | type=str, 469 | help='[/path/to/img_dir]', 470 | required=True 471 | ) 472 | parser.add_argument( 473 | '--model-path', 474 | type=str, 475 | help='[/path/to/model.pth]', 476 | required=True 477 | ) 478 | parser.add_argument( 479 | '--save-path', 480 | type=str, 481 | help='[/path/to/result.csv]', 482 | required=True 483 | ) 484 | parser.add_argument( 485 | '--item-id-field', 486 | type=str, 487 | help='(default: "item_id".)', 488 | default='item_id' 489 | ) 490 | parser.add_argument( 491 | '--tags-field', 492 | type=str, 493 | help='(default: "tags".)', 494 | default='tags' 495 | ) 496 | parser.add_argument( 497 | '--batch-size', 498 | type=int, 499 | help='The batch size used in the inference process (default: 64).', 500 | default=64 501 | ) 502 | parser.add_argument( 503 | '--num-workers', 504 | type=int, 505 | help='The number of workers used for data loaders, \ 506 | -1 means using all available processors, \ 507 | rules of thumb: num_workers ~ num_gpu * 4, \ 508 | (default: 4).', 509 | default=4 510 | ) 511 | parser.add_argument( 512 | '--threshold', 513 | type=float, 514 | help='(default: "0.5".)', 515 | default=0.5 516 | ) 517 | parser.add_argument( 518 | '--pos-threshold', 519 | type=float, 520 | help='The threshold used to classify an item into positive or non-positive class, \ 521 | an item with a score higher than the threshold will be considered a positive sample (default: 0.95).', 522 | default=0.95 523 | ) 524 | parser.add_argument( 525 | '--neg-threshold', 526 | type=float, 527 | help='The threshold used to classify an item into negative or non-negative class, \ 528 | an item with a score lower than the threshold will be considered a negative sample (default: 0.05).', 529 | default=0.05 530 | ) 531 | parser.add_argument( 532 | '--max-ratio', 533 | type=float, 534 | help='The maximum value for the ratio of the number of the confident tags to the number of all tags (default: 0.05).', 535 | default=0.05 536 | ) 537 | parser.add_argument( 538 | '--min-positive', 539 | type=int, 540 | help='The minimum number of positive tags in each item (default: 0).', 541 | default=0 542 | ) 543 | parser.add_argument( 544 | '--no-cuda', 545 | action='store_true' 546 | ) 547 | parser.add_argument( 548 | '--gpu-id', 549 | type=int, 550 | help='The ID of selected GPU, --no-cuda must be disabled, (default: 0).', 551 | default=0 552 | ) 553 | 554 | opt = parser.parse_args() 555 | 556 | # Check configuration 557 | if opt.num_workers == -1: 558 | opt.num_workers = mp.cpu_count() 559 | 560 | opt.device = 'cuda:{0}'.format(opt.gpu_id) if not opt.no_cuda else 'cpu' 561 | if not opt.no_cuda: 562 | assert torch.cuda.is_available() 563 | 564 | return opt 565 | 566 | 567 | class TensorBoardWriter(object): 568 | def __init__(self, log_dir: str, purge_step: int = 0): 569 | self.log_dir = log_dir 570 | self.purge_step = purge_step 571 | 572 | def __enter__(self): 573 | self.writer = SummaryWriter( 574 | log_dir=self.log_dir, 575 | purge_step=self.purge_step 576 | ) 577 | 578 | return self.writer 579 | 580 | def __exit__(self, type, value, traceback): 581 | self.writer.close() 582 | 583 | 584 | class Trainer(object): 585 | def __init__( 586 | self, 587 | model: nn.Module, 588 | optimizer: optim.Optimizer, 589 | opt: argparse.Namespace 590 | ): 591 | self.model = model 592 | self.optimizer = optimizer 593 | self.opt = opt 594 | 595 | self.start_from_epoch = self.opt.start_from_epoch 596 | self.stop_at_epoch = self.opt.start_from_epoch + self.opt.epochs 597 | self.log_dir = './runs/{0}'.format( 598 | "_".join(self.opt.save_dir.split("/")[-1].split("."))) 599 | 600 | self.criterion = { 601 | 'both': BCEDiceLoss(beta=1.0), 602 | 'tag': BCEDiceLoss(beta=1.0), 603 | 'img': BCEDiceLoss(beta=1.0) 604 | } 605 | 606 | # Initialize monitoring params 607 | self.best_val_loss = np.inf 608 | self.best_val_f1 = 0 609 | self.best_val_precision = 0 610 | self.best_val_recall = 0 611 | self.alpha = 0.9 # Mean over 10 iters 612 | 613 | self.fk_eval = PrecisionRecallFk( 614 | enable_logger=False, threshold=self.opt.threshold) 615 | 616 | if self.opt.use_rop_scheduler: 617 | self.rop_scheduler = optim.lr_scheduler.ReduceLROnPlateau( 618 | optimizer=self.optimizer, 619 | factor=self.opt.rop_factor, 620 | patience=self.opt.rop_patience, 621 | min_lr=1e-7, 622 | verbose=True 623 | ) 624 | elif self.opt.use_steplr_scheduler: 625 | self.steplr_scheduler = optim.lr_scheduler.StepLR( 626 | optimizer=self.optimizer, 627 | step_size=1, 628 | gamma=self.opt.sl_gamma 629 | ) 630 | 631 | if self.opt.checkpoint_path is not None: 632 | self._load_checkpoint() 633 | 634 | def fit( 635 | self, 636 | train_dataloader: DataLoader, 637 | val_dataloader: DataLoader 638 | ): 639 | with TensorBoardWriter(self.log_dir, purge_step=self.start_from_epoch) as writer: 640 | if self.opt.log_graph: 641 | self._log_graph(train_dataloader, writer) 642 | 643 | print('\nTraining model...') 644 | for epoch in range(self.start_from_epoch, self.stop_at_epoch): 645 | self._fit_an_epoch( 646 | train_dataloader, val_dataloader, writer, epoch) 647 | 648 | def _log_graph(self, dataloader, writer): 649 | image_batch, tags_batch, mask_batch, _ = next( 650 | iter(dataloader)) 651 | image_batch = image_batch.to(self.opt.device) 652 | tags_batch = tags_batch.to(self.opt.device) 653 | mask_batch = mask_batch.to(self.opt.device) 654 | writer.add_graph(self.model, (tags_batch, image_batch, mask_batch)) 655 | 656 | def _load_checkpoint(self): 657 | assert os.path.isfile(self.opt.checkpoint_path) 658 | 659 | print('\nLoading checkpoint...') 660 | states = torch.load(self.opt.checkpoint_path, 661 | map_location=lambda storage, loc: storage) 662 | print('|`-- Loading model...') 663 | print('+--------------------') 664 | model_dict = self.model.state_dict() 665 | excluding_layers = [ 666 | 'img_linear.weight', 667 | 'img_linear.bias', 668 | 'tag_linear.weight', 669 | 'tag_linear.bias', 670 | 'gating.linear_1.weight', 671 | 'gating.linear_1.bias', 672 | 'gating.linear_2.weight', 673 | 'gating.linear_2.bias' 674 | ] if self.opt.exclude_top else [] 675 | pretrained_dict = {k: v for k, v in states['model'].items() 676 | if k in model_dict and k not in excluding_layers} 677 | model_dict.update(pretrained_dict) 678 | self.model.load_state_dict(model_dict) 679 | if not self.opt.load_weights_only: 680 | print('|`-- Loading optimizer...') 681 | self.optimizer.load_state_dict(states['optimizer']) 682 | print('|`-- Loading best val loss...') 683 | self.best_val_loss = states['best_val_loss'] 684 | print('|`-- Loading best val f1...') 685 | self.best_val_f1 = states['best_val_f1'] 686 | print('|`-- Loading best val precision...') 687 | self.best_val_precision = states['best_val_precision'] 688 | print(' `-- Loading best val recall...') 689 | self.best_val_recall = states['best_val_recall'] 690 | 691 | def _save_checkpoint( 692 | self, 693 | new_loss, 694 | new_f1, 695 | new_precision, 696 | new_recall, 697 | epoch 698 | ): 699 | found_better_val_loss = new_loss < self.best_val_loss 700 | found_better_val_f1 = new_f1 > self.best_val_f1 701 | 702 | self.best_val_loss = np.minimum( 703 | self.best_val_loss, new_loss) 704 | self.best_val_f1 = np.maximum( 705 | self.best_val_f1, new_f1) 706 | self.best_val_precision = np.maximum( 707 | self.best_val_precision, new_precision) 708 | self.best_val_recall = np.maximum( 709 | self.best_val_recall, new_recall) 710 | 711 | states = { 712 | 'model': self.model.state_dict(), 713 | 'optimizer': self.optimizer.state_dict(), 714 | 'best_val_loss': self.best_val_loss, 715 | 'best_val_f1': self.best_val_f1, 716 | 'best_val_precision': self.best_val_precision, 717 | 'best_val_recall': self.best_val_recall, 718 | 'config': vars(self.opt) 719 | } 720 | 721 | if self.opt.save_best_loss and found_better_val_loss: 722 | print(' \__ Found a better checkpoint based on val loss -> Saving...') 723 | torch.save(states, os.path.join( 724 | self.opt.save_dir, 'best_loss.pth')) 725 | 726 | if self.opt.save_best_f1 and found_better_val_f1: 727 | print(' \__ Found a better checkpoint based on val F1 -> Saving...') 728 | torch.save(states, os.path.join(self.opt.save_dir, 'best_f1.pth')) 729 | 730 | if self.opt.save_latest: 731 | torch.save(states, os.path.join(self.opt.save_dir, 'latest.pth')) 732 | 733 | if self.opt.save_all_epochs: 734 | torch.save(states, os.path.join( 735 | self.opt.save_dir, 'epoch_{0}.pth'.format(epoch+1))) 736 | 737 | def _compute_running_precision_recall_f1( 738 | self, 739 | pred, 740 | label, 741 | running_precision, 742 | running_recall, 743 | running_f1 744 | ): 745 | fk_eval_dict = self.fk_eval(pred, label, betas=[1]) 746 | running_precision = moving_avg( 747 | running_precision, np.nan_to_num(fk_eval_dict['precision']), self.alpha) 748 | running_recall = moving_avg( 749 | running_recall, np.nan_to_num(fk_eval_dict['recall']), self.alpha) 750 | running_f1 = moving_avg( 751 | running_f1, np.nan_to_num(fk_eval_dict['f_score']['F1']), self.alpha) 752 | 753 | return running_precision, running_recall, running_f1 754 | 755 | def _compute_batch_precision_recall_f1( 756 | self, 757 | pred, 758 | label, 759 | batch_val_idx, 760 | local_batch_size, 761 | batch_val_precision, 762 | batch_val_recall, 763 | batch_val_f1 764 | ): 765 | fk_eval_dict = self.fk_eval(pred, label, betas=[1]) 766 | batch_val_precision[batch_val_idx] = np.nan_to_num( 767 | fk_eval_dict['precision']) * local_batch_size 768 | batch_val_recall[batch_val_idx] = np.nan_to_num( 769 | fk_eval_dict['recall']) * local_batch_size 770 | batch_val_f1[batch_val_idx] = np.nan_to_num( 771 | fk_eval_dict['f_score']['F1']) * local_batch_size 772 | 773 | return batch_val_precision, batch_val_recall, batch_val_f1 774 | 775 | def _fit_an_epoch(self, train_dataloader, val_dataloader, writer, epoch): 776 | # Training process 777 | self.model.train() 778 | 779 | # Initialize a dictionary to store numeric values 780 | running = { 781 | 'loss': { 782 | 'both': 0, 783 | 'tag': 0, 784 | 'img': 0, 785 | 'sum': 0 786 | }, 787 | 'f1': { 788 | 'both': 0, 789 | 'tag': 0, 790 | 'img': 0 791 | }, 792 | 'precision': { 793 | 'both': 0, 794 | 'tag': 0, 795 | 'img': 0 796 | }, 797 | 'recall': { 798 | 'both': 0, 799 | 'tag': 0, 800 | 'img': 0 801 | }, 802 | 'weight': { 803 | 'tag': 0, 804 | 'img': 0 805 | } 806 | } 807 | 808 | train_pbar = tqdm(train_dataloader) 809 | train_pbar.desc = '* Epoch {0}'.format(epoch+1) 810 | 811 | for batch_idx, (image_batch, tags_batch, mask_batch, label_batch) in enumerate(train_pbar): 812 | image_batch = image_batch.to(self.opt.device) 813 | tags_batch = tags_batch.to(self.opt.device) 814 | label_batch = label_batch.to(self.opt.device) 815 | mask_batch = mask_batch.to(self.opt.device) 816 | 817 | preds = dict() 818 | weight = dict() 819 | preds['both'], preds['tag'], preds['img'], weight['tag'], weight['img'] = \ 820 | self.model(tags_batch, image_batch, mask_batch) 821 | 822 | for key in preds.keys(): 823 | preds[key] = preds[key].masked_fill( 824 | mask_batch, 825 | 0.0 826 | ) 827 | 828 | loss = dict() 829 | for key in preds.keys(): 830 | loss[key] = self.criterion[key](preds[key], label_batch) 831 | loss['sum'] = loss['both'] + loss['tag'] + loss['img'] 832 | 833 | self.optimizer.zero_grad() 834 | loss['sum'].backward() 835 | self.optimizer.step() 836 | 837 | processed_label = label_batch.detach().cpu().numpy().astype(np.uint8) 838 | processed_pred = dict() 839 | for key in preds.keys(): 840 | processed_pred[key] = preds[key].detach().cpu().numpy() 841 | 842 | # Compute running losses 843 | for key in running['loss'].keys(): 844 | running['loss'][key] = moving_avg( 845 | running['loss'][key], loss[key].item(), self.alpha) 846 | 847 | # Compute running weights 848 | for key in running['weight'].keys(): 849 | running['weight'][key] = moving_avg( 850 | running['weight'][key], weight[key].mean().item(), self.alpha) 851 | 852 | # Compute running precision, recall and f1 853 | for key in processed_pred.keys(): 854 | running['precision'][key], running['recall'][key], running['f1'][key] = \ 855 | self._compute_running_precision_recall_f1( 856 | processed_pred[key], 857 | processed_label, 858 | running['precision'][key], 859 | running['recall'][key], 860 | running['f1'][key] 861 | ) 862 | 863 | train_pbar.set_postfix({ 864 | 'loss': running['loss']['both'], 865 | 'f1': running['f1']['both'], 866 | 'prec': running['precision']['both'], 867 | 'recall': running['recall']['both'], 868 | }) 869 | 870 | # Log to TensorBoard 871 | for key in running.keys(): 872 | for subkey in running[key]: 873 | writer.add_scalar('{0}/train_{1}'.format(key, subkey), 874 | running[key][subkey], epoch) 875 | 876 | # Validation process 877 | self.model.eval() 878 | 879 | with torch.no_grad(): 880 | # Initialize a dictionary to store 1d arrays 881 | batch_val = { 882 | 'loss': { 883 | 'both': np.zeros(len(val_dataloader)), 884 | 'tag': np.zeros(len(val_dataloader)), 885 | 'img': np.zeros(len(val_dataloader)), 886 | 'sum': np.zeros(len(val_dataloader)) 887 | }, 888 | 'f1': { 889 | 'both': np.zeros(len(val_dataloader)), 890 | 'tag': np.zeros(len(val_dataloader)), 891 | 'img': np.zeros(len(val_dataloader)) 892 | }, 893 | 'precision': { 894 | 'both': np.zeros(len(val_dataloader)), 895 | 'tag': np.zeros(len(val_dataloader)), 896 | 'img': np.zeros(len(val_dataloader)) 897 | }, 898 | 'recall': { 899 | 'both': np.zeros(len(val_dataloader)), 900 | 'tag': np.zeros(len(val_dataloader)), 901 | 'img': np.zeros(len(val_dataloader)) 902 | }, 903 | 'weight': { 904 | 'tag': np.zeros(len(val_dataloader)), 905 | 'img': np.zeros(len(val_dataloader)) 906 | }, 907 | } 908 | 909 | if self.opt.log_weight_hist: 910 | all_val_weights = { 911 | 'tag': [], 912 | 'img': [] 913 | } 914 | 915 | num_items = 0 916 | 917 | val_pbar = tqdm(val_dataloader) 918 | val_pbar.desc = '\__ Validating' 919 | 920 | for batch_val_idx, (image_batch, tags_batch, mask_batch, label_batch) in enumerate(val_pbar): 921 | image_batch = image_batch.to(self.opt.device) 922 | tags_batch = tags_batch.to(self.opt.device) 923 | label_batch = label_batch.to(self.opt.device) 924 | mask_batch = mask_batch.to(self.opt.device) 925 | 926 | preds = dict() 927 | weight = dict() 928 | preds['both'], preds['tag'], preds['img'], weight['tag'], weight['img'] = \ 929 | self.model(tags_batch, image_batch, mask_batch) 930 | 931 | for key in preds.keys(): 932 | preds[key] = preds[key].masked_fill( 933 | mask_batch, 934 | 0.0 935 | ) 936 | 937 | val_loss = dict() 938 | for key in preds.keys(): 939 | val_loss[key] = self.criterion[key]( 940 | preds[key], label_batch) 941 | val_loss['sum'] = val_loss['both'] + \ 942 | val_loss['tag'] + val_loss['img'] 943 | 944 | processed_label = label_batch.detach().cpu().numpy().astype(np.uint8) 945 | processed_pred = dict() 946 | for key in preds.keys(): 947 | processed_pred[key] = preds[key].detach().cpu().numpy() 948 | 949 | local_batch_size = label_batch.size(0) 950 | num_items += local_batch_size 951 | 952 | # Compute sum batch losses 953 | for key in batch_val['loss'].keys(): 954 | batch_val['loss'][key][batch_val_idx] = val_loss[key].item( 955 | ) * local_batch_size 956 | 957 | # Compute sum batch weights 958 | for key in batch_val['weight'].keys(): 959 | batch_val['weight'][key][batch_val_idx] = weight[key].mean().item( 960 | ) * local_batch_size 961 | 962 | # Keep all available weight values (if needed) 963 | if self.opt.log_weight_hist: 964 | for key in batch_val['weight'].keys(): 965 | all_val_weights[key].extend( 966 | weight[key].detach().cpu().numpy().reshape(-1)) 967 | 968 | # Compute sum batch precision, recall and f1 969 | for key in processed_pred.keys(): 970 | batch_val['precision'][key], batch_val['recall'][key], batch_val['f1'][key] = \ 971 | self._compute_batch_precision_recall_f1( 972 | processed_pred[key], 973 | processed_label, 974 | batch_val_idx, 975 | local_batch_size, 976 | batch_val['precision'][key], 977 | batch_val['recall'][key], 978 | batch_val['f1'][key] 979 | ) 980 | 981 | val_pbar.set_postfix({ 982 | 'loss': batch_val['loss']['both'][batch_val_idx] / local_batch_size, 983 | 'f1': batch_val['f1']['both'][batch_val_idx] / local_batch_size, 984 | 'prec': batch_val['precision']['both'][batch_val_idx] / local_batch_size, 985 | 'recall': batch_val['recall']['both'][batch_val_idx] / local_batch_size, 986 | }) 987 | 988 | mean_val = dict() 989 | for key in batch_val.keys(): 990 | mean_val[key] = dict() 991 | for subkey in batch_val[key].keys(): 992 | mean_val[key][subkey] = np.sum( 993 | batch_val[key][subkey]) / num_items 994 | 995 | if self.opt.use_rop_scheduler: 996 | self.rop_scheduler.step(mean_val['loss']['sum']) 997 | elif self.opt.use_steplr_scheduler: 998 | self.steplr_scheduler.step() 999 | 1000 | # Log to TensorBoard 1001 | for key in mean_val.keys(): 1002 | for subkey in mean_val[key].keys(): 1003 | writer.add_scalar('{0}/val_{1}'.format(key, subkey), 1004 | mean_val[key][subkey], epoch) 1005 | if self.opt.log_weight_hist: 1006 | for key in all_val_weights.keys(): 1007 | writer.add_histogram( 1008 | 'weight/{0}'.format(key), np.array(all_val_weights[key]), epoch) 1009 | 1010 | # Save checkpoint 1011 | self._save_checkpoint( 1012 | new_loss=mean_val['loss']['both'], 1013 | new_f1=mean_val['f1']['both'], 1014 | new_precision=mean_val['precision']['both'], 1015 | new_recall=mean_val['recall']['both'], 1016 | epoch=epoch 1017 | ) 1018 | -------------------------------------------------------------------------------- /MAGNeto/preprocess.py: -------------------------------------------------------------------------------- 1 | import ast 2 | 3 | import pandas as pd 4 | from tqdm import tqdm 5 | 6 | from magneto.utils import parse_preprocessing_args 7 | 8 | 9 | def make_label(tags, important_tags) -> list: 10 | ''' 11 | input: 12 | + tags: all available tags of an item. 13 | + important_tags: tags that marked as important. 14 | output: 15 | a binary mask with 0 for unimportant tags and 1 for important ones. 16 | ''' 17 | return ['1' if tag in important_tags else '0' for tag in tags] 18 | 19 | 20 | def label_important_tags( 21 | item_id, 22 | tags, 23 | important_tags 24 | ) -> dict: 25 | ''' 26 | input: 27 | + item_id: the ID of an item. 28 | + tags: all available tags of an item. 29 | + important_tags: tags that marked as important. 30 | output: 31 | a dictionary which includes all needed information of an item. 32 | ''' 33 | label = make_label(tags, important_tags) 34 | 35 | return { 36 | 'item_id': item_id, 37 | 'tags': ','.join(tags), 38 | 'important_tags': ','.join(important_tags), 39 | 'label': ','.join(label) 40 | } 41 | 42 | 43 | def main(): 44 | opt = parse_preprocessing_args() 45 | 46 | df = pd.read_csv(opt.csv_path) 47 | 48 | assert 'tags' in df.columns 49 | assert 'important_tags' in df.columns 50 | assert opt.tags_field_type in ['str', 'list'] 51 | assert opt.important_tags_field_type in ['str', 'list'] 52 | 53 | series_of_item_id = df['item_id'] 54 | series_of_tags = df['tags'] 55 | series_of_important_tags = df['important_tags'] 56 | 57 | if opt.tags_field_type == 'str': 58 | series_of_tags = series_of_tags.apply(lambda x: x.split(',')) 59 | elif opt.tags_field_type == 'list': 60 | series_of_tags = series_of_tags.apply(ast.literal_eval) 61 | 62 | if opt.important_tags_field_type == 'str': 63 | series_of_important_tags = series_of_important_tags.apply( 64 | lambda x: x.split(',')) 65 | elif opt.important_tags_field_type == 'list': 66 | series_of_important_tags = series_of_important_tags.apply(ast.literal_eval) 67 | 68 | rows_dict = dict() 69 | i = 0 70 | 71 | if opt.use_multiprocessing: 72 | import multiprocessing as mp 73 | 74 | # Apply a patch for the multiprocessing module 75 | import multiprocessing.pool as mpp 76 | from magneto.utils import istarmap 77 | mpp.Pool.istarmap = istarmap 78 | 79 | if opt.num_workers == -1: 80 | opt.num_workers = mp.cpu_count() 81 | 82 | inputs = list(zip( 83 | series_of_item_id, 84 | series_of_tags, 85 | series_of_important_tags 86 | )) 87 | with mp.Pool(opt.num_workers) as pool: 88 | for result in tqdm(pool.istarmap(label_important_tags, inputs), total=len(inputs)): 89 | rows_dict[i] = result 90 | i += 1 91 | else: 92 | for item_id, tags, important_tags \ 93 | in tqdm(list(zip( 94 | series_of_item_id, 95 | series_of_tags, 96 | series_of_important_tags 97 | ))): 98 | 99 | result = label_important_tags( 100 | item_id, 101 | tags, 102 | important_tags 103 | ) 104 | 105 | rows_dict[i] = result 106 | i += 1 107 | 108 | new_df = pd.DataFrame.from_dict(rows_dict, 'index') 109 | new_df.to_csv(opt.save_path, index=False) 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /MAGNeto/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.19.0 2 | scipy==1.5.2 3 | opencv-python==4.2.0.34 4 | pandas==1.0.5 5 | Pillow==6.1.0 6 | tensorboard==2.2.2 7 | torch==1.5.1 8 | torchvision==0.6.1 9 | tqdm==4.47.0 10 | matplotlib==3.2.2 11 | -------------------------------------------------------------------------------- /MAGNeto/scripts/start_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m infer \ 4 | --csv-path ./data/nus_wide/annotations/val_81_with_label.csv \ 5 | --img-dir ./data/nus_wide/images \ 6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \ 7 | --model-path ./snapshots/demo/best_f1.pth \ 8 | --batch-size 32 \ 9 | --num-workers 4 \ 10 | --threshold 0.5 \ 11 | --top 0 \ 12 | --gpu-id 0 \ 13 | --has-label \ 14 | -m 15 | -------------------------------------------------------------------------------- /MAGNeto/scripts/start_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m preprocess \ 4 | -c data/nus_wide/annotations/train_81.csv \ 5 | -s data/nus_wide/annotations/train_81_with_label.csv \ 6 | -tt str \ 7 | -it str \ 8 | -m \ 9 | --num-workers 4 -------------------------------------------------------------------------------- /MAGNeto/scripts/start_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m train \ 4 | --train-csv-path ./data/nus_wide/annotations/train_81_with_label.csv \ 5 | --val-csv-path ./data/nus_wide/annotations/val_81_with_label.csv \ 6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \ 7 | --img-dir ./data/nus_wide/images \ 8 | --save-dir ./snapshots/demo \ 9 | --start-from-epoch 0 \ 10 | --t-heads 4 \ 11 | --t-blocks 2 \ 12 | --t-dim-feedforward 512 \ 13 | --i-heads 4 \ 14 | --i-blocks 1 \ 15 | --i-dim-feedforward 512 \ 16 | --img-backbone resnet18 \ 17 | --d-model 128 \ 18 | --max-len 16 \ 19 | --g-dim-feedforward 512 \ 20 | --dropout 0.3 \ 21 | --threshold 0.5 \ 22 | --tagaug-add-max-ratio 1.0 \ 23 | --tagaug-drop-max-ratio 0.0 \ 24 | --train-batch-size 32 \ 25 | --val-batch-size 32 \ 26 | --epochs 500 \ 27 | --gpu-id 0 \ 28 | --num-workers 8 \ 29 | --log-graph \ 30 | --save-best-loss \ 31 | --save-best-f1 \ 32 | --save-latest \ 33 | --lr 1e-2 \ 34 | --log-weight-hist \ -------------------------------------------------------------------------------- /MAGNeto/scripts/start_train_usp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m train \ 4 | --train-csv-path ./data/nus_wide/annotations/train_81_with_label.csv \ 5 | --val-csv-path ./data/nus_wide/annotations/val_81_with_label.csv \ 6 | --vocab-path ./data/nus_wide/annotations/vocab_81.csv \ 7 | --img-dir ./data/nus_wide/images \ 8 | --save-dir ./snapshots/nus_wide_81_add0p0_drop0p0_with_unsupervised_pretraining_Sep_12_20 \ 9 | --checkpoint-path ./snapshots/nuswide_top_81_ver_1_unsupervised_pretraining_Sep_10_20/ckpt.pth \ 10 | --load-weights-only \ 11 | --exclude-top \ 12 | --start-from-epoch 0 \ 13 | --t-heads 4 \ 14 | --t-blocks 2 \ 15 | --t-dim-feedforward 512 \ 16 | --i-heads 4 \ 17 | --i-blocks 1 \ 18 | --i-dim-feedforward 512 \ 19 | --img-backbone resnet18 \ 20 | --d-model 128 \ 21 | --max-len 16 \ 22 | --g-dim-feedforward 512 \ 23 | --dropout 0.3 \ 24 | --threshold 0.5 \ 25 | --tagaug-add-max-ratio 0.0 \ 26 | --tagaug-drop-max-ratio 0.0 \ 27 | --train-batch-size 32 \ 28 | --val-batch-size 32 \ 29 | --epochs 500 \ 30 | --gpu-id 3 \ 31 | --num-workers 8 \ 32 | --log-graph \ 33 | --save-best-loss \ 34 | --lr 1e-2 \ 35 | --log-weight-hist \ -------------------------------------------------------------------------------- /MAGNeto/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torch import optim 4 | 5 | from magneto.utils import parse_train_args 6 | from magneto.data import get_dataloaders 7 | from magneto.model import MAGNeto 8 | from magneto.utils import Trainer 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def main(): 14 | ##### GET CONFIGURATION ##### 15 | opt = parse_train_args() 16 | 17 | ##### PREPARING DATASETS ##### 18 | print('\nPreparing datasets...') 19 | train_dataloader, val_dataloader, vocab_size = get_dataloaders( 20 | train_csv_path=opt.train_csv_path, 21 | val_csv_path=opt.val_csv_path, 22 | vocab_path=opt.vocab_path, 23 | img_dir=opt.img_dir, 24 | tagaug_add_max_ratio=opt.tagaug_add_max_ratio, 25 | tagaug_drop_max_ratio=opt.tagaug_drop_max_ratio, 26 | train_batch_size=opt.train_batch_size, 27 | val_batch_size=opt.val_batch_size, 28 | max_len=opt.max_len, 29 | num_workers=opt.num_workers, 30 | pin_memory=True if not opt.no_cuda else False 31 | ) 32 | 33 | ##### CREATE MODEL ##### 34 | model = MAGNeto( 35 | d_model=opt.d_model, 36 | vocab_size=vocab_size, 37 | t_blocks=opt.t_blocks, 38 | t_heads=opt.t_heads, 39 | t_dim_feedforward=opt.t_dim_feedforward, 40 | i_blocks=opt.i_blocks, 41 | i_heads=opt.i_heads, 42 | i_dim_feedforward=opt.i_dim_feedforward, 43 | img_backbone=opt.img_backbone, 44 | g_dim_feedforward=opt.g_dim_feedforward, 45 | dropout=opt.dropout, 46 | ) 47 | model = model.to(opt.device) 48 | 49 | ##### CREATE OPTIMIZER ##### 50 | optimizer = optim.SGD( 51 | filter(lambda p: p.requires_grad, model.parameters()), 52 | lr=opt.lr, 53 | momentum=0.9 54 | ) 55 | 56 | ##### CREATE TRAINER AND START THE TRAINING PROCESS ##### 57 | trainer = Trainer( 58 | model=model, 59 | optimizer=optimizer, 60 | opt=opt 61 | ) 62 | trainer.fit(train_dataloader, val_dataloader) 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LabTeam 2 | 3 | ## Publications 4 | - MAGNeto: An Efficient Deep Learning Method for the Extractive Tags Summarization Problem [[code](MAGNeto)][[abs](https://arxiv.org/abs/2011.04349)][[pdf](https://arxiv.org/pdf/2011.04349)] 5 | --------------------------------------------------------------------------------