├── .gitignore ├── LICENSE ├── README.md ├── coloredmnist └── train_coloredmnist.py ├── domainbed ├── __init__.py ├── algorithms.py ├── command_launchers.py ├── datasets.py ├── hparams_registry.py ├── lib │ ├── fast_data_loader.py │ ├── misc.py │ ├── query.py │ ├── reporting.py │ └── wide_resnet.py ├── model_selection.py ├── networks.py └── scripts │ ├── __init__.py │ ├── collect_results.py │ ├── download.py │ ├── list_top_hparams.py │ ├── save_images.py │ ├── sweep.py │ └── train.py ├── fig_intro.png └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Build and Release Folders 2 | bin-debug/ 3 | bin-release/ 4 | [Oo]bj/ 5 | [Bb]in/ 6 | 7 | # Other files and folders 8 | .settings/ 9 | *__pycache__* 10 | 11 | # Executables 12 | *.swf 13 | *.air 14 | *.ipa 15 | *.apk 16 | *.pyc 17 | 18 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` 19 | # should NOT be excluded as they contain compiler settings and other important 20 | # information for Eclipse / Flash Builder. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fishr: Invariant Gradient Variances for Out-of-distribution Generalization 2 | 3 | Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization, ICML 2022 | [paper](https://arxiv.org/abs/2109.02934) 4 | 5 | [Alexandre Ramé](https://alexrame.github.io/), [Corentin Dancette](https://cdancette.fr/), [Matthieu Cord](http://webia.lip6.fr/~cord/) 6 | 7 | ![](./fig_intro.png) 8 | 9 | 10 | ## Abstract 11 | Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols. 12 | 13 | In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights. 14 | 15 | Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than 16 | Empirical Risk Minimization. 17 | 18 | # Installation 19 | 20 | ## Requirements overview 21 | 22 | Our implementation relies on the [BackPACK](https://github.com/f-dangel/backpack/) package in [PyTorch](https://pytorch.org/) to easily compute gradient variances. 23 | 24 | - python == 3.7.10 25 | - torch == 1.8.1 26 | - torchvision == 0.9.1 27 | - backpack-for-pytorch == 1.3.0 28 | - numpy == 1.20.2 29 | 30 | ## Procedure 31 | 32 | 1. Clone the repo: 33 | ```bash 34 | $ git clone https://github.com/alexrame/fishr.git 35 | ``` 36 | 37 | 2. Install this repository and the dependencies using pip: 38 | ```bash 39 | $ conda create --name fishr python=3.7.10 40 | $ conda activate fishr 41 | $ cd fishr 42 | $ pip install -r requirements.txt 43 | ``` 44 | 45 | With this, you can edit the Fishr code on the fly. 46 | 47 | # Overview 48 | 49 | This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by [IRM](https://github.com/facebookresearch/InvariantRiskMinimization/tree/master/code/colored_mnist) and (2) on the [DomainBed](https://github.com/facebookresearch/DomainBed/) benchmark. 50 | 51 | 52 | ## Colored MNIST in the IRM setup 53 | 54 | We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST. 55 | ### Main results (Table 2 in Section 6.A) 56 | 57 | To reproduce the results from Table 2, call ```python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm``` where `algorithm` is either: 58 | - ```erm``` for Empirical Risk Minimization 59 | - ```irm``` for [Invariant Risk Minimization](https://arxiv.org/abs/1907.02893) 60 | - ```rex``` for [Out-of-Distribution Generalization via Risk Extrapolation](https://icml.cc/virtual/2021/oral/9186) 61 | - ```fishr``` for our proposed Fishr 62 | 63 | Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal [IRM](https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/reproduce_paper_results.sh) implementation. 64 | 65 | Method | Train acc. | Test acc. | Gray test acc. 66 | --------|------------|------------|---------------- 67 | ERM | 86.4 ± 0.2 | 14.0 ± 0.7 | 71.0 ± 0.7 68 | IRM | 71.0 ± 0.5 | 65.6 ± 1.8 | 66.1 ± 0.2 69 | V-REx | 71.7 ± 1.5 | 67.2 ± 1.5 | 68.6 ± 2.2 70 | Fishr | 71.0 ± 0.9 | 69.5 ± 1.0 | 70.2 ± 1.1 71 | 72 | 73 | 74 | ### Without label flipping (Table 5 in Appendix C.2.3) 75 | The script ```coloredmnist.train_coloredmnist``` also accepts as input the argument `--label_flipping_prob` which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set `--label_flipping_prob 0`. 76 | ### Fishr variants (Table 6 in Appendix C.2.4) 77 | This table considers two additional Fishr variants, reproduced with `algorithm` set to: 78 | - ```fishr_offdiagonal``` for Fishr but without centering the gradient variances 79 | - ```fishr_notcentered``` for Fishr but on the full covariance rather than only the diagonal 80 | 81 | ## DomainBed 82 | 83 | DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in [In Search of Lost Domain Generalization](https://arxiv.org/abs/2007.01434). Instructions below are copied and adapted from the official [github](https://github.com/facebookresearch/DomainBed/). 84 | 85 | ### Algorithms and hyperparameter grids 86 | 87 | We added Fishr as a new algorithm [here](domainbed/algorithms.py), and defined Fishr's hyperparameter grids [here](domainbed/hparams_registry.py), as defined in Table 7 in Appendix D. 88 | 89 | ### Datasets 90 | 91 | We ran Fishr on following [datasets](domainbed/datasets.py): 92 | 93 | * Rotated MNIST ([Ghifary et al., 2015](https://arxiv.org/abs/1508.07680)) 94 | * Colored MNIST ([Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893)) 95 | * VLCS ([Fang et al., 2013](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf)) 96 | * PACS ([Li et al., 2017](https://arxiv.org/abs/1710.03077)) 97 | * OfficeHome ([Venkateswara et al., 2017](https://arxiv.org/abs/1706.07522)) 98 | * A TerraIncognita ([Beery et al., 2018](https://arxiv.org/abs/1807.04975)) subset 99 | * DomainNet ([Peng et al., 2019](http://ai.bu.edu/M3SDA/)) 100 | 101 | ### Launch training 102 | 103 | Download the datasets: 104 | 105 | ```sh 106 | python3 -m domainbed.scripts.download\ 107 | --data_dir=/my/data/dir 108 | ``` 109 | 110 | Train a model for debugging: 111 | 112 | ```sh 113 | python3 -m domainbed.scripts.train\ 114 | --data_dir=/my/data/dir/\ 115 | --algorithm Fishr\ 116 | --dataset ColoredMNIST\ 117 | --test_env 2 118 | ``` 119 | 120 | Launch a sweep for hyperparameter search: 121 | 122 | ```sh 123 | python -m domainbed.scripts.sweep launch\ 124 | --data_dir=/my/data/dir/\ 125 | --output_dir=/my/sweep/output/path\ 126 | --command_launcher MyLauncher 127 | --datasets ColoredMNIST\ 128 | --algorithms Fishr 129 | ``` 130 | Here, `MyLauncher` is your cluster's command launcher, as implemented in `command_launchers.py`. 131 | 132 | 133 | ### Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G) 134 | 135 | To view the results of your sweep: 136 | 137 | ````sh 138 | python -m domainbed.scripts.collect_results\ 139 | --input_dir=/my/sweep/output/path 140 | ```` 141 | 142 | We inspect performances using following [model selection criteria](domainbed/model_selection.py), that differ in what data is used to choose the best hyper-parameters for a given model: 143 | 144 | * `OracleSelectionMethod` (`Oracle`): A random subset from the data of the test domain. 145 | * `IIDAccuracySelectionMethod` (`Training`): A random subset from the data of the training domains. 146 | 147 | Critically, Fishr performs consistently better than Empirical Risk Minimization. 148 | 149 | Model selection | Algorithm | Colored MNIST | Rotated MNIST | VLCS | PACS |OfficeHome | TerraIncognita | DomainNet | Avg 150 | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 151 | | | | | | | | | | 152 | Oracle | ERM | 57.8 ± 0.2 | 97.8 ± 0.1 | 77.6 ± 0.3 | 86.7 ± 0.3 | 66.4 ± 0.5 | 53.0 ± 0.3 | 41.3 ± 0.1 | 68.7 153 | Oracle | Fishr | 68.8 ± 1.4 | 97.8 ± 0.1 | 78.2 ± 0.2 | 86.9 ± 0.2 | 68.2 ± 0.2 | 53.6 ± 0.4 | 41.8 ± 0.2 | 70.8 154 | | | | | | | | | | 155 | Training | ERM | 51.5 ± 0.1 | 98.0 ± 0.0 | 77.5 ± 0.4 | 85.5 ± 0.2 | 66.5 ± 0.3 | 46.1 ± 1.8 | 40.9 ± 0.1 | 66.6 156 | Training | Fishr | 52.0 ± 0.2 | 97.8 ± 0.0 | 77.8 ± 0.1 | 85.5 ± 0.4 | 67.8 ± 0.1 | 47.4 ± 1.6 | 41.7 ± 0.0 | 67.1 157 | 158 | 159 | # Conclusion 160 | 161 | We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact alexandre.rame@lip6.fr. 162 | 163 | # Citation 164 | 165 | If you find this code useful for your research, please consider citing our work: 166 | 167 | ``` 168 | @inproceedings{rame2021ishr, 169 | title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization}, 170 | author={Alexandre Rame and Corentin Dancette and Matthieu Cord}, 171 | year={2022}, 172 | booktitle={ICML} 173 | } 174 | ``` 175 | -------------------------------------------------------------------------------- /coloredmnist/train_coloredmnist.py: -------------------------------------------------------------------------------- 1 | # This script was first copied from https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/main.py under the license 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # Then we included our new regularization loss Fishr. To do so: 9 | # 1. we first compute gradients covariance on each domain (see compute_grads_variance method) using BackPACK package 10 | # 2. then, we compute l2 distance between these gradient covariances (see l2_between_grads_variance method) 11 | 12 | import random 13 | import argparse 14 | import numpy as np 15 | from collections import OrderedDict 16 | 17 | import torch 18 | from torchvision import datasets 19 | from torch import nn, optim, autograd 20 | 21 | from backpack import backpack, extend 22 | from backpack.extensions import BatchGrad 23 | 24 | parser = argparse.ArgumentParser(description='Colored MNIST') 25 | 26 | # select your algorithm 27 | parser.add_argument( 28 | '--algorithm', 29 | type=str, 30 | default="fishr", 31 | choices=[ 32 | ## Four main methods, for Table 2 in Section 6.A 33 | 'erm', # Empirical Risk Minimization 34 | 'irm', # Invariant Risk Minimization (https://arxiv.org/abs/1907.02893) 35 | 'rex', # Out-of-Distribution Generalization via Risk Extrapolation (https://icml.cc/virtual/2021/oral/9186) 36 | 'fishr', # Our proposed Fishr 37 | ## two Fishr variants, for Table 6 in Appendix C.2.4 38 | 'fishr_offdiagonal' # Fishr but on the full covariance rather than only the diagonal 39 | 'fishr_notcentered', # Fishr but without centering the gradient variances 40 | ] 41 | ) 42 | # select whether you want to apply label flipping or not 43 | # Set to 0 in Table 5 in Appendix C.2.3 and in the right half of Table 6 in Appendix C.2.4 44 | parser.add_argument('--label_flipping_prob', type=float, default=0.25) 45 | 46 | # Following hyperparameters are directly taken from from https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/reproduce_paper_results.sh 47 | # They should not be modified except in case of a new proper hyperparameter search with an external validation dataset. 48 | # Overall, we compare all approaches using the hyperparameters optimized for IRM. 49 | parser.add_argument('--hidden_dim', type=int, default=390) 50 | parser.add_argument('--l2_regularizer_weight', type=float, default=0.00110794568) 51 | parser.add_argument('--lr', type=float, default=0.0004898536566546834) 52 | parser.add_argument('--penalty_anneal_iters', type=int, default=190) 53 | parser.add_argument('--penalty_weight', type=float, default=91257.18613115903) 54 | parser.add_argument('--steps', type=int, default=501) 55 | # experimental setup 56 | parser.add_argument('--grayscale_model', action='store_true') 57 | parser.add_argument('--n_restarts', type=int, default=10) 58 | parser.add_argument('--seed', type=int, default=0, help='Seed for everything') 59 | 60 | flags = parser.parse_args() 61 | 62 | print('Flags:') 63 | for k, v in sorted(vars(flags).items()): 64 | print("\t{}: {}".format(k, v)) 65 | 66 | random.seed(flags.seed) 67 | np.random.seed(flags.seed) 68 | torch.manual_seed(flags.seed) 69 | torch.backends.cudnn.deterministic = True 70 | torch.backends.cudnn.benchmark = False 71 | 72 | final_train_accs = [] 73 | final_test_accs = [] 74 | final_graytest_accs = [] 75 | for restart in range(flags.n_restarts): 76 | print("Restart", restart) 77 | 78 | # Load MNIST, make train/val splits, and shuffle train set examples 79 | 80 | mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True) 81 | mnist_train = (mnist.data[:50000], mnist.targets[:50000]) 82 | mnist_val = (mnist.data[50000:], mnist.targets[50000:]) 83 | 84 | rng_state = np.random.get_state() 85 | np.random.shuffle(mnist_train[0].numpy()) 86 | np.random.set_state(rng_state) 87 | np.random.shuffle(mnist_train[1].numpy()) 88 | 89 | # Build environments 90 | 91 | 92 | def make_environment(images, labels, e, grayscale=False): 93 | 94 | def torch_bernoulli(p, size): 95 | return (torch.rand(size) < p).float() 96 | 97 | def torch_xor(a, b): 98 | return (a - b).abs() # Assumes both inputs are either 0 or 1 99 | 100 | # 2x subsample for computational convenience 101 | images = images.reshape((-1, 28, 28))[:, ::2, ::2] 102 | # Assign a binary label based on the digit; flip label with probability 0.25 103 | labels = (labels < 5).float() 104 | labels = torch_xor(labels, torch_bernoulli(flags.label_flipping_prob, len(labels))) 105 | # Assign a color based on the label; flip the color with probability e 106 | colors = torch_xor(labels, torch_bernoulli(e, len(labels))) 107 | # Apply the color to the image by zeroing out the other color channel 108 | images = torch.stack([images, images], dim=1) 109 | if not grayscale: 110 | images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0 111 | return {'images': (images.float() / 255.).cuda(), 'labels': labels[:, None].cuda()} 112 | 113 | envs = [ 114 | make_environment(mnist_train[0][::2], mnist_train[1][::2], 0.2), 115 | make_environment(mnist_train[0][1::2], mnist_train[1][1::2], 0.1), 116 | make_environment(mnist_val[0], mnist_val[1], 0.9), 117 | make_environment(mnist_val[0], mnist_val[1], 0.9, grayscale=True) 118 | ] 119 | 120 | # Define and instantiate the model 121 | 122 | 123 | class MLP(nn.Module): 124 | 125 | def __init__(self): 126 | super(MLP, self).__init__() 127 | if flags.grayscale_model: 128 | lin1 = nn.Linear(14 * 14, flags.hidden_dim) 129 | else: 130 | lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim) 131 | lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim) 132 | 133 | self.classifier = extend(nn.Linear(flags.hidden_dim, 1)) 134 | for lin in [lin1, lin2, self.classifier]: 135 | nn.init.xavier_uniform_(lin.weight) 136 | nn.init.zeros_(lin.bias) 137 | self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True)) 138 | 139 | def forward(self, input): 140 | if flags.grayscale_model: 141 | out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1) 142 | else: 143 | out = input.view(input.shape[0], 2 * 14 * 14) 144 | features = self._main(out) 145 | logits = self.classifier(features) 146 | return features, logits 147 | 148 | mlp = MLP().cuda() 149 | 150 | # Define loss function helpers 151 | 152 | 153 | def mean_nll(logits, y): 154 | return nn.functional.binary_cross_entropy_with_logits(logits, y) 155 | 156 | def mean_accuracy(logits, y): 157 | preds = (logits > 0.).float() 158 | return ((preds - y).abs() < 1e-2).float().mean() 159 | 160 | def compute_irm_penalty(logits, y): 161 | scale = torch.tensor(1.).cuda().requires_grad_() 162 | loss = mean_nll(logits * scale, y) 163 | grad = autograd.grad(loss, [scale], create_graph=True)[0] 164 | return torch.sum(grad**2) 165 | 166 | bce_extended = extend(nn.BCEWithLogitsLoss()) 167 | 168 | def compute_grads_variance(features, labels, classifier): 169 | logits = classifier(features) 170 | loss = bce_extended(logits, labels) 171 | with backpack(BatchGrad()): 172 | loss.backward( 173 | inputs=list(classifier.parameters()), retain_graph=True, create_graph=True 174 | ) 175 | 176 | dict_grads = OrderedDict( 177 | [ 178 | (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1)) 179 | for name, weights in classifier.named_parameters() 180 | ] 181 | ) 182 | dict_grads_variance = {} 183 | for name, _grads in dict_grads.items(): 184 | grads = _grads * labels.size(0) # multiply by batch size 185 | env_mean = grads.mean(dim=0, keepdim=True) 186 | if flags.algorithm != "fishr_notcentered": 187 | grads = grads - env_mean 188 | if flags.algorithm == "fishr_offdiagonal": 189 | dict_grads_variance[name] = torch.einsum("na,nb->ab", grads, 190 | grads) / (grads.size(0) * grads.size(1)) 191 | else: 192 | dict_grads_variance[name] = (grads).pow(2).mean(dim=0) 193 | 194 | return dict_grads_variance 195 | 196 | def l2_between_grads_variance(cov_1, cov_2): 197 | assert len(cov_1) == len(cov_2) 198 | cov_1_values = [cov_1[key] for key in sorted(cov_1.keys())] 199 | cov_2_values = [cov_2[key] for key in sorted(cov_2.keys())] 200 | return ( 201 | torch.cat(tuple([t.view(-1) for t in cov_1_values])) - 202 | torch.cat(tuple([t.view(-1) for t in cov_2_values])) 203 | ).pow(2).sum() 204 | 205 | # Train loop 206 | 207 | def pretty_print(*values): 208 | col_width = 13 209 | 210 | def format_val(v): 211 | if not isinstance(v, str): 212 | v = np.array2string(v, precision=5, floatmode='fixed') 213 | return v.ljust(col_width) 214 | 215 | str_values = [format_val(v) for v in values] 216 | print(" ".join(str_values)) 217 | 218 | optimizer = optim.Adam(mlp.parameters(), lr=flags.lr) 219 | 220 | pretty_print( 221 | 'step', 'train nll', 'train acc', 'fishr penalty', 'rex penalty', 'irm penalty', 'test acc', 222 | "gray test acc" 223 | ) 224 | for step in range(flags.steps): 225 | for edx, env in enumerate(envs): 226 | features, logits = mlp(env['images']) 227 | env['nll'] = mean_nll(logits, env['labels']) 228 | env['acc'] = mean_accuracy(logits, env['labels']) 229 | env['irm'] = compute_irm_penalty(logits, env['labels']) 230 | if edx in [0, 1]: 231 | # True when the dataset is in training 232 | optimizer.zero_grad() 233 | env["grads_variance"] = compute_grads_variance(features, env['labels'], mlp.classifier) 234 | 235 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean() 236 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() 237 | 238 | weight_norm = torch.tensor(0.).cuda() 239 | for w in mlp.parameters(): 240 | weight_norm += w.norm().pow(2) 241 | 242 | loss = train_nll.clone() 243 | loss += flags.l2_regularizer_weight * weight_norm 244 | 245 | irm_penalty = torch.stack([envs[0]['irm'], envs[1]['irm']]).mean() 246 | rex_penalty = (envs[0]['nll'].mean() - envs[1]['nll'].mean())**2 247 | 248 | # Compute the variance averaged over the two training domains 249 | dict_grads_variance_averaged = OrderedDict( 250 | [ 251 | ( 252 | name, 253 | torch.stack([envs[0]["grads_variance"][name], envs[1]["grads_variance"][name]], 254 | dim=0).mean(dim=0) 255 | ) for name in envs[0]["grads_variance"] 256 | ] 257 | ) 258 | fishr_penalty = ( 259 | l2_between_grads_variance(envs[0]["grads_variance"], dict_grads_variance_averaged) + 260 | l2_between_grads_variance(envs[1]["grads_variance"], dict_grads_variance_averaged) 261 | ) 262 | 263 | # apply the selected regularization 264 | if flags.algorithm == "erm": 265 | pass 266 | else: 267 | if flags.algorithm.startswith("fishr"): 268 | train_penalty = fishr_penalty 269 | elif flags.algorithm == "rex": 270 | train_penalty = rex_penalty 271 | elif flags.algorithm == "irm": 272 | train_penalty = irm_penalty 273 | else: 274 | raise ValueError(flags.algorithm) 275 | penalty_weight = (flags.penalty_weight if step >= flags.penalty_anneal_iters else 1.0) 276 | loss += penalty_weight * train_penalty 277 | if penalty_weight > 1.0: 278 | # Rescale the entire loss to keep gradients in a reasonable range 279 | loss /= penalty_weight 280 | 281 | optimizer.zero_grad() 282 | loss.backward() 283 | optimizer.step() 284 | 285 | test_acc = envs[2]['acc'] 286 | grayscale_test_acc = envs[3]['acc'] 287 | if step % 100 == 0: 288 | pretty_print( 289 | np.int32(step), 290 | train_nll.detach().cpu().numpy(), 291 | train_acc.detach().cpu().numpy(), 292 | fishr_penalty.detach().cpu().numpy(), 293 | rex_penalty.detach().cpu().numpy(), 294 | irm_penalty.detach().cpu().numpy(), 295 | test_acc.detach().cpu().numpy(), 296 | grayscale_test_acc.detach().cpu().numpy(), 297 | ) 298 | 299 | final_train_accs.append(train_acc.detach().cpu().numpy()) 300 | final_test_accs.append(test_acc.detach().cpu().numpy()) 301 | final_graytest_accs.append(grayscale_test_acc.detach().cpu().numpy()) 302 | print('Final train acc (mean/std across restarts so far):') 303 | print(np.mean(final_train_accs), np.std(final_train_accs)) 304 | print('Final test acc (mean/std across restarts so far):') 305 | print(np.mean(final_test_accs), np.std(final_test_accs)) 306 | print('Final gray test acc (mean/std across restarts so far):') 307 | print(np.mean(final_graytest_accs), np.std(final_graytest_accs)) 308 | -------------------------------------------------------------------------------- /domainbed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/algorithms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.autograd as autograd 7 | from torch.autograd import Variable 8 | 9 | import copy 10 | import numpy as np 11 | from collections import defaultdict, OrderedDict 12 | try: 13 | from backpack import backpack, extend 14 | from backpack.extensions import BatchGrad 15 | except: 16 | backpack = None 17 | 18 | from domainbed import networks 19 | from domainbed.lib.misc import ( 20 | random_pairs_of_minibatches, ParamDict, MovingAverage, l2_between_dicts 21 | ) 22 | 23 | ALGORITHMS = [ 24 | 'ERM', 25 | 'Fish', 26 | 'IRM', 27 | 'GroupDRO', 28 | 'Mixup', 29 | 'MLDG', 30 | 'CORAL', 31 | 'MMD', 32 | 'DANN', 33 | 'CDANN', 34 | 'MTL', 35 | 'SagNet', 36 | 'ARM', 37 | 'VREx', 38 | 'RSC', 39 | 'SD', 40 | 'ANDMask', 41 | 'SANDMask', # SAND-mask 42 | 'IGA', 43 | 'SelfReg', 44 | "Fishr" 45 | ] 46 | 47 | 48 | def get_algorithm_class(algorithm_name): 49 | """Return the algorithm class with the given name.""" 50 | if algorithm_name not in globals(): 51 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name)) 52 | return globals()[algorithm_name] 53 | 54 | 55 | class Algorithm(torch.nn.Module): 56 | """ 57 | A subclass of Algorithm implements a domain generalization algorithm. 58 | Subclasses should implement the following: 59 | - update() 60 | - predict() 61 | """ 62 | 63 | def __init__(self, input_shape, num_classes, num_domains, hparams): 64 | super(Algorithm, self).__init__() 65 | self.hparams = hparams 66 | 67 | def update(self, minibatches, unlabeled=None): 68 | """ 69 | Perform one update step, given a list of (x, y) tuples for all 70 | environments. 71 | 72 | Admits an optional list of unlabeled minibatches from the test domains, 73 | when task is domain_adaptation. 74 | """ 75 | raise NotImplementedError 76 | 77 | def predict(self, x): 78 | raise NotImplementedError 79 | 80 | 81 | class ERM(Algorithm): 82 | """ 83 | Empirical Risk Minimization (ERM) 84 | """ 85 | 86 | def __init__(self, input_shape, num_classes, num_domains, hparams): 87 | super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams) 88 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 89 | self.classifier = networks.Classifier( 90 | self.featurizer.n_outputs, num_classes, self.hparams['nonlinear_classifier'] 91 | ) 92 | 93 | self.network = nn.Sequential(self.featurizer, self.classifier) 94 | self.optimizer = torch.optim.Adam( 95 | self.network.parameters(), 96 | lr=self.hparams["lr"], 97 | weight_decay=self.hparams['weight_decay'] 98 | ) 99 | 100 | def update(self, minibatches, unlabeled=None): 101 | all_x = torch.cat([x for x, y in minibatches]) 102 | all_y = torch.cat([y for x, y in minibatches]) 103 | loss = F.cross_entropy(self.predict(all_x), all_y) 104 | 105 | self.optimizer.zero_grad() 106 | loss.backward() 107 | self.optimizer.step() 108 | 109 | return {'loss': loss.item()} 110 | 111 | def predict(self, x): 112 | return self.network(x) 113 | 114 | 115 | class Fish(Algorithm): 116 | """ 117 | Implementation of Fish, as seen in Gradient Matching for Domain 118 | Generalization, Shi et al. 2021. 119 | """ 120 | 121 | def __init__(self, input_shape, num_classes, num_domains, hparams): 122 | super(Fish, self).__init__(input_shape, num_classes, num_domains, hparams) 123 | self.input_shape = input_shape 124 | self.num_classes = num_classes 125 | 126 | self.network = networks.WholeFish(input_shape, num_classes, hparams) 127 | self.optimizer = torch.optim.Adam( 128 | self.network.parameters(), 129 | lr=self.hparams["lr"], 130 | weight_decay=self.hparams['weight_decay'] 131 | ) 132 | self.optimizer_inner_state = None 133 | 134 | def create_clone(self, device): 135 | self.network_inner = networks.WholeFish( 136 | self.input_shape, self.num_classes, self.hparams, weights=self.network.state_dict() 137 | ).to(device) 138 | self.optimizer_inner = torch.optim.Adam( 139 | self.network_inner.parameters(), 140 | lr=self.hparams["lr"], 141 | weight_decay=self.hparams['weight_decay'] 142 | ) 143 | if self.optimizer_inner_state is not None: 144 | self.optimizer_inner.load_state_dict(self.optimizer_inner_state) 145 | 146 | def fish(self, meta_weights, inner_weights, lr_meta): 147 | meta_weights = ParamDict(meta_weights) 148 | inner_weights = ParamDict(inner_weights) 149 | meta_weights += lr_meta * (inner_weights - meta_weights) 150 | return meta_weights 151 | 152 | def update(self, minibatches, unlabeled=None): 153 | self.create_clone(minibatches[0][0].device) 154 | 155 | for x, y in minibatches: 156 | loss = F.cross_entropy(self.network_inner(x), y) 157 | self.optimizer_inner.zero_grad() 158 | loss.backward() 159 | self.optimizer_inner.step() 160 | 161 | self.optimizer_inner_state = self.optimizer_inner.state_dict() 162 | meta_weights = self.fish( 163 | meta_weights=self.network.state_dict(), 164 | inner_weights=self.network_inner.state_dict(), 165 | lr_meta=self.hparams["meta_lr"] 166 | ) 167 | self.network.reset_weights(meta_weights) 168 | 169 | return {'loss': loss.item()} 170 | 171 | def predict(self, x): 172 | return self.network(x) 173 | 174 | 175 | class ARM(ERM): 176 | """ Adaptive Risk Minimization (ARM) """ 177 | 178 | def __init__(self, input_shape, num_classes, num_domains, hparams): 179 | original_input_shape = input_shape 180 | input_shape = (1 + original_input_shape[0],) + original_input_shape[1:] 181 | super(ARM, self).__init__(input_shape, num_classes, num_domains, hparams) 182 | self.context_net = networks.ContextNet(original_input_shape) 183 | self.support_size = hparams['batch_size'] 184 | 185 | def predict(self, x): 186 | batch_size, c, h, w = x.shape 187 | if batch_size % self.support_size == 0: 188 | meta_batch_size = batch_size // self.support_size 189 | support_size = self.support_size 190 | else: 191 | meta_batch_size, support_size = 1, batch_size 192 | context = self.context_net(x) 193 | context = context.reshape((meta_batch_size, support_size, 1, h, w)) 194 | context = context.mean(dim=1) 195 | context = torch.repeat_interleave(context, repeats=support_size, dim=0) 196 | x = torch.cat([x, context], dim=1) 197 | return self.network(x) 198 | 199 | 200 | class AbstractDANN(Algorithm): 201 | """Domain-Adversarial Neural Networks (abstract class)""" 202 | 203 | def __init__(self, input_shape, num_classes, num_domains, hparams, conditional, class_balance): 204 | 205 | super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, hparams) 206 | 207 | self.register_buffer('update_count', torch.tensor([0])) 208 | self.conditional = conditional 209 | self.class_balance = class_balance 210 | 211 | # Algorithms 212 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 213 | self.classifier = networks.Classifier( 214 | self.featurizer.n_outputs, num_classes, self.hparams['nonlinear_classifier'] 215 | ) 216 | self.discriminator = networks.MLP(self.featurizer.n_outputs, num_domains, self.hparams) 217 | self.class_embeddings = nn.Embedding(num_classes, self.featurizer.n_outputs) 218 | 219 | # Optimizers 220 | self.disc_opt = torch.optim.Adam( 221 | (list(self.discriminator.parameters()) + list(self.class_embeddings.parameters())), 222 | lr=self.hparams["lr_d"], 223 | weight_decay=self.hparams['weight_decay_d'], 224 | betas=(self.hparams['beta1'], 0.9) 225 | ) 226 | 227 | self.gen_opt = torch.optim.Adam( 228 | (list(self.featurizer.parameters()) + list(self.classifier.parameters())), 229 | lr=self.hparams["lr_g"], 230 | weight_decay=self.hparams['weight_decay_g'], 231 | betas=(self.hparams['beta1'], 0.9) 232 | ) 233 | 234 | def update(self, minibatches, unlabeled=None): 235 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 236 | self.update_count += 1 237 | all_x = torch.cat([x for x, y in minibatches]) 238 | all_y = torch.cat([y for x, y in minibatches]) 239 | all_z = self.featurizer(all_x) 240 | if self.conditional: 241 | disc_input = all_z + self.class_embeddings(all_y) 242 | else: 243 | disc_input = all_z 244 | disc_out = self.discriminator(disc_input) 245 | disc_labels = torch.cat( 246 | [ 247 | torch.full((x.shape[0],), i, dtype=torch.int64, device=device) 248 | for i, (x, y) in enumerate(minibatches) 249 | ] 250 | ) 251 | 252 | if self.class_balance: 253 | y_counts = F.one_hot(all_y).sum(dim=0) 254 | weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float() 255 | disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none') 256 | disc_loss = (weights * disc_loss).sum() 257 | else: 258 | disc_loss = F.cross_entropy(disc_out, disc_labels) 259 | 260 | disc_softmax = F.softmax(disc_out, dim=1) 261 | input_grad = autograd.grad( 262 | disc_softmax[:, disc_labels].sum(), [disc_input], create_graph=True 263 | )[0] 264 | grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0) 265 | disc_loss += self.hparams['grad_penalty'] * grad_penalty 266 | 267 | d_steps_per_g = self.hparams['d_steps_per_g_step'] 268 | if (self.update_count.item() % (1 + d_steps_per_g) < d_steps_per_g): 269 | 270 | self.disc_opt.zero_grad() 271 | disc_loss.backward() 272 | self.disc_opt.step() 273 | return {'disc_loss': disc_loss.item()} 274 | else: 275 | all_preds = self.classifier(all_z) 276 | classifier_loss = F.cross_entropy(all_preds, all_y) 277 | gen_loss = (classifier_loss + (self.hparams['lambda'] * -disc_loss)) 278 | self.disc_opt.zero_grad() 279 | self.gen_opt.zero_grad() 280 | gen_loss.backward() 281 | self.gen_opt.step() 282 | return {'gen_loss': gen_loss.item()} 283 | 284 | def predict(self, x): 285 | return self.classifier(self.featurizer(x)) 286 | 287 | 288 | class DANN(AbstractDANN): 289 | """Unconditional DANN""" 290 | 291 | def __init__(self, input_shape, num_classes, num_domains, hparams): 292 | super(DANN, self).__init__( 293 | input_shape, num_classes, num_domains, hparams, conditional=False, class_balance=False 294 | ) 295 | 296 | 297 | class CDANN(AbstractDANN): 298 | """Conditional DANN""" 299 | 300 | def __init__(self, input_shape, num_classes, num_domains, hparams): 301 | super(CDANN, self).__init__( 302 | input_shape, num_classes, num_domains, hparams, conditional=True, class_balance=True 303 | ) 304 | 305 | 306 | class IRM(ERM): 307 | """Invariant Risk Minimization""" 308 | 309 | def __init__(self, input_shape, num_classes, num_domains, hparams): 310 | super(IRM, self).__init__(input_shape, num_classes, num_domains, hparams) 311 | self.register_buffer('update_count', torch.tensor([0])) 312 | 313 | @staticmethod 314 | def _irm_penalty(logits, y): 315 | device = "cuda" if logits[0][0].is_cuda else "cpu" 316 | scale = torch.tensor(1.).to(device).requires_grad_() 317 | loss_1 = F.cross_entropy(logits[::2] * scale, y[::2]) 318 | loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2]) 319 | grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0] 320 | grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0] 321 | result = torch.sum(grad_1 * grad_2) 322 | return result 323 | 324 | def update(self, minibatches, unlabeled=None): 325 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 326 | penalty_weight = ( 327 | self.hparams['irm_lambda'] 328 | if self.update_count >= self.hparams['irm_penalty_anneal_iters'] else 1.0 329 | ) 330 | nll = 0. 331 | penalty = 0. 332 | 333 | all_x = torch.cat([x for x, y in minibatches]) 334 | all_logits = self.network(all_x) 335 | all_logits_idx = 0 336 | for i, (x, y) in enumerate(minibatches): 337 | logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]] 338 | all_logits_idx += x.shape[0] 339 | nll += F.cross_entropy(logits, y) 340 | penalty += self._irm_penalty(logits, y) 341 | nll /= len(minibatches) 342 | penalty /= len(minibatches) 343 | loss = nll + (penalty_weight * penalty) 344 | 345 | if self.update_count == self.hparams['irm_penalty_anneal_iters']: 346 | # Reset Adam, because it doesn't like the sharp jump in gradient 347 | # magnitudes that happens at this step. 348 | self.optimizer = torch.optim.Adam( 349 | self.network.parameters(), 350 | lr=self.hparams["lr"], 351 | weight_decay=self.hparams['weight_decay'] 352 | ) 353 | 354 | self.optimizer.zero_grad() 355 | loss.backward() 356 | self.optimizer.step() 357 | 358 | self.update_count += 1 359 | return {'loss': loss.item(), 'nll': nll.item(), 'penalty': penalty.item()} 360 | 361 | 362 | class VREx(ERM): 363 | """V-REx algorithm from http://arxiv.org/abs/2003.00688""" 364 | 365 | def __init__(self, input_shape, num_classes, num_domains, hparams): 366 | super(VREx, self).__init__(input_shape, num_classes, num_domains, hparams) 367 | self.register_buffer('update_count', torch.tensor([0])) 368 | 369 | def update(self, minibatches, unlabeled=None): 370 | if self.update_count >= self.hparams["vrex_penalty_anneal_iters"]: 371 | penalty_weight = self.hparams["vrex_lambda"] 372 | else: 373 | penalty_weight = 1.0 374 | 375 | nll = 0. 376 | 377 | all_x = torch.cat([x for x, y in minibatches]) 378 | all_logits = self.network(all_x) 379 | all_logits_idx = 0 380 | losses = torch.zeros(len(minibatches)) 381 | for i, (x, y) in enumerate(minibatches): 382 | logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]] 383 | all_logits_idx += x.shape[0] 384 | nll = F.cross_entropy(logits, y) 385 | losses[i] = nll 386 | 387 | mean = losses.mean() 388 | penalty = ((losses - mean)**2).mean() 389 | loss = mean + penalty_weight * penalty 390 | 391 | if self.update_count == self.hparams['vrex_penalty_anneal_iters']: 392 | # Reset Adam (like IRM), because it doesn't like the sharp jump in 393 | # gradient magnitudes that happens at this step. 394 | self.optimizer = torch.optim.Adam( 395 | self.network.parameters(), 396 | lr=self.hparams["lr"], 397 | weight_decay=self.hparams['weight_decay'] 398 | ) 399 | 400 | self.optimizer.zero_grad() 401 | loss.backward() 402 | self.optimizer.step() 403 | 404 | self.update_count += 1 405 | return {'loss': loss.item(), 'nll': nll.item(), 'penalty': penalty.item()} 406 | 407 | 408 | class Mixup(ERM): 409 | """ 410 | Mixup of minibatches from different domains 411 | https://arxiv.org/pdf/2001.00677.pdf 412 | https://arxiv.org/pdf/1912.01805.pdf 413 | """ 414 | 415 | def __init__(self, input_shape, num_classes, num_domains, hparams): 416 | super(Mixup, self).__init__(input_shape, num_classes, num_domains, hparams) 417 | 418 | def update(self, minibatches, unlabeled=None): 419 | objective = 0 420 | 421 | for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): 422 | lam = np.random.beta(self.hparams["mixup_alpha"], self.hparams["mixup_alpha"]) 423 | 424 | x = lam * xi + (1 - lam) * xj 425 | predictions = self.predict(x) 426 | 427 | objective += lam * F.cross_entropy(predictions, yi) 428 | objective += (1 - lam) * F.cross_entropy(predictions, yj) 429 | 430 | objective /= len(minibatches) 431 | 432 | self.optimizer.zero_grad() 433 | objective.backward() 434 | self.optimizer.step() 435 | 436 | return {'loss': objective.item()} 437 | 438 | 439 | class GroupDRO(ERM): 440 | """ 441 | Robust ERM minimizes the error at the worst minibatch 442 | Algorithm 1 from [https://arxiv.org/pdf/1911.08731.pdf] 443 | """ 444 | 445 | def __init__(self, input_shape, num_classes, num_domains, hparams): 446 | super(GroupDRO, self).__init__(input_shape, num_classes, num_domains, hparams) 447 | self.register_buffer("q", torch.Tensor()) 448 | 449 | def update(self, minibatches, unlabeled=None): 450 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 451 | 452 | if not len(self.q): 453 | self.q = torch.ones(len(minibatches)).to(device) 454 | 455 | losses = torch.zeros(len(minibatches)).to(device) 456 | 457 | for m in range(len(minibatches)): 458 | x, y = minibatches[m] 459 | losses[m] = F.cross_entropy(self.predict(x), y) 460 | self.q[m] *= (self.hparams["groupdro_eta"] * losses[m].data).exp() 461 | 462 | self.q /= self.q.sum() 463 | 464 | loss = torch.dot(losses, self.q) 465 | 466 | self.optimizer.zero_grad() 467 | loss.backward() 468 | self.optimizer.step() 469 | 470 | return {'loss': loss.item()} 471 | 472 | 473 | class MLDG(ERM): 474 | """ 475 | Model-Agnostic Meta-Learning 476 | Algorithm 1 / Equation (3) from: https://arxiv.org/pdf/1710.03463.pdf 477 | Related: https://arxiv.org/pdf/1703.03400.pdf 478 | Related: https://arxiv.org/pdf/1910.13580.pdf 479 | """ 480 | 481 | def __init__(self, input_shape, num_classes, num_domains, hparams): 482 | super(MLDG, self).__init__(input_shape, num_classes, num_domains, hparams) 483 | 484 | def update(self, minibatches, unlabeled=None): 485 | """ 486 | Terms being computed: 487 | * Li = Loss(xi, yi, params) 488 | * Gi = Grad(Li, params) 489 | 490 | * Lj = Loss(xj, yj, Optimizer(params, grad(Li, params))) 491 | * Gj = Grad(Lj, params) 492 | 493 | * params = Optimizer(params, Grad(Li + beta * Lj, params)) 494 | * = Optimizer(params, Gi + beta * Gj) 495 | 496 | That is, when calling .step(), we want grads to be Gi + beta * Gj 497 | 498 | For computational efficiency, we do not compute second derivatives. 499 | """ 500 | num_mb = len(minibatches) 501 | objective = 0 502 | 503 | self.optimizer.zero_grad() 504 | for p in self.network.parameters(): 505 | if p.grad is None: 506 | p.grad = torch.zeros_like(p) 507 | 508 | for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): 509 | # fine tune clone-network on task "i" 510 | inner_net = copy.deepcopy(self.network) 511 | 512 | inner_opt = torch.optim.Adam( 513 | inner_net.parameters(), 514 | lr=self.hparams["lr"], 515 | weight_decay=self.hparams['weight_decay'] 516 | ) 517 | 518 | inner_obj = F.cross_entropy(inner_net(xi), yi) 519 | 520 | inner_opt.zero_grad() 521 | inner_obj.backward() 522 | inner_opt.step() 523 | 524 | # The network has now accumulated gradients Gi 525 | # The clone-network has now parameters P - lr * Gi 526 | for p_tgt, p_src in zip(self.network.parameters(), inner_net.parameters()): 527 | if p_src.grad is not None: 528 | p_tgt.grad.data.add_(p_src.grad.data / num_mb) 529 | 530 | # `objective` is populated for reporting purposes 531 | objective += inner_obj.item() 532 | 533 | # this computes Gj on the clone-network 534 | loss_inner_j = F.cross_entropy(inner_net(xj), yj) 535 | grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True) 536 | 537 | # `objective` is populated for reporting purposes 538 | objective += (self.hparams['mldg_beta'] * loss_inner_j).item() 539 | 540 | for p, g_j in zip(self.network.parameters(), grad_inner_j): 541 | if g_j is not None: 542 | p.grad.data.add_(self.hparams['mldg_beta'] * g_j.data / num_mb) 543 | 544 | # The network has now accumulated gradients Gi + beta * Gj 545 | # Repeat for all train-test splits, do .step() 546 | 547 | objective /= len(minibatches) 548 | 549 | self.optimizer.step() 550 | 551 | return {'loss': objective} 552 | 553 | # This commented "update" method back-propagates through the gradients of 554 | # the inner update, as suggested in the original MAML paper. However, this 555 | # is twice as expensive as the uncommented "update" method, which does not 556 | # compute second-order derivatives, implementing the First-Order MAML 557 | # method (FOMAML) described in the original MAML paper. 558 | 559 | # def update(self, minibatches, unlabeled=None): 560 | # objective = 0 561 | # beta = self.hparams["beta"] 562 | # inner_iterations = self.hparams["inner_iterations"] 563 | 564 | # self.optimizer.zero_grad() 565 | 566 | # with higher.innerloop_ctx(self.network, self.optimizer, 567 | # copy_initial_weights=False) as (inner_network, inner_optimizer): 568 | 569 | # for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches): 570 | # for inner_iteration in range(inner_iterations): 571 | # li = F.cross_entropy(inner_network(xi), yi) 572 | # inner_optimizer.step(li) 573 | # 574 | # objective += F.cross_entropy(self.network(xi), yi) 575 | # objective += beta * F.cross_entropy(inner_network(xj), yj) 576 | 577 | # objective /= len(minibatches) 578 | # objective.backward() 579 | # 580 | # self.optimizer.step() 581 | # 582 | # return objective 583 | 584 | 585 | class AbstractMMD(ERM): 586 | """ 587 | Perform ERM while matching the pair-wise domain feature distributions 588 | using MMD (abstract class) 589 | """ 590 | 591 | def __init__(self, input_shape, num_classes, num_domains, hparams, gaussian): 592 | super(AbstractMMD, self).__init__(input_shape, num_classes, num_domains, hparams) 593 | if gaussian: 594 | self.kernel_type = "gaussian" 595 | else: 596 | self.kernel_type = "mean_cov" 597 | 598 | def my_cdist(self, x1, x2): 599 | x1_norm = x1.pow(2).sum(dim=-1, keepdim=True) 600 | x2_norm = x2.pow(2).sum(dim=-1, keepdim=True) 601 | res = torch.addmm( 602 | x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2 603 | ).add_(x1_norm) 604 | return res.clamp_min_(1e-30) 605 | 606 | def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100, 1000]): 607 | D = self.my_cdist(x, y) 608 | K = torch.zeros_like(D) 609 | 610 | for g in gamma: 611 | K.add_(torch.exp(D.mul(-g))) 612 | 613 | return K 614 | 615 | def mmd(self, x, y): 616 | if self.kernel_type == "gaussian": 617 | Kxx = self.gaussian_kernel(x, x).mean() 618 | Kyy = self.gaussian_kernel(y, y).mean() 619 | Kxy = self.gaussian_kernel(x, y).mean() 620 | return Kxx + Kyy - 2 * Kxy 621 | else: 622 | mean_x = x.mean(0, keepdim=True) 623 | mean_y = y.mean(0, keepdim=True) 624 | cent_x = x - mean_x 625 | cent_y = y - mean_y 626 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1) 627 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1) 628 | 629 | mean_diff = (mean_x - mean_y).pow(2).mean() 630 | cova_diff = (cova_x - cova_y).pow(2).mean() 631 | 632 | return mean_diff + cova_diff 633 | 634 | def update(self, minibatches, unlabeled=None): 635 | objective = 0 636 | penalty = 0 637 | nmb = len(minibatches) 638 | 639 | features = [self.featurizer(xi) for xi, _ in minibatches] 640 | classifs = [self.classifier(fi) for fi in features] 641 | targets = [yi for _, yi in minibatches] 642 | 643 | for i in range(nmb): 644 | objective += F.cross_entropy(classifs[i], targets[i]) 645 | for j in range(i + 1, nmb): 646 | penalty += self.mmd(features[i], features[j]) 647 | 648 | objective /= nmb 649 | if nmb > 1: 650 | penalty /= (nmb * (nmb - 1) / 2) 651 | 652 | self.optimizer.zero_grad() 653 | (objective + (self.hparams['mmd_gamma'] * penalty)).backward() 654 | self.optimizer.step() 655 | 656 | if torch.is_tensor(penalty): 657 | penalty = penalty.item() 658 | 659 | return {'loss': objective.item(), 'penalty': penalty} 660 | 661 | 662 | class MMD(AbstractMMD): 663 | """ 664 | MMD using Gaussian kernel 665 | """ 666 | 667 | def __init__(self, input_shape, num_classes, num_domains, hparams): 668 | super(MMD, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=True) 669 | 670 | 671 | class CORAL(AbstractMMD): 672 | """ 673 | MMD using mean and covariance difference 674 | """ 675 | 676 | def __init__(self, input_shape, num_classes, num_domains, hparams): 677 | super(CORAL, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=False) 678 | 679 | 680 | class MTL(Algorithm): 681 | """ 682 | A neural network version of 683 | Domain Generalization by Marginal Transfer Learning 684 | (https://arxiv.org/abs/1711.07910) 685 | """ 686 | 687 | def __init__(self, input_shape, num_classes, num_domains, hparams): 688 | super(MTL, self).__init__(input_shape, num_classes, num_domains, hparams) 689 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 690 | self.classifier = networks.Classifier( 691 | self.featurizer.n_outputs * 2, num_classes, self.hparams['nonlinear_classifier'] 692 | ) 693 | self.optimizer = torch.optim.Adam( 694 | list(self.featurizer.parameters()) +\ 695 | list(self.classifier.parameters()), 696 | lr=self.hparams["lr"], 697 | weight_decay=self.hparams['weight_decay'] 698 | ) 699 | 700 | self.register_buffer('embeddings', torch.zeros(num_domains, self.featurizer.n_outputs)) 701 | 702 | self.ema = self.hparams['mtl_ema'] 703 | 704 | def update(self, minibatches, unlabeled=None): 705 | loss = 0 706 | for env, (x, y) in enumerate(minibatches): 707 | loss += F.cross_entropy(self.predict(x, env), y) 708 | 709 | self.optimizer.zero_grad() 710 | loss.backward() 711 | self.optimizer.step() 712 | 713 | return {'loss': loss.item()} 714 | 715 | def update_embeddings_(self, features, env=None): 716 | return_embedding = features.mean(0) 717 | 718 | if env is not None: 719 | return_embedding = self.ema * return_embedding +\ 720 | (1 - self.ema) * self.embeddings[env] 721 | 722 | self.embeddings[env] = return_embedding.clone().detach() 723 | 724 | return return_embedding.view(1, -1).repeat(len(features), 1) 725 | 726 | def predict(self, x, env=None): 727 | features = self.featurizer(x) 728 | embedding = self.update_embeddings_(features, env).normal_() 729 | return self.classifier(torch.cat((features, embedding), 1)) 730 | 731 | 732 | class SagNet(Algorithm): 733 | """ 734 | Style Agnostic Network 735 | Algorithm 1 from: https://arxiv.org/abs/1910.11645 736 | """ 737 | 738 | def __init__(self, input_shape, num_classes, num_domains, hparams): 739 | super(SagNet, self).__init__(input_shape, num_classes, num_domains, hparams) 740 | # featurizer network 741 | self.network_f = networks.Featurizer(input_shape, self.hparams) 742 | # content network 743 | self.network_c = networks.Classifier( 744 | self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier'] 745 | ) 746 | # style network 747 | self.network_s = networks.Classifier( 748 | self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier'] 749 | ) 750 | 751 | # # This commented block of code implements something closer to the 752 | # # original paper, but is specific to ResNet and puts in disadvantage 753 | # # the other algorithms. 754 | # resnet_c = networks.Featurizer(input_shape, self.hparams) 755 | # resnet_s = networks.Featurizer(input_shape, self.hparams) 756 | # # featurizer network 757 | # self.network_f = torch.nn.Sequential( 758 | # resnet_c.network.conv1, 759 | # resnet_c.network.bn1, 760 | # resnet_c.network.relu, 761 | # resnet_c.network.maxpool, 762 | # resnet_c.network.layer1, 763 | # resnet_c.network.layer2, 764 | # resnet_c.network.layer3) 765 | # # content network 766 | # self.network_c = torch.nn.Sequential( 767 | # resnet_c.network.layer4, 768 | # resnet_c.network.avgpool, 769 | # networks.Flatten(), 770 | # resnet_c.network.fc) 771 | # # style network 772 | # self.network_s = torch.nn.Sequential( 773 | # resnet_s.network.layer4, 774 | # resnet_s.network.avgpool, 775 | # networks.Flatten(), 776 | # resnet_s.network.fc) 777 | 778 | def opt(p): 779 | return torch.optim.Adam(p, lr=hparams["lr"], weight_decay=hparams["weight_decay"]) 780 | 781 | self.optimizer_f = opt(self.network_f.parameters()) 782 | self.optimizer_c = opt(self.network_c.parameters()) 783 | self.optimizer_s = opt(self.network_s.parameters()) 784 | self.weight_adv = hparams["sag_w_adv"] 785 | 786 | def forward_c(self, x): 787 | # learning content network on randomized style 788 | return self.network_c(self.randomize(self.network_f(x), "style")) 789 | 790 | def forward_s(self, x): 791 | # learning style network on randomized content 792 | return self.network_s(self.randomize(self.network_f(x), "content")) 793 | 794 | def randomize(self, x, what="style", eps=1e-5): 795 | device = "cuda" if x.is_cuda else "cpu" 796 | sizes = x.size() 797 | alpha = torch.rand(sizes[0], 1).to(device) 798 | 799 | if len(sizes) == 4: 800 | x = x.view(sizes[0], sizes[1], -1) 801 | alpha = alpha.unsqueeze(-1) 802 | 803 | mean = x.mean(-1, keepdim=True) 804 | var = x.var(-1, keepdim=True) 805 | 806 | x = (x - mean) / (var + eps).sqrt() 807 | 808 | idx_swap = torch.randperm(sizes[0]) 809 | if what == "style": 810 | mean = alpha * mean + (1 - alpha) * mean[idx_swap] 811 | var = alpha * var + (1 - alpha) * var[idx_swap] 812 | else: 813 | x = x[idx_swap].detach() 814 | 815 | x = x * (var + eps).sqrt() + mean 816 | return x.view(*sizes) 817 | 818 | def update(self, minibatches, unlabeled=None): 819 | all_x = torch.cat([x for x, y in minibatches]) 820 | all_y = torch.cat([y for x, y in minibatches]) 821 | 822 | # learn content 823 | self.optimizer_f.zero_grad() 824 | self.optimizer_c.zero_grad() 825 | loss_c = F.cross_entropy(self.forward_c(all_x), all_y) 826 | loss_c.backward() 827 | self.optimizer_f.step() 828 | self.optimizer_c.step() 829 | 830 | # learn style 831 | self.optimizer_s.zero_grad() 832 | loss_s = F.cross_entropy(self.forward_s(all_x), all_y) 833 | loss_s.backward() 834 | self.optimizer_s.step() 835 | 836 | # learn adversary 837 | self.optimizer_f.zero_grad() 838 | loss_adv = -F.log_softmax(self.forward_s(all_x), dim=1).mean(1).mean() 839 | loss_adv = loss_adv * self.weight_adv 840 | loss_adv.backward() 841 | self.optimizer_f.step() 842 | 843 | return {'loss_c': loss_c.item(), 'loss_s': loss_s.item(), 'loss_adv': loss_adv.item()} 844 | 845 | def predict(self, x): 846 | return self.network_c(self.network_f(x)) 847 | 848 | 849 | class RSC(ERM): 850 | 851 | def __init__(self, input_shape, num_classes, num_domains, hparams): 852 | super(RSC, self).__init__(input_shape, num_classes, num_domains, hparams) 853 | self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100 854 | self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100 855 | self.num_classes = num_classes 856 | 857 | def update(self, minibatches, unlabeled=None): 858 | device = "cuda" if minibatches[0][0].is_cuda else "cpu" 859 | 860 | # inputs 861 | all_x = torch.cat([x for x, y in minibatches]) 862 | # labels 863 | all_y = torch.cat([y for _, y in minibatches]) 864 | # one-hot labels 865 | all_o = torch.nn.functional.one_hot(all_y, self.num_classes) 866 | # features 867 | all_f = self.featurizer(all_x) 868 | # predictions 869 | all_p = self.classifier(all_f) 870 | 871 | # Equation (1): compute gradients with respect to representation 872 | all_g = autograd.grad((all_p * all_o).sum(), all_f)[0] 873 | 874 | # Equation (2): compute top-gradient-percentile mask 875 | percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1) 876 | percentiles = torch.Tensor(percentiles) 877 | percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1)) 878 | mask_f = all_g.lt(percentiles.to(device)).float() 879 | 880 | # Equation (3): mute top-gradient-percentile activations 881 | all_f_muted = all_f * mask_f 882 | 883 | # Equation (4): compute muted predictions 884 | all_p_muted = self.classifier(all_f_muted) 885 | 886 | # Section 3.3: Batch Percentage 887 | all_s = F.softmax(all_p, dim=1) 888 | all_s_muted = F.softmax(all_p_muted, dim=1) 889 | changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1) 890 | percentile = np.percentile(changes.detach().cpu(), self.drop_b) 891 | mask_b = changes.lt(percentile).float().view(-1, 1) 892 | mask = torch.logical_or(mask_f, mask_b).float() 893 | 894 | # Equations (3) and (4) again, this time mutting over examples 895 | all_p_muted_again = self.classifier(all_f * mask) 896 | 897 | # Equation (5): update 898 | loss = F.cross_entropy(all_p_muted_again, all_y) 899 | self.optimizer.zero_grad() 900 | loss.backward() 901 | self.optimizer.step() 902 | 903 | return {'loss': loss.item()} 904 | 905 | 906 | class SD(ERM): 907 | """ 908 | Gradient Starvation: A Learning Proclivity in Neural Networks 909 | Equation 25 from [https://arxiv.org/pdf/2011.09468.pdf] 910 | """ 911 | 912 | def __init__(self, input_shape, num_classes, num_domains, hparams): 913 | super(SD, self).__init__(input_shape, num_classes, num_domains, hparams) 914 | self.sd_reg = hparams["sd_reg"] 915 | 916 | def update(self, minibatches, unlabeled=None): 917 | all_x = torch.cat([x for x, y in minibatches]) 918 | all_y = torch.cat([y for x, y in minibatches]) 919 | all_p = self.predict(all_x) 920 | 921 | loss = F.cross_entropy(all_p, all_y) 922 | penalty = (all_p**2).mean() 923 | objective = loss + self.sd_reg * penalty 924 | 925 | self.optimizer.zero_grad() 926 | objective.backward() 927 | self.optimizer.step() 928 | 929 | return {'loss': loss.item(), 'penalty': penalty.item()} 930 | 931 | 932 | class ANDMask(ERM): 933 | """ 934 | Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329] 935 | AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary] 936 | """ 937 | 938 | def __init__(self, input_shape, num_classes, num_domains, hparams): 939 | super(ANDMask, self).__init__(input_shape, num_classes, num_domains, hparams) 940 | 941 | self.tau = hparams["tau"] 942 | 943 | def update(self, minibatches, unlabeled=None): 944 | mean_loss = 0 945 | param_gradients = [[] for _ in self.network.parameters()] 946 | for i, (x, y) in enumerate(minibatches): 947 | logits = self.network(x) 948 | 949 | env_loss = F.cross_entropy(logits, y) 950 | mean_loss += env_loss.item() / len(minibatches) 951 | 952 | env_grads = autograd.grad(env_loss, self.network.parameters()) 953 | for grads, env_grad in zip(param_gradients, env_grads): 954 | grads.append(env_grad) 955 | 956 | self.optimizer.zero_grad() 957 | self.mask_grads(self.tau, param_gradients, self.network.parameters()) 958 | self.optimizer.step() 959 | 960 | return {'loss': mean_loss} 961 | 962 | def mask_grads(self, tau, gradients, params): 963 | 964 | for param, grads in zip(params, gradients): 965 | grads = torch.stack(grads, dim=0) 966 | grad_signs = torch.sign(grads) 967 | mask = torch.mean(grad_signs, dim=0).abs() >= self.tau 968 | mask = mask.to(torch.float32) 969 | avg_grad = torch.mean(grads, dim=0) 970 | 971 | mask_t = (mask.sum() / mask.numel()) 972 | param.grad = mask * avg_grad 973 | param.grad *= (1. / (1e-10 + mask_t)) 974 | 975 | return 0 976 | 977 | 978 | class IGA(ERM): 979 | """ 980 | Inter-environmental Gradient Alignment 981 | From https://arxiv.org/abs/2008.01883v2 982 | """ 983 | 984 | def __init__(self, in_features, num_classes, num_domains, hparams): 985 | super(IGA, self).__init__(in_features, num_classes, num_domains, hparams) 986 | 987 | def update(self, minibatches, unlabeled=False): 988 | total_loss = 0 989 | grads = [] 990 | for i, (x, y) in enumerate(minibatches): 991 | logits = self.network(x) 992 | 993 | env_loss = F.cross_entropy(logits, y) 994 | total_loss += env_loss 995 | 996 | env_grad = autograd.grad(env_loss, self.network.parameters(), create_graph=True) 997 | 998 | grads.append(env_grad) 999 | 1000 | mean_loss = total_loss / len(minibatches) 1001 | mean_grad = autograd.grad(mean_loss, self.network.parameters(), retain_graph=True) 1002 | 1003 | # compute trace penalty 1004 | penalty_value = 0 1005 | for grad in grads: 1006 | for g, mean_g in zip(grad, mean_grad): 1007 | penalty_value += (g - mean_g).pow(2).sum() 1008 | 1009 | objective = mean_loss + self.hparams['penalty'] * penalty_value 1010 | 1011 | self.optimizer.zero_grad() 1012 | objective.backward() 1013 | self.optimizer.step() 1014 | 1015 | return {'loss': mean_loss.item(), 'penalty': penalty_value.item()} 1016 | 1017 | 1018 | class SelfReg(ERM): 1019 | 1020 | def __init__(self, input_shape, num_classes, num_domains, hparams): 1021 | super(SelfReg, self).__init__(input_shape, num_classes, num_domains, hparams) 1022 | self.num_classes = num_classes 1023 | self.MSEloss = nn.MSELoss() 1024 | input_feat_size = self.featurizer.n_outputs 1025 | hidden_size = input_feat_size if input_feat_size == 2048 else input_feat_size * 2 1026 | 1027 | self.cdpl = nn.Sequential( 1028 | nn.Linear(input_feat_size, hidden_size), nn.BatchNorm1d(hidden_size), 1029 | nn.ReLU(inplace=True), nn.Linear(hidden_size, hidden_size), nn.BatchNorm1d(hidden_size), 1030 | nn.ReLU(inplace=True), nn.Linear(hidden_size, input_feat_size), 1031 | nn.BatchNorm1d(input_feat_size) 1032 | ) 1033 | 1034 | def update(self, minibatches, unlabeled=None): 1035 | 1036 | all_x = torch.cat([x for x, y in minibatches]) 1037 | all_y = torch.cat([y for _, y in minibatches]) 1038 | 1039 | lam = np.random.beta(0.5, 0.5) 1040 | 1041 | batch_size = all_y.size()[0] 1042 | 1043 | # cluster and order features into same-class group 1044 | with torch.no_grad(): 1045 | sorted_y, indices = torch.sort(all_y) 1046 | sorted_x = torch.zeros_like(all_x) 1047 | for idx, order in enumerate(indices): 1048 | sorted_x[idx] = all_x[order] 1049 | intervals = [] 1050 | ex = 0 1051 | for idx, val in enumerate(sorted_y): 1052 | if ex == val: 1053 | continue 1054 | intervals.append(idx) 1055 | ex = val 1056 | intervals.append(batch_size) 1057 | 1058 | all_x = sorted_x 1059 | all_y = sorted_y 1060 | 1061 | feat = self.featurizer(all_x) 1062 | proj = self.cdpl(feat) 1063 | 1064 | output = self.classifier(feat) 1065 | 1066 | # shuffle 1067 | output_2 = torch.zeros_like(output) 1068 | feat_2 = torch.zeros_like(proj) 1069 | output_3 = torch.zeros_like(output) 1070 | feat_3 = torch.zeros_like(proj) 1071 | ex = 0 1072 | for end in intervals: 1073 | shuffle_indices = torch.randperm(end - ex) + ex 1074 | shuffle_indices2 = torch.randperm(end - ex) + ex 1075 | for idx in range(end - ex): 1076 | output_2[idx + ex] = output[shuffle_indices[idx]] 1077 | feat_2[idx + ex] = proj[shuffle_indices[idx]] 1078 | output_3[idx + ex] = output[shuffle_indices2[idx]] 1079 | feat_3[idx + ex] = proj[shuffle_indices2[idx]] 1080 | ex = end 1081 | 1082 | # mixup 1083 | output_3 = lam * output_2 + (1 - lam) * output_3 1084 | feat_3 = lam * feat_2 + (1 - lam) * feat_3 1085 | 1086 | # regularization 1087 | L_ind_logit = self.MSEloss(output, output_2) 1088 | L_hdl_logit = self.MSEloss(output, output_3) 1089 | L_ind_feat = 0.3 * self.MSEloss(feat, feat_2) 1090 | L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3) 1091 | 1092 | cl_loss = F.cross_entropy(output, all_y) 1093 | C_scale = min(cl_loss.item(), 1.) 1094 | loss = cl_loss + C_scale * ( 1095 | lam * (L_ind_logit + L_ind_feat) + (1 - lam) * (L_hdl_logit + L_hdl_feat) 1096 | ) 1097 | 1098 | self.optimizer.zero_grad() 1099 | loss.backward() 1100 | self.optimizer.step() 1101 | 1102 | return {'loss': loss.item()} 1103 | 1104 | 1105 | class SANDMask(ERM): 1106 | """ 1107 | SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization 1108 | 1109 | """ 1110 | 1111 | def __init__(self, input_shape, num_classes, num_domains, hparams): 1112 | super(SANDMask, self).__init__(input_shape, num_classes, num_domains, hparams) 1113 | 1114 | self.tau = hparams["tau"] 1115 | self.k = hparams["k"] 1116 | betas = (0.9, 0.999) 1117 | self.optimizer = torch.optim.Adam( 1118 | self.network.parameters(), 1119 | lr=self.hparams["lr"], 1120 | weight_decay=self.hparams['weight_decay'], 1121 | betas=betas 1122 | ) 1123 | 1124 | self.register_buffer('update_count', torch.tensor([0])) 1125 | 1126 | def update(self, minibatches, unlabeled=None): 1127 | 1128 | mean_loss = 0 1129 | param_gradients = [[] for _ in self.network.parameters()] 1130 | for i, (x, y) in enumerate(minibatches): 1131 | logits = self.network(x) 1132 | 1133 | env_loss = F.cross_entropy(logits, y) 1134 | mean_loss += env_loss.item() / len(minibatches) 1135 | env_grads = autograd.grad(env_loss, self.network.parameters(), retain_graph=True) 1136 | for grads, env_grad in zip(param_gradients, env_grads): 1137 | grads.append(env_grad) 1138 | 1139 | self.optimizer.zero_grad() 1140 | # gradient masking applied here 1141 | self.mask_grads(param_gradients, self.network.parameters()) 1142 | self.optimizer.step() 1143 | self.update_count += 1 1144 | 1145 | return {'loss': mean_loss} 1146 | 1147 | def mask_grads(self, gradients, params): 1148 | ''' 1149 | Here a mask with continuous values in the range [0,1] is formed to control the amount of update for each 1150 | parameter based on the agreement of gradients coming from different environments. 1151 | ''' 1152 | device = gradients[0][0].device 1153 | for param, grads in zip(params, gradients): 1154 | grads = torch.stack(grads, dim=0) 1155 | avg_grad = torch.mean(grads, dim=0) 1156 | grad_signs = torch.sign(grads) 1157 | gamma = torch.tensor(1.0).to(device) 1158 | grads_var = grads.var(dim=0) 1159 | grads_var[torch.isnan(grads_var)] = 1e-17 1160 | lam = (gamma * grads_var).pow(-1) 1161 | mask = torch.tanh(self.k * lam * (torch.abs(grad_signs.mean(dim=0)) - self.tau)) 1162 | mask = torch.max(mask, torch.zeros_like(mask)) 1163 | mask[torch.isnan(mask)] = 1e-17 1164 | mask_t = (mask.sum() / mask.numel()) 1165 | param.grad = mask * avg_grad 1166 | param.grad *= (1. / (1e-10 + mask_t)) 1167 | 1168 | 1169 | class Fishr(Algorithm): 1170 | "Invariant Gradients variances for Out-of-distribution Generalization" 1171 | 1172 | def __init__(self, input_shape, num_classes, num_domains, hparams): 1173 | assert backpack is not None, "Install backpack with: 'pip install backpack-for-pytorch==1.3.0'" 1174 | super(Fishr, self).__init__(input_shape, num_classes, num_domains, hparams) 1175 | self.num_domains = num_domains 1176 | 1177 | self.featurizer = networks.Featurizer(input_shape, self.hparams) 1178 | self.classifier = extend( 1179 | networks.Classifier( 1180 | self.featurizer.n_outputs, 1181 | num_classes, 1182 | self.hparams['nonlinear_classifier'], 1183 | ) 1184 | ) 1185 | self.network = nn.Sequential(self.featurizer, self.classifier) 1186 | 1187 | self.register_buffer("update_count", torch.tensor([0])) 1188 | self.bce_extended = extend(nn.CrossEntropyLoss(reduction='none')) 1189 | self.ema_per_domain = [ 1190 | MovingAverage(ema=self.hparams["ema"], oneminusema_correction=True) 1191 | for _ in range(self.num_domains) 1192 | ] 1193 | self._init_optimizer() 1194 | 1195 | def _init_optimizer(self): 1196 | self.optimizer = torch.optim.Adam( 1197 | list(self.featurizer.parameters()) + list(self.classifier.parameters()), 1198 | lr=self.hparams["lr"], 1199 | weight_decay=self.hparams["weight_decay"], 1200 | ) 1201 | 1202 | def update(self, minibatches, unlabeled=False): 1203 | assert len(minibatches) == self.num_domains 1204 | all_x = torch.cat([x for x, y in minibatches]) 1205 | all_y = torch.cat([y for x, y in minibatches]) 1206 | len_minibatches = [x.shape[0] for x, y in minibatches] 1207 | 1208 | all_z = self.featurizer(all_x) 1209 | all_logits = self.classifier(all_z) 1210 | 1211 | penalty = self.compute_fishr_penalty(all_logits, all_y, len_minibatches) 1212 | all_nll = F.cross_entropy(all_logits, all_y) 1213 | 1214 | penalty_weight = 0 1215 | if self.update_count >= self.hparams["penalty_anneal_iters"]: 1216 | penalty_weight = self.hparams["lambda"] 1217 | if self.update_count == self.hparams["penalty_anneal_iters"] != 0: 1218 | # Reset Adam as in IRM or V-REx, because it may not like the sharp jump in 1219 | # gradient magnitudes that happens at this step. 1220 | self._init_optimizer() 1221 | self.update_count += 1 1222 | 1223 | objective = all_nll + penalty_weight * penalty 1224 | self.optimizer.zero_grad() 1225 | objective.backward() 1226 | self.optimizer.step() 1227 | 1228 | return {'loss': objective.item(), 'nll': all_nll.item(), 'penalty': penalty.item()} 1229 | 1230 | def compute_fishr_penalty(self, all_logits, all_y, len_minibatches): 1231 | dict_grads = self._get_grads(all_logits, all_y) 1232 | grads_var_per_domain = self._get_grads_var_per_domain(dict_grads, len_minibatches) 1233 | return self._compute_distance_grads_var(grads_var_per_domain) 1234 | 1235 | def _get_grads(self, logits, y): 1236 | self.optimizer.zero_grad() 1237 | loss = self.bce_extended(logits, y).sum() 1238 | with backpack(BatchGrad()): 1239 | loss.backward( 1240 | inputs=list(self.classifier.parameters()), retain_graph=True, create_graph=True 1241 | ) 1242 | 1243 | # compute individual grads for all samples across all domains simultaneously 1244 | dict_grads = OrderedDict( 1245 | [ 1246 | (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1)) 1247 | for name, weights in self.classifier.named_parameters() 1248 | ] 1249 | ) 1250 | return dict_grads 1251 | 1252 | def _get_grads_var_per_domain(self, dict_grads, len_minibatches): 1253 | # grads var per domain 1254 | grads_var_per_domain = [{} for _ in range(self.num_domains)] 1255 | for name, _grads in dict_grads.items(): 1256 | all_idx = 0 1257 | for domain_id, bsize in enumerate(len_minibatches): 1258 | env_grads = _grads[all_idx:all_idx + bsize] 1259 | all_idx += bsize 1260 | env_mean = env_grads.mean(dim=0, keepdim=True) 1261 | env_grads_centered = env_grads - env_mean 1262 | grads_var_per_domain[domain_id][name] = (env_grads_centered).pow(2).mean(dim=0) 1263 | 1264 | # moving average 1265 | for domain_id in range(self.num_domains): 1266 | grads_var_per_domain[domain_id] = self.ema_per_domain[domain_id].update( 1267 | grads_var_per_domain[domain_id] 1268 | ) 1269 | 1270 | return grads_var_per_domain 1271 | 1272 | def _compute_distance_grads_var(self, grads_var_per_domain): 1273 | 1274 | # compute gradient variances averaged across domains 1275 | grads_var = OrderedDict( 1276 | [ 1277 | ( 1278 | name, 1279 | torch.stack( 1280 | [ 1281 | grads_var_per_domain[domain_id][name] 1282 | for domain_id in range(self.num_domains) 1283 | ], 1284 | dim=0 1285 | ).mean(dim=0) 1286 | ) 1287 | for name in grads_var_per_domain[0].keys() 1288 | ] 1289 | ) 1290 | 1291 | penalty = 0 1292 | for domain_id in range(self.num_domains): 1293 | penalty += l2_between_dicts(grads_var_per_domain[domain_id], grads_var) 1294 | return penalty / self.num_domains 1295 | 1296 | def predict(self, x): 1297 | return self.network(x) 1298 | -------------------------------------------------------------------------------- /domainbed/command_launchers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | A command launcher launches a list of commands on a cluster; implement your own 5 | launcher to add support for your cluster. We've provided an example launcher 6 | which runs all commands serially on the local machine. 7 | """ 8 | 9 | import subprocess 10 | import time 11 | import torch 12 | 13 | def local_launcher(commands): 14 | """Launch commands serially on the local machine.""" 15 | for cmd in commands: 16 | subprocess.call(cmd, shell=True) 17 | 18 | def dummy_launcher(commands): 19 | """ 20 | Doesn't run anything; instead, prints each command. 21 | Useful for testing. 22 | """ 23 | for cmd in commands: 24 | print(f'Dummy launcher: {cmd}') 25 | 26 | def multi_gpu_launcher(commands): 27 | """ 28 | Launch commands on the local machine, using all GPUs in parallel. 29 | """ 30 | print('WARNING: using experimental multi_gpu_launcher.') 31 | n_gpus = torch.cuda.device_count() 32 | procs_by_gpu = [None]*n_gpus 33 | 34 | while len(commands) > 0: 35 | for gpu_idx in range(n_gpus): 36 | proc = procs_by_gpu[gpu_idx] 37 | if (proc is None) or (proc.poll() is not None): 38 | # Nothing is running on this GPU; launch a command. 39 | cmd = commands.pop(0) 40 | new_proc = subprocess.Popen( 41 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True) 42 | procs_by_gpu[gpu_idx] = new_proc 43 | break 44 | time.sleep(1) 45 | 46 | # Wait for the last few tasks to finish before returning 47 | for p in procs_by_gpu: 48 | if p is not None: 49 | p.wait() 50 | 51 | REGISTRY = { 52 | 'local': local_launcher, 53 | 'dummy': dummy_launcher, 54 | 'multi_gpu': multi_gpu_launcher 55 | } 56 | 57 | try: 58 | from domainbed import facebook 59 | facebook.register_command_launchers(REGISTRY) 60 | except ImportError: 61 | pass 62 | -------------------------------------------------------------------------------- /domainbed/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | from PIL import Image, ImageFile 6 | from torchvision import transforms 7 | import torchvision.datasets.folder 8 | from torch.utils.data import TensorDataset, Subset 9 | from torchvision.datasets import MNIST, ImageFolder 10 | from torchvision.transforms.functional import rotate 11 | 12 | # from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 13 | # from wilds.datasets.fmow_dataset import FMoWDataset 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | DATASETS = [ 18 | # Debug 19 | "Debug28", 20 | "Debug224", 21 | # Small images 22 | "ColoredMNIST", 23 | "RotatedMNIST", 24 | # Big images 25 | "VLCS", 26 | "PACS", 27 | "OfficeHome", 28 | "TerraIncognita", 29 | "DomainNet", 30 | # "SVIRO", 31 | # # WILDS datasets 32 | # "WILDSCamelyon", 33 | # "WILDSFMoW" 34 | ] 35 | 36 | def get_dataset_class(dataset_name): 37 | """Return the dataset class with the given name.""" 38 | if dataset_name not in globals(): 39 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 40 | return globals()[dataset_name] 41 | 42 | 43 | def num_environments(dataset_name): 44 | return len(get_dataset_class(dataset_name).ENVIRONMENTS) 45 | 46 | 47 | class MultipleDomainDataset: 48 | N_STEPS = 5001 # Default, subclasses may override 49 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 50 | N_WORKERS = 8 # Default, subclasses may override 51 | ENVIRONMENTS = None # Subclasses should override 52 | INPUT_SHAPE = None # Subclasses should override 53 | 54 | def __getitem__(self, index): 55 | return self.datasets[index] 56 | 57 | def __len__(self): 58 | return len(self.datasets) 59 | 60 | 61 | class Debug(MultipleDomainDataset): 62 | def __init__(self, root, test_envs, hparams): 63 | super().__init__() 64 | self.input_shape = self.INPUT_SHAPE 65 | self.num_classes = 2 66 | self.datasets = [] 67 | for _ in [0, 1, 2]: 68 | self.datasets.append( 69 | TensorDataset( 70 | torch.randn(16, *self.INPUT_SHAPE), 71 | torch.randint(0, self.num_classes, (16,)) 72 | ) 73 | ) 74 | 75 | class Debug28(Debug): 76 | INPUT_SHAPE = (3, 28, 28) 77 | ENVIRONMENTS = ['0', '1', '2'] 78 | 79 | class Debug224(Debug): 80 | INPUT_SHAPE = (3, 224, 224) 81 | ENVIRONMENTS = ['0', '1', '2'] 82 | 83 | 84 | class MultipleEnvironmentMNIST(MultipleDomainDataset): 85 | def __init__(self, root, environments, dataset_transform, input_shape, 86 | num_classes): 87 | super().__init__() 88 | if root is None: 89 | raise ValueError('Data directory not specified!') 90 | 91 | original_dataset_tr = MNIST(root, train=True, download=True) 92 | original_dataset_te = MNIST(root, train=False, download=True) 93 | 94 | original_images = torch.cat((original_dataset_tr.data, 95 | original_dataset_te.data)) 96 | 97 | original_labels = torch.cat((original_dataset_tr.targets, 98 | original_dataset_te.targets)) 99 | 100 | shuffle = torch.randperm(len(original_images)) 101 | 102 | original_images = original_images[shuffle] 103 | original_labels = original_labels[shuffle] 104 | 105 | self.datasets = [] 106 | 107 | for i in range(len(environments)): 108 | images = original_images[i::len(environments)] 109 | labels = original_labels[i::len(environments)] 110 | self.datasets.append(dataset_transform(images, labels, environments[i])) 111 | 112 | self.input_shape = input_shape 113 | self.num_classes = num_classes 114 | 115 | 116 | class ColoredMNIST(MultipleEnvironmentMNIST): 117 | ENVIRONMENTS = ['+90%', '+80%', '-90%'] 118 | 119 | def __init__(self, root, test_envs, hparams): 120 | super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9], 121 | self.color_dataset, (2, 28, 28,), 2) 122 | 123 | self.input_shape = (2, 28, 28,) 124 | self.num_classes = 2 125 | 126 | def color_dataset(self, images, labels, environment): 127 | # # Subsample 2x for computational convenience 128 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2] 129 | # Assign a binary label based on the digit 130 | labels = (labels < 5).float() 131 | # Flip label with probability 0.25 132 | labels = self.torch_xor_(labels, 133 | self.torch_bernoulli_(0.25, len(labels))) 134 | 135 | # Assign a color based on the label; flip the color with probability e 136 | colors = self.torch_xor_(labels, 137 | self.torch_bernoulli_(environment, 138 | len(labels))) 139 | images = torch.stack([images, images], dim=1) 140 | # Apply the color to the image by zeroing out the other color channel 141 | images[torch.tensor(range(len(images))), ( 142 | 1 - colors).long(), :, :] *= 0 143 | 144 | x = images.float().div_(255.0) 145 | y = labels.view(-1).long() 146 | 147 | return TensorDataset(x, y) 148 | 149 | def torch_bernoulli_(self, p, size): 150 | return (torch.rand(size) < p).float() 151 | 152 | def torch_xor_(self, a, b): 153 | return (a - b).abs() 154 | 155 | 156 | class RotatedMNIST(MultipleEnvironmentMNIST): 157 | ENVIRONMENTS = ['0', '15', '30', '45', '60', '75'] 158 | 159 | def __init__(self, root, test_envs, hparams): 160 | super(RotatedMNIST, self).__init__(root, [0, 15, 30, 45, 60, 75], 161 | self.rotate_dataset, (1, 28, 28,), 10) 162 | 163 | def rotate_dataset(self, images, labels, angle): 164 | rotation = transforms.Compose([ 165 | transforms.ToPILImage(), 166 | transforms.Lambda(lambda x: rotate(x, angle, fill=(0,), 167 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR)), 168 | transforms.ToTensor()]) 169 | 170 | x = torch.zeros(len(images), 1, 28, 28) 171 | for i in range(len(images)): 172 | x[i] = rotation(images[i]) 173 | 174 | y = labels.view(-1) 175 | 176 | return TensorDataset(x, y) 177 | 178 | 179 | class MultipleEnvironmentImageFolder(MultipleDomainDataset): 180 | def __init__(self, root, test_envs, augment, hparams): 181 | super().__init__() 182 | environments = [f.name for f in os.scandir(root) if f.is_dir()] 183 | environments = sorted(environments) 184 | 185 | transform = transforms.Compose([ 186 | transforms.Resize((224,224)), 187 | transforms.ToTensor(), 188 | transforms.Normalize( 189 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 190 | ]) 191 | 192 | augment_transform = transforms.Compose([ 193 | # transforms.Resize((224,224)), 194 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 195 | transforms.RandomHorizontalFlip(), 196 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 197 | transforms.RandomGrayscale(), 198 | transforms.ToTensor(), 199 | transforms.Normalize( 200 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 201 | ]) 202 | 203 | self.datasets = [] 204 | for i, environment in enumerate(environments): 205 | 206 | if augment and (i not in test_envs): 207 | env_transform = augment_transform 208 | else: 209 | env_transform = transform 210 | 211 | path = os.path.join(root, environment) 212 | env_dataset = ImageFolder(path, 213 | transform=env_transform) 214 | 215 | self.datasets.append(env_dataset) 216 | 217 | self.input_shape = (3, 224, 224,) 218 | self.num_classes = len(self.datasets[-1].classes) 219 | 220 | class VLCS(MultipleEnvironmentImageFolder): 221 | CHECKPOINT_FREQ = 300 222 | ENVIRONMENTS = ["C", "L", "S", "V"] 223 | def __init__(self, root, test_envs, hparams): 224 | self.dir = os.path.join(root, "VLCS/") 225 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 226 | 227 | class PACS(MultipleEnvironmentImageFolder): 228 | CHECKPOINT_FREQ = 300 229 | ENVIRONMENTS = ["A", "C", "P", "S"] 230 | def __init__(self, root, test_envs, hparams): 231 | self.dir = os.path.join(root, "PACS/") 232 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 233 | 234 | class DomainNet(MultipleEnvironmentImageFolder): 235 | CHECKPOINT_FREQ = 1000 236 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] 237 | def __init__(self, root, test_envs, hparams): 238 | self.dir = os.path.join(root, "domain_net/") 239 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 240 | 241 | class OfficeHome(MultipleEnvironmentImageFolder): 242 | CHECKPOINT_FREQ = 300 243 | ENVIRONMENTS = ["A", "C", "P", "R"] 244 | def __init__(self, root, test_envs, hparams): 245 | self.dir = os.path.join(root, "office_home/") 246 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 247 | 248 | class TerraIncognita(MultipleEnvironmentImageFolder): 249 | CHECKPOINT_FREQ = 300 250 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"] 251 | def __init__(self, root, test_envs, hparams): 252 | self.dir = os.path.join(root, "terra_incognita/") 253 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 254 | 255 | 256 | # class SVIRO(MultipleEnvironmentImageFolder): 257 | # CHECKPOINT_FREQ = 300 258 | # ENVIRONMENTS = [ 259 | # "aclass", "escape", "hilux", "i3", "lexus", "tesla", "tiguan", "tucson", "x5", "zoe" 260 | # ] 261 | 262 | # def __init__(self, root, test_envs, hparams): 263 | # self.dir = os.path.join(root, "sviro/") 264 | # super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) 265 | 266 | 267 | # class WILDSEnvironment: 268 | 269 | # def __init__(self, wilds_dataset, metadata_name, metadata_value, transform=None): 270 | # self.name = metadata_name + "_" + str(metadata_value) 271 | 272 | # metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 273 | # metadata_array = wilds_dataset.metadata_array 274 | # subset_indices = torch.where(metadata_array[:, metadata_index] == metadata_value)[0] 275 | 276 | # self.dataset = wilds_dataset 277 | # self.indices = subset_indices 278 | # self.transform = transform 279 | 280 | # def __getitem__(self, i): 281 | # x = self.dataset.get_input(self.indices[i]) 282 | # if type(x).__name__ != "Image": 283 | # x = Image.fromarray(x) 284 | 285 | # y = self.dataset.y_array[self.indices[i]] 286 | # if self.transform is not None: 287 | # x = self.transform(x) 288 | # return x, y 289 | 290 | # def __len__(self): 291 | # return len(self.indices) 292 | 293 | 294 | # class WILDSDataset(MultipleDomainDataset): 295 | # INPUT_SHAPE = (3, 224, 224) 296 | 297 | # def __init__(self, dataset, metadata_name, test_envs, augment, hparams): 298 | # super().__init__() 299 | 300 | # transform = transforms.Compose( 301 | # [ 302 | # transforms.Resize((224, 224)), 303 | # transforms.ToTensor(), 304 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 305 | # ] 306 | # ) 307 | 308 | # augment_transform = transforms.Compose( 309 | # [ 310 | # transforms.Resize((224, 224)), 311 | # transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 312 | # transforms.RandomHorizontalFlip(), 313 | # transforms.ColorJitter(0.3, 0.3, 0.3, 0.3), 314 | # transforms.RandomGrayscale(), 315 | # transforms.ToTensor(), 316 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 317 | # ] 318 | # ) 319 | 320 | # self.datasets = [] 321 | 322 | # for i, metadata_value in enumerate(self.metadata_values(dataset, metadata_name)): 323 | # if augment and (i not in test_envs): 324 | # env_transform = augment_transform 325 | # else: 326 | # env_transform = transform 327 | 328 | # env_dataset = WILDSEnvironment(dataset, metadata_name, metadata_value, env_transform) 329 | 330 | # self.datasets.append(env_dataset) 331 | 332 | # self.input_shape = ( 333 | # 3, 334 | # 224, 335 | # 224, 336 | # ) 337 | # self.num_classes = dataset.n_classes 338 | 339 | # def metadata_values(self, wilds_dataset, metadata_name): 340 | # metadata_index = wilds_dataset.metadata_fields.index(metadata_name) 341 | # metadata_vals = wilds_dataset.metadata_array[:, metadata_index] 342 | # return sorted(list(set(metadata_vals.view(-1).tolist()))) 343 | 344 | 345 | # class WILDSCamelyon(WILDSDataset): 346 | # ENVIRONMENTS = ["hospital_0", "hospital_1", "hospital_2", "hospital_3", "hospital_4"] 347 | 348 | # def __init__(self, root, test_envs, hparams): 349 | # dataset = Camelyon17Dataset(root_dir=root) 350 | # super().__init__(dataset, "hospital", test_envs, hparams['data_augmentation'], hparams) 351 | 352 | 353 | # class WILDSFMoW(WILDSDataset): 354 | # ENVIRONMENTS = ["region_0", "region_1", "region_2", "region_3", "region_4", "region_5"] 355 | 356 | # def __init__(self, root, test_envs, hparams): 357 | # dataset = FMoWDataset(root_dir=root) 358 | # super().__init__(dataset, "region", test_envs, hparams['data_augmentation'], hparams) 359 | -------------------------------------------------------------------------------- /domainbed/hparams_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import numpy as np 3 | from domainbed.lib import misc 4 | 5 | 6 | def _define_hparam(hparams, hparam_name, default_val, random_val_fn): 7 | hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn) 8 | 9 | 10 | def _hparams(algorithm, dataset, random_seed): 11 | """ 12 | Global registry of hyperparams. Each entry is a (default, random) tuple. 13 | New algorithms / networks / etc. should add entries here. 14 | """ 15 | SMALL_IMAGES = ['Debug28', 'RotatedMNIST', 'ColoredMNIST'] 16 | 17 | hparams = {} 18 | 19 | def _hparam(name, default_val, random_val_fn): 20 | """Define a hyperparameter. random_val_fn takes a RandomState and 21 | returns a random hyperparameter value.""" 22 | assert(name not in hparams) 23 | random_state = np.random.RandomState( 24 | misc.seed_hash(random_seed, name) 25 | ) 26 | hparams[name] = (default_val, random_val_fn(random_state)) 27 | 28 | # Unconditional hparam definitions. 29 | 30 | _hparam('data_augmentation', True, lambda r: True) 31 | _hparam('resnet18', False, lambda r: False) 32 | _hparam('resnet_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 33 | _hparam('class_balanced', False, lambda r: False) 34 | # TODO: nonlinear classifiers disabled 35 | _hparam('nonlinear_classifier', False, 36 | lambda r: bool(r.choice([False, False]))) 37 | 38 | # Algorithm-specific hparam definitions. Each block of code below 39 | # corresponds to exactly one algorithm. 40 | 41 | if algorithm in ['DANN', 'CDANN']: 42 | _hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2)) 43 | _hparam('weight_decay_d', 0., lambda r: 10**r.uniform(-6, -2)) 44 | _hparam('d_steps_per_g_step', 1, lambda r: int(2**r.uniform(0, 3))) 45 | _hparam('grad_penalty', 0., lambda r: 10**r.uniform(-2, 1)) 46 | _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5])) 47 | _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10))) 48 | _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5]))) 49 | _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5])) 50 | 51 | elif algorithm == 'Fish': 52 | _hparam('meta_lr', 0.5, lambda r:r.choice([0.05, 0.1, 0.5])) 53 | 54 | elif algorithm == "RSC": 55 | _hparam('rsc_f_drop_factor', 1/3, lambda r: r.uniform(0, 0.5)) 56 | _hparam('rsc_b_drop_factor', 1/3, lambda r: r.uniform(0, 0.5)) 57 | 58 | elif algorithm == "SagNet": 59 | _hparam('sag_w_adv', 0.1, lambda r: 10**r.uniform(-2, 1)) 60 | 61 | elif algorithm == "IRM": 62 | _hparam('irm_lambda', 1e2, lambda r: 10**r.uniform(-1, 5)) 63 | _hparam('irm_penalty_anneal_iters', 500, 64 | lambda r: int(10**r.uniform(0, 4))) 65 | 66 | elif algorithm == "Mixup": 67 | _hparam('mixup_alpha', 0.2, lambda r: 10**r.uniform(-1, -1)) 68 | 69 | elif algorithm == "GroupDRO": 70 | _hparam('groupdro_eta', 1e-2, lambda r: 10**r.uniform(-3, -1)) 71 | 72 | elif algorithm == "MMD" or algorithm == "CORAL": 73 | _hparam('mmd_gamma', 1., lambda r: 10**r.uniform(-1, 1)) 74 | 75 | elif algorithm == "MLDG": 76 | _hparam('mldg_beta', 1., lambda r: 10**r.uniform(-1, 1)) 77 | 78 | elif algorithm == "MTL": 79 | _hparam('mtl_ema', .99, lambda r: r.choice([0.5, 0.9, 0.99, 1.])) 80 | 81 | elif algorithm == "VREx": 82 | _hparam('vrex_lambda', 1e1, lambda r: 10**r.uniform(-1, 5)) 83 | _hparam('vrex_penalty_anneal_iters', 500, 84 | lambda r: int(10**r.uniform(0, 4))) 85 | 86 | elif algorithm == "SD": 87 | _hparam('sd_reg', 0.1, lambda r: 10**r.uniform(-5, -1)) 88 | 89 | elif algorithm == "ANDMask": 90 | _hparam('tau', 1, lambda r: r.uniform(0.5, 1.)) 91 | 92 | elif algorithm == "IGA": 93 | _hparam('penalty', 1000, lambda r: 10**r.uniform(1, 5)) 94 | 95 | elif algorithm == "SANDMask": 96 | _hparam('tau', 1.0, lambda r: r.uniform(0.0, 1.)) 97 | _hparam('k', 1e+1, lambda r: int(10**r.uniform(-3, 5))) 98 | 99 | elif algorithm == "Fishr": 100 | _hparam('lambda', 1000., lambda r: 10**r.uniform(1., 4.)) 101 | _hparam('penalty_anneal_iters', 1500, lambda r: int(r.uniform(0., 5000.))) 102 | _hparam('ema', 0.95, lambda r: r.uniform(0.90, 0.99)) 103 | 104 | # Dataset-and-algorithm-specific hparam definitions. Each block of code 105 | # below corresponds to exactly one hparam. Avoid nested conditionals. 106 | 107 | if dataset in SMALL_IMAGES: 108 | _hparam('lr', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 109 | else: 110 | _hparam('lr', 5e-5, lambda r: 10**r.uniform(-5, -3.5)) 111 | 112 | if dataset in SMALL_IMAGES: 113 | _hparam('weight_decay', 0., lambda r: 0.) 114 | else: 115 | _hparam('weight_decay', 0., lambda r: 10**r.uniform(-6, -2)) 116 | 117 | if dataset in SMALL_IMAGES: 118 | _hparam('batch_size', 64, lambda r: int(2**r.uniform(3, 9))) 119 | elif algorithm == 'ARM': 120 | _hparam('batch_size', 8, lambda r: 8) 121 | elif dataset == 'DomainNet': 122 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5))) 123 | else: 124 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5.5))) 125 | 126 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 127 | _hparam('lr_g', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 128 | elif algorithm in ['DANN', 'CDANN']: 129 | _hparam('lr_g', 5e-5, lambda r: 10**r.uniform(-5, -3.5)) 130 | 131 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 132 | _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5)) 133 | elif algorithm in ['DANN', 'CDANN']: 134 | _hparam('lr_d', 5e-5, lambda r: 10**r.uniform(-5, -3.5)) 135 | 136 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES: 137 | _hparam('weight_decay_g', 0., lambda r: 0.) 138 | elif algorithm in ['DANN', 'CDANN']: 139 | _hparam('weight_decay_g', 0., lambda r: 10**r.uniform(-6, -2)) 140 | 141 | return hparams 142 | 143 | 144 | def default_hparams(algorithm, dataset): 145 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()} 146 | 147 | 148 | def random_hparams(algorithm, dataset, seed): 149 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()} 150 | -------------------------------------------------------------------------------- /domainbed/lib/fast_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | 5 | class _InfiniteSampler(torch.utils.data.Sampler): 6 | """Wraps another Sampler to yield an infinite stream.""" 7 | def __init__(self, sampler): 8 | self.sampler = sampler 9 | 10 | def __iter__(self): 11 | while True: 12 | for batch in self.sampler: 13 | yield batch 14 | 15 | class InfiniteDataLoader: 16 | def __init__(self, dataset, weights, batch_size, num_workers): 17 | super().__init__() 18 | 19 | if weights is not None: 20 | sampler = torch.utils.data.WeightedRandomSampler(weights, 21 | replacement=True, 22 | num_samples=batch_size) 23 | else: 24 | sampler = torch.utils.data.RandomSampler(dataset, 25 | replacement=True) 26 | 27 | if weights == None: 28 | weights = torch.ones(len(dataset)) 29 | 30 | batch_sampler = torch.utils.data.BatchSampler( 31 | sampler, 32 | batch_size=batch_size, 33 | drop_last=True) 34 | 35 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 36 | dataset, 37 | num_workers=num_workers, 38 | batch_sampler=_InfiniteSampler(batch_sampler) 39 | )) 40 | 41 | def __iter__(self): 42 | while True: 43 | yield next(self._infinite_iterator) 44 | 45 | def __len__(self): 46 | raise ValueError 47 | 48 | class FastDataLoader: 49 | """DataLoader wrapper with slightly improved speed by not respawning worker 50 | processes at every epoch.""" 51 | def __init__(self, dataset, batch_size, num_workers): 52 | super().__init__() 53 | 54 | batch_sampler = torch.utils.data.BatchSampler( 55 | torch.utils.data.RandomSampler(dataset, replacement=False), 56 | batch_size=batch_size, 57 | drop_last=False 58 | ) 59 | 60 | self._infinite_iterator = iter(torch.utils.data.DataLoader( 61 | dataset, 62 | num_workers=num_workers, 63 | batch_sampler=_InfiniteSampler(batch_sampler) 64 | )) 65 | 66 | self._length = len(batch_sampler) 67 | 68 | def __iter__(self): 69 | for _ in range(len(self)): 70 | yield next(self._infinite_iterator) 71 | 72 | def __len__(self): 73 | return self._length 74 | -------------------------------------------------------------------------------- /domainbed/lib/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Things that don't belong anywhere else 5 | """ 6 | 7 | import hashlib 8 | import json 9 | import os 10 | import sys 11 | from shutil import copyfile 12 | from collections import OrderedDict, defaultdict 13 | from numbers import Number 14 | import operator 15 | 16 | import numpy as np 17 | import torch 18 | import tqdm 19 | from collections import Counter 20 | 21 | 22 | def l2_between_dicts(dict_1, dict_2): 23 | assert len(dict_1) == len(dict_2) 24 | dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())] 25 | dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())] 26 | return ( 27 | torch.cat(tuple([t.view(-1) for t in dict_1_values])) - 28 | torch.cat(tuple([t.view(-1) for t in dict_2_values])) 29 | ).pow(2).mean() 30 | 31 | class MovingAverage: 32 | 33 | def __init__(self, ema, oneminusema_correction=True): 34 | self.ema = ema 35 | self.named_parameters = {} 36 | self._updates = 0 37 | self._oneminusema_correction = oneminusema_correction 38 | 39 | def update(self, dict_data): 40 | ema_dict_data = {} 41 | for name, data in dict_data.items(): 42 | data = data.view(1, -1) 43 | if self._updates == 0: 44 | previous_data = torch.zeros_like(data) 45 | else: 46 | previous_data = self.named_parameters[name] 47 | 48 | ema_data = self.ema * previous_data + (1 - self.ema) * data 49 | if self._oneminusema_correction: 50 | ema_dict_data[name] = ema_data / (1 - self.ema) 51 | else: 52 | ema_dict_data[name] = ema_data 53 | self.named_parameters[name] = ema_data.clone().detach() 54 | 55 | self._updates += 1 56 | return ema_dict_data 57 | 58 | 59 | 60 | def make_weights_for_balanced_classes(dataset): 61 | counts = Counter() 62 | classes = [] 63 | for _, y in dataset: 64 | y = int(y) 65 | counts[y] += 1 66 | classes.append(y) 67 | 68 | n_classes = len(counts) 69 | 70 | weight_per_class = {} 71 | for y in counts: 72 | weight_per_class[y] = 1 / (counts[y] * n_classes) 73 | 74 | weights = torch.zeros(len(dataset)) 75 | for i, y in enumerate(classes): 76 | weights[i] = weight_per_class[int(y)] 77 | 78 | return weights 79 | 80 | def pdb(): 81 | sys.stdout = sys.__stdout__ 82 | import pdb 83 | print("Launching PDB, enter 'n' to step to parent function.") 84 | pdb.set_trace() 85 | 86 | def seed_hash(*args): 87 | """ 88 | Derive an integer hash from all args, for use as a random seed. 89 | """ 90 | args_str = str(args) 91 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) 92 | 93 | def print_separator(): 94 | print("="*80) 95 | 96 | def print_row(row, colwidth=10, latex=False): 97 | if latex: 98 | sep = " & " 99 | end_ = "\\\\" 100 | else: 101 | sep = " " 102 | end_ = "" 103 | 104 | def format_val(x): 105 | if np.issubdtype(type(x), np.floating): 106 | x = "{:.10f}".format(x) 107 | return str(x).ljust(colwidth)[:colwidth] 108 | print(sep.join([format_val(x) for x in row]), end_) 109 | 110 | class _SplitDataset(torch.utils.data.Dataset): 111 | """Used by split_dataset""" 112 | def __init__(self, underlying_dataset, keys): 113 | super(_SplitDataset, self).__init__() 114 | self.underlying_dataset = underlying_dataset 115 | self.keys = keys 116 | def __getitem__(self, key): 117 | return self.underlying_dataset[self.keys[key]] 118 | def __len__(self): 119 | return len(self.keys) 120 | 121 | def split_dataset(dataset, n, seed=0): 122 | """ 123 | Return a pair of datasets corresponding to a random split of the given 124 | dataset, with n datapoints in the first dataset and the rest in the last, 125 | using the given random seed 126 | """ 127 | assert(n <= len(dataset)) 128 | keys = list(range(len(dataset))) 129 | np.random.RandomState(seed).shuffle(keys) 130 | keys_1 = keys[:n] 131 | keys_2 = keys[n:] 132 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) 133 | 134 | def random_pairs_of_minibatches(minibatches): 135 | perm = torch.randperm(len(minibatches)).tolist() 136 | pairs = [] 137 | 138 | for i in range(len(minibatches)): 139 | j = i + 1 if i < (len(minibatches) - 1) else 0 140 | 141 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] 142 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] 143 | 144 | min_n = min(len(xi), len(xj)) 145 | 146 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) 147 | 148 | return pairs 149 | 150 | def accuracy(network, loader, weights, device): 151 | correct = 0 152 | total = 0 153 | weights_offset = 0 154 | 155 | network.eval() 156 | with torch.no_grad(): 157 | for x, y in loader: 158 | x = x.to(device) 159 | y = y.to(device) 160 | p = network.predict(x) 161 | if weights is None: 162 | batch_weights = torch.ones(len(x)) 163 | else: 164 | batch_weights = weights[weights_offset : weights_offset + len(x)] 165 | weights_offset += len(x) 166 | batch_weights = batch_weights.to(device) 167 | if p.size(1) == 1: 168 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() 169 | else: 170 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() 171 | total += batch_weights.sum().item() 172 | network.train() 173 | 174 | return correct / total 175 | 176 | class Tee: 177 | def __init__(self, fname, mode="a"): 178 | self.stdout = sys.stdout 179 | self.file = open(fname, mode) 180 | 181 | def write(self, message): 182 | self.stdout.write(message) 183 | self.file.write(message) 184 | self.flush() 185 | 186 | def flush(self): 187 | self.stdout.flush() 188 | self.file.flush() 189 | 190 | class ParamDict(OrderedDict): 191 | """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. 192 | A dictionary where the values are Tensors, meant to represent weights of 193 | a model. This subclass lets you perform arithmetic on weights directly.""" 194 | 195 | def __init__(self, *args, **kwargs): 196 | super().__init__(*args, *kwargs) 197 | 198 | def _prototype(self, other, op): 199 | if isinstance(other, Number): 200 | return ParamDict({k: op(v, other) for k, v in self.items()}) 201 | elif isinstance(other, dict): 202 | return ParamDict({k: op(self[k], other[k]) for k in self}) 203 | else: 204 | raise NotImplementedError 205 | 206 | def __add__(self, other): 207 | return self._prototype(other, operator.add) 208 | 209 | def __rmul__(self, other): 210 | return self._prototype(other, operator.mul) 211 | 212 | __mul__ = __rmul__ 213 | 214 | def __neg__(self): 215 | return ParamDict({k: -v for k, v in self.items()}) 216 | 217 | def __rsub__(self, other): 218 | # a- b := a + (-b) 219 | return self.__add__(other.__neg__()) 220 | 221 | __sub__ = __rsub__ 222 | 223 | def __truediv__(self, other): 224 | return self._prototype(other, operator.truediv) 225 | -------------------------------------------------------------------------------- /domainbed/lib/query.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """Small query library.""" 4 | 5 | import collections 6 | import inspect 7 | import json 8 | import types 9 | import unittest 10 | import warnings 11 | import math 12 | 13 | import numpy as np 14 | 15 | 16 | def make_selector_fn(selector): 17 | """ 18 | If selector is a function, return selector. 19 | Otherwise, return a function corresponding to the selector string. Examples 20 | of valid selector strings and the corresponding functions: 21 | x lambda obj: obj['x'] 22 | x.y lambda obj: obj['x']['y'] 23 | x,y lambda obj: (obj['x'], obj['y']) 24 | """ 25 | if isinstance(selector, str): 26 | if ',' in selector: 27 | parts = selector.split(',') 28 | part_selectors = [make_selector_fn(part) for part in parts] 29 | return lambda obj: tuple(sel(obj) for sel in part_selectors) 30 | elif '.' in selector: 31 | parts = selector.split('.') 32 | part_selectors = [make_selector_fn(part) for part in parts] 33 | def f(obj): 34 | for sel in part_selectors: 35 | obj = sel(obj) 36 | return obj 37 | return f 38 | else: 39 | key = selector.strip() 40 | return lambda obj: obj[key] 41 | elif isinstance(selector, types.FunctionType): 42 | return selector 43 | else: 44 | raise TypeError 45 | 46 | def hashable(obj): 47 | try: 48 | hash(obj) 49 | return obj 50 | except TypeError: 51 | return json.dumps({'_':obj}, sort_keys=True) 52 | 53 | class Q(object): 54 | def __init__(self, list_): 55 | super(Q, self).__init__() 56 | self._list = list_ 57 | 58 | def __len__(self): 59 | return len(self._list) 60 | 61 | def __getitem__(self, key): 62 | return self._list[key] 63 | 64 | def __eq__(self, other): 65 | if isinstance(other, self.__class__): 66 | return self._list == other._list 67 | else: 68 | return self._list == other 69 | 70 | def __str__(self): 71 | return str(self._list) 72 | 73 | def __repr__(self): 74 | return repr(self._list) 75 | 76 | def _append(self, item): 77 | """Unsafe, be careful you know what you're doing.""" 78 | self._list.append(item) 79 | 80 | def group(self, selector): 81 | """ 82 | Group elements by selector and return a list of (group, group_records) 83 | tuples. 84 | """ 85 | selector = make_selector_fn(selector) 86 | groups = {} 87 | for x in self._list: 88 | group = selector(x) 89 | group_key = hashable(group) 90 | if group_key not in groups: 91 | groups[group_key] = (group, Q([])) 92 | groups[group_key][1]._append(x) 93 | results = [groups[key] for key in sorted(groups.keys())] 94 | return Q(results) 95 | 96 | def group_map(self, selector, fn): 97 | """ 98 | Group elements by selector, apply fn to each group, and return a list 99 | of the results. 100 | """ 101 | return self.group(selector).map(fn) 102 | 103 | def map(self, fn): 104 | """ 105 | map self onto fn. If fn takes multiple args, tuple-unpacking 106 | is applied. 107 | """ 108 | if len(inspect.signature(fn).parameters) > 1: 109 | return Q([fn(*x) for x in self._list]) 110 | else: 111 | return Q([fn(x) for x in self._list]) 112 | 113 | def select(self, selector): 114 | selector = make_selector_fn(selector) 115 | return Q([selector(x) for x in self._list]) 116 | 117 | def min(self): 118 | return min(self._list) 119 | 120 | def max(self): 121 | return max(self._list) 122 | 123 | def sum(self): 124 | return sum(self._list) 125 | 126 | def len(self): 127 | return len(self._list) 128 | 129 | def mean(self): 130 | with warnings.catch_warnings(): 131 | warnings.simplefilter("ignore") 132 | return float(np.mean(self._list)) 133 | 134 | def std(self): 135 | with warnings.catch_warnings(): 136 | warnings.simplefilter("ignore") 137 | return float(np.std(self._list)) 138 | 139 | def mean_std(self): 140 | return (self.mean(), self.std()) 141 | 142 | def argmax(self, selector): 143 | selector = make_selector_fn(selector) 144 | return max(self._list, key=selector) 145 | 146 | def filter(self, fn): 147 | return Q([x for x in self._list if fn(x)]) 148 | 149 | def filter_equals(self, selector, value): 150 | """like [x for x in y if x.selector == value]""" 151 | selector = make_selector_fn(selector) 152 | return self.filter(lambda r: selector(r) == value) 153 | 154 | def filter_not_none(self): 155 | return self.filter(lambda r: r is not None) 156 | 157 | def filter_not_nan(self): 158 | return self.filter(lambda r: not np.isnan(r)) 159 | 160 | def flatten(self): 161 | return Q([y for x in self._list for y in x]) 162 | 163 | def unique(self): 164 | result = [] 165 | result_set = set() 166 | for x in self._list: 167 | hashable_x = hashable(x) 168 | if hashable_x not in result_set: 169 | result_set.add(hashable_x) 170 | result.append(x) 171 | return Q(result) 172 | 173 | def sorted(self, key=None): 174 | if key is None: 175 | key = lambda x: x 176 | def key2(x): 177 | x = key(x) 178 | if isinstance(x, (np.floating, float)) and np.isnan(x): 179 | return float('-inf') 180 | else: 181 | return x 182 | return Q(sorted(self._list, key=key2)) 183 | -------------------------------------------------------------------------------- /domainbed/lib/reporting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | import json 6 | import os 7 | 8 | import tqdm 9 | 10 | from domainbed.lib.query import Q 11 | 12 | def load_records(path): 13 | records = [] 14 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))), 15 | ncols=80, 16 | leave=False): 17 | results_path = os.path.join(path, subdir, "results.jsonl") 18 | try: 19 | with open(results_path, "r") as f: 20 | for line in f: 21 | records.append(json.loads(line[:-1])) 22 | except IOError: 23 | pass 24 | 25 | return Q(records) 26 | 27 | def get_grouped_records(records): 28 | """Group records by (trial_seed, dataset, algorithm, test_env). Because 29 | records can have multiple test envs, a given record may appear in more than 30 | one group.""" 31 | result = collections.defaultdict(lambda: []) 32 | for r in records: 33 | for test_env in r["args"]["test_envs"]: 34 | group = (r["args"]["trial_seed"], 35 | r["args"]["dataset"], 36 | r["args"]["algorithm"], 37 | test_env) 38 | result[group].append(r) 39 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, "test_env": e, 40 | "records": Q(r)} for (t,d,a,e),r in result.items()]) 41 | -------------------------------------------------------------------------------- /domainbed/lib/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | From https://github.com/meliketoy/wide-resnet.pytorch 5 | """ 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.nn.init as init 14 | from torch.autograd import Variable 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | return nn.Conv2d( 19 | in_planes, 20 | out_planes, 21 | kernel_size=3, 22 | stride=stride, 23 | padding=1, 24 | bias=True) 25 | 26 | 27 | def conv_init(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Conv') != -1: 30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 31 | init.constant_(m.bias, 0) 32 | elif classname.find('BatchNorm') != -1: 33 | init.constant_(m.weight, 1) 34 | init.constant_(m.bias, 0) 35 | 36 | 37 | class wide_basic(nn.Module): 38 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 39 | super(wide_basic, self).__init__() 40 | self.bn1 = nn.BatchNorm2d(in_planes) 41 | self.conv1 = nn.Conv2d( 42 | in_planes, planes, kernel_size=3, padding=1, bias=True) 43 | self.dropout = nn.Dropout(p=dropout_rate) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d( 46 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 47 | 48 | self.shortcut = nn.Sequential() 49 | if stride != 1 or in_planes != planes: 50 | self.shortcut = nn.Sequential( 51 | nn.Conv2d( 52 | in_planes, planes, kernel_size=1, stride=stride, 53 | bias=True), ) 54 | 55 | def forward(self, x): 56 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 57 | out = self.conv2(F.relu(self.bn2(out))) 58 | out += self.shortcut(x) 59 | 60 | return out 61 | 62 | 63 | class Wide_ResNet(nn.Module): 64 | """Wide Resnet with the softmax layer chopped off""" 65 | def __init__(self, input_shape, depth, widen_factor, dropout_rate): 66 | super(Wide_ResNet, self).__init__() 67 | self.in_planes = 16 68 | 69 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 70 | n = (depth - 4) / 6 71 | k = widen_factor 72 | 73 | # print('| Wide-Resnet %dx%d' % (depth, k)) 74 | nStages = [16, 16 * k, 32 * k, 64 * k] 75 | 76 | self.conv1 = conv3x3(input_shape[0], nStages[0]) 77 | self.layer1 = self._wide_layer( 78 | wide_basic, nStages[1], n, dropout_rate, stride=1) 79 | self.layer2 = self._wide_layer( 80 | wide_basic, nStages[2], n, dropout_rate, stride=2) 81 | self.layer3 = self._wide_layer( 82 | wide_basic, nStages[3], n, dropout_rate, stride=2) 83 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 84 | 85 | self.n_outputs = nStages[3] 86 | 87 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 88 | strides = [stride] + [1] * (int(num_blocks) - 1) 89 | layers = [] 90 | 91 | for stride in strides: 92 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 93 | self.in_planes = planes 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | out = self.conv1(x) 99 | out = self.layer1(out) 100 | out = self.layer2(out) 101 | out = self.layer3(out) 102 | out = F.relu(self.bn1(out)) 103 | out = F.avg_pool2d(out, 8) 104 | return out[:, :, 0, 0] 105 | -------------------------------------------------------------------------------- /domainbed/model_selection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import itertools 4 | import numpy as np 5 | 6 | def get_test_records(records): 7 | """Given records with a common test env, get the test records (i.e. the 8 | records with *only* that single test env and no other test envs)""" 9 | return records.filter(lambda r: len(r['args']['test_envs']) == 1) 10 | 11 | class SelectionMethod: 12 | """Abstract class whose subclasses implement strategies for model 13 | selection across hparams and timesteps.""" 14 | 15 | def __init__(self): 16 | raise TypeError 17 | 18 | @classmethod 19 | def run_acc(self, run_records): 20 | """ 21 | Given records from a run, return a {val_acc, test_acc} dict representing 22 | the best val-acc and corresponding test-acc for that run. 23 | """ 24 | raise NotImplementedError 25 | 26 | @classmethod 27 | def hparams_accs(self, records): 28 | """ 29 | Given all records from a single (dataset, algorithm, test env) pair, 30 | return a sorted list of (run_acc, records) tuples. 31 | """ 32 | return (records.group('args.hparams_seed') 33 | .map(lambda _, run_records: 34 | ( 35 | self.run_acc(run_records), 36 | run_records 37 | ) 38 | ).filter(lambda x: x[0] is not None) 39 | .sorted(key=lambda x: x[0]['val_acc'])[::-1] 40 | ) 41 | 42 | @classmethod 43 | def sweep_acc(self, records): 44 | """ 45 | Given all records from a single (dataset, algorithm, test env) pair, 46 | return the mean test acc of the k runs with the top val accs. 47 | """ 48 | _hparams_accs = self.hparams_accs(records) 49 | if len(_hparams_accs): 50 | return _hparams_accs[0][0]['test_acc'] 51 | else: 52 | return None 53 | 54 | class OracleSelectionMethod(SelectionMethod): 55 | """Like Selection method which picks argmax(test_out_acc) across all hparams 56 | and checkpoints, but instead of taking the argmax over all 57 | checkpoints, we pick the last checkpoint, i.e. no early stopping.""" 58 | name = "test-domain validation set (oracle)" 59 | 60 | @classmethod 61 | def run_acc(self, run_records): 62 | run_records = run_records.filter(lambda r: 63 | len(r['args']['test_envs']) == 1) 64 | if not len(run_records): 65 | return None 66 | test_env = run_records[0]['args']['test_envs'][0] 67 | test_out_acc_key = 'env{}_out_acc'.format(test_env) 68 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 69 | chosen_record = run_records.sorted(lambda r: r['step'])[-1] 70 | return { 71 | 'val_acc': chosen_record[test_out_acc_key], 72 | 'test_acc': chosen_record[test_in_acc_key] 73 | } 74 | 75 | class IIDAccuracySelectionMethod(SelectionMethod): 76 | """Picks argmax(mean(env_out_acc for env in train_envs))""" 77 | name = "training-domain validation set" 78 | 79 | @classmethod 80 | def _step_acc(self, record): 81 | """Given a single record, return a {val_acc, test_acc} dict.""" 82 | test_env = record['args']['test_envs'][0] 83 | val_env_keys = [] 84 | for i in itertools.count(): 85 | if f'env{i}_out_acc' not in record: 86 | break 87 | if i != test_env: 88 | val_env_keys.append(f'env{i}_out_acc') 89 | test_in_acc_key = 'env{}_in_acc'.format(test_env) 90 | return { 91 | 'val_acc': np.mean([record[key] for key in val_env_keys]), 92 | 'test_acc': record[test_in_acc_key] 93 | } 94 | 95 | @classmethod 96 | def run_acc(self, run_records): 97 | test_records = get_test_records(run_records) 98 | if not len(test_records): 99 | return None 100 | return test_records.map(self._step_acc).argmax('val_acc') 101 | 102 | class LeaveOneOutSelectionMethod(SelectionMethod): 103 | """Picks (hparams, step) by leave-one-out cross validation.""" 104 | name = "leave-one-domain-out cross-validation" 105 | 106 | @classmethod 107 | def _step_acc(self, records): 108 | """Return the {val_acc, test_acc} for a group of records corresponding 109 | to a single step.""" 110 | test_records = get_test_records(records) 111 | if len(test_records) != 1: 112 | return None 113 | 114 | test_env = test_records[0]['args']['test_envs'][0] 115 | n_envs = 0 116 | for i in itertools.count(): 117 | if f'env{i}_out_acc' not in records[0]: 118 | break 119 | n_envs += 1 120 | val_accs = np.zeros(n_envs) - 1 121 | for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): 122 | val_env = (set(r['args']['test_envs']) - set([test_env])).pop() 123 | val_accs[val_env] = r['env{}_in_acc'.format(val_env)] 124 | val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:]) 125 | if any([v==-1 for v in val_accs]): 126 | return None 127 | val_acc = np.sum(val_accs) / (n_envs-1) 128 | return { 129 | 'val_acc': val_acc, 130 | 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] 131 | } 132 | 133 | @classmethod 134 | def run_acc(self, records): 135 | step_accs = records.group('step').map(lambda step, step_records: 136 | self._step_acc(step_records) 137 | ).filter_not_none() 138 | if len(step_accs): 139 | return step_accs.argmax('val_acc') 140 | else: 141 | return None 142 | -------------------------------------------------------------------------------- /domainbed/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models 7 | 8 | from domainbed.lib import wide_resnet 9 | import copy 10 | 11 | 12 | def remove_batch_norm_from_resnet(model): 13 | fuse = torch.nn.utils.fusion.fuse_conv_bn_eval 14 | model.eval() 15 | 16 | model.conv1 = fuse(model.conv1, model.bn1) 17 | model.bn1 = Identity() 18 | 19 | for name, module in model.named_modules(): 20 | if name.startswith("layer") and len(name) == 6: 21 | for b, bottleneck in enumerate(module): 22 | for name2, module2 in bottleneck.named_modules(): 23 | if name2.startswith("conv"): 24 | bn_name = "bn" + name2[-1] 25 | setattr(bottleneck, name2, 26 | fuse(module2, getattr(bottleneck, bn_name))) 27 | setattr(bottleneck, bn_name, Identity()) 28 | if isinstance(bottleneck.downsample, torch.nn.Sequential): 29 | bottleneck.downsample[0] = fuse(bottleneck.downsample[0], 30 | bottleneck.downsample[1]) 31 | bottleneck.downsample[1] = Identity() 32 | model.train() 33 | return model 34 | 35 | 36 | class Identity(nn.Module): 37 | """An identity layer""" 38 | def __init__(self): 39 | super(Identity, self).__init__() 40 | 41 | def forward(self, x): 42 | return x 43 | 44 | 45 | class MLP(nn.Module): 46 | """Just an MLP""" 47 | def __init__(self, n_inputs, n_outputs, hparams): 48 | super(MLP, self).__init__() 49 | self.input = nn.Linear(n_inputs, hparams['mlp_width']) 50 | self.dropout = nn.Dropout(hparams['mlp_dropout']) 51 | self.hiddens = nn.ModuleList([ 52 | nn.Linear(hparams['mlp_width'], hparams['mlp_width']) 53 | for _ in range(hparams['mlp_depth']-2)]) 54 | self.output = nn.Linear(hparams['mlp_width'], n_outputs) 55 | self.n_outputs = n_outputs 56 | 57 | def forward(self, x): 58 | x = self.input(x) 59 | x = self.dropout(x) 60 | x = F.relu(x) 61 | for hidden in self.hiddens: 62 | x = hidden(x) 63 | x = self.dropout(x) 64 | x = F.relu(x) 65 | x = self.output(x) 66 | return x 67 | 68 | 69 | class ResNet(torch.nn.Module): 70 | """ResNet with the softmax chopped off and the batchnorm frozen""" 71 | def __init__(self, input_shape, hparams): 72 | super(ResNet, self).__init__() 73 | if hparams['resnet18']: 74 | self.network = torchvision.models.resnet18(pretrained=True) 75 | self.n_outputs = 512 76 | else: 77 | self.network = torchvision.models.resnet50(pretrained=True) 78 | self.n_outputs = 2048 79 | 80 | # self.network = remove_batch_norm_from_resnet(self.network) 81 | 82 | # adapt number of channels 83 | nc = input_shape[0] 84 | if nc != 3: 85 | tmp = self.network.conv1.weight.data.clone() 86 | 87 | self.network.conv1 = nn.Conv2d( 88 | nc, 64, kernel_size=(7, 7), 89 | stride=(2, 2), padding=(3, 3), bias=False) 90 | 91 | for i in range(nc): 92 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] 93 | 94 | # save memory 95 | del self.network.fc 96 | self.network.fc = Identity() 97 | 98 | self.freeze_bn() 99 | self.hparams = hparams 100 | self.dropout = nn.Dropout(hparams['resnet_dropout']) 101 | 102 | def forward(self, x): 103 | """Encode x into a feature vector of size n_outputs.""" 104 | return self.dropout(self.network(x)) 105 | 106 | def train(self, mode=True): 107 | """ 108 | Override the default train() to freeze the BN parameters 109 | """ 110 | super().train(mode) 111 | self.freeze_bn() 112 | 113 | def freeze_bn(self): 114 | for m in self.network.modules(): 115 | if isinstance(m, nn.BatchNorm2d): 116 | m.eval() 117 | 118 | 119 | class MNIST_CNN(nn.Module): 120 | """ 121 | Hand-tuned architecture for MNIST. 122 | Weirdness I've noticed so far with this architecture: 123 | - adding a linear layer after the mean-pool in features hurts 124 | RotatedMNIST-100 generalization severely. 125 | """ 126 | n_outputs = 128 127 | 128 | def __init__(self, input_shape): 129 | super(MNIST_CNN, self).__init__() 130 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) 131 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) 132 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) 133 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) 134 | 135 | self.bn0 = nn.GroupNorm(8, 64) 136 | self.bn1 = nn.GroupNorm(8, 128) 137 | self.bn2 = nn.GroupNorm(8, 128) 138 | self.bn3 = nn.GroupNorm(8, 128) 139 | 140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 141 | 142 | def forward(self, x): 143 | x = self.conv1(x) 144 | x = F.relu(x) 145 | x = self.bn0(x) 146 | 147 | x = self.conv2(x) 148 | x = F.relu(x) 149 | x = self.bn1(x) 150 | 151 | x = self.conv3(x) 152 | x = F.relu(x) 153 | x = self.bn2(x) 154 | 155 | x = self.conv4(x) 156 | x = F.relu(x) 157 | x = self.bn3(x) 158 | 159 | x = self.avgpool(x) 160 | x = x.view(len(x), -1) 161 | return x 162 | 163 | 164 | class ContextNet(nn.Module): 165 | def __init__(self, input_shape): 166 | super(ContextNet, self).__init__() 167 | 168 | # Keep same dimensions 169 | padding = (5 - 1) // 2 170 | self.context_net = nn.Sequential( 171 | nn.Conv2d(input_shape[0], 64, 5, padding=padding), 172 | nn.BatchNorm2d(64), 173 | nn.ReLU(), 174 | nn.Conv2d(64, 64, 5, padding=padding), 175 | nn.BatchNorm2d(64), 176 | nn.ReLU(), 177 | nn.Conv2d(64, 1, 5, padding=padding), 178 | ) 179 | 180 | def forward(self, x): 181 | return self.context_net(x) 182 | 183 | 184 | def Featurizer(input_shape, hparams): 185 | """Auto-select an appropriate featurizer for the given input shape.""" 186 | if len(input_shape) == 1: 187 | return MLP(input_shape[0], hparams["mlp_width"], hparams) 188 | elif input_shape[1:3] == (28, 28): 189 | return MNIST_CNN(input_shape) 190 | elif input_shape[1:3] == (32, 32): 191 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.) 192 | elif input_shape[1:3] == (224, 224): 193 | return ResNet(input_shape, hparams) 194 | else: 195 | raise NotImplementedError 196 | 197 | 198 | def Classifier(in_features, out_features, is_nonlinear=False): 199 | if is_nonlinear: 200 | return torch.nn.Sequential( 201 | torch.nn.Linear(in_features, in_features // 2), 202 | torch.nn.ReLU(), 203 | torch.nn.Linear(in_features // 2, in_features // 4), 204 | torch.nn.ReLU(), 205 | torch.nn.Linear(in_features // 4, out_features)) 206 | else: 207 | return torch.nn.Linear(in_features, out_features) 208 | 209 | 210 | class WholeFish(nn.Module): 211 | def __init__(self, input_shape, num_classes, hparams, weights=None): 212 | super(WholeFish, self).__init__() 213 | featurizer = Featurizer(input_shape, hparams) 214 | classifier = Classifier( 215 | featurizer.n_outputs, 216 | num_classes, 217 | hparams['nonlinear_classifier']) 218 | self.net = nn.Sequential( 219 | featurizer, classifier 220 | ) 221 | if weights is not None: 222 | self.load_state_dict(copy.deepcopy(weights)) 223 | 224 | def reset_weights(self, weights): 225 | self.load_state_dict(copy.deepcopy(weights)) 226 | 227 | def forward(self, x): 228 | return self.net(x) 229 | -------------------------------------------------------------------------------- /domainbed/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | -------------------------------------------------------------------------------- /domainbed/scripts/collect_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | 5 | 6 | import argparse 7 | import functools 8 | import glob 9 | import pickle 10 | import itertools 11 | import json 12 | import os 13 | import random 14 | import sys 15 | 16 | import numpy as np 17 | import tqdm 18 | 19 | from domainbed import datasets 20 | from domainbed import algorithms 21 | from domainbed.lib import misc, reporting 22 | from domainbed import model_selection 23 | from domainbed.lib.query import Q 24 | import warnings 25 | 26 | def format_mean(data, latex): 27 | """Given a list of datapoints, return a string describing their mean and 28 | standard error""" 29 | if len(data) == 0: 30 | return None, None, "X" 31 | mean = 100 * np.mean(list(data)) 32 | err = 100 * np.std(list(data) / np.sqrt(len(data))) 33 | if latex: 34 | return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err) 35 | else: 36 | return mean, err, "{:.1f} +/- {:.1f}".format(mean, err) 37 | 38 | def print_table(table, header_text, row_labels, col_labels, colwidth=10, 39 | latex=True): 40 | """Pretty-print a 2D array of data, optionally with row/col labels""" 41 | print("") 42 | 43 | if latex: 44 | num_cols = len(table[0]) 45 | print("\\begin{center}") 46 | print("\\adjustbox{max width=\\textwidth}{%") 47 | print("\\begin{tabular}{l" + "c" * num_cols + "}") 48 | print("\\toprule") 49 | else: 50 | print("--------", header_text) 51 | 52 | for row, label in zip(table, row_labels): 53 | row.insert(0, label) 54 | 55 | if latex: 56 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}" 57 | for col_label in col_labels] 58 | table.insert(0, col_labels) 59 | 60 | for r, row in enumerate(table): 61 | misc.print_row(row, colwidth=colwidth, latex=latex) 62 | if latex and r == 0: 63 | print("\\midrule") 64 | if latex: 65 | print("\\bottomrule") 66 | print("\\end{tabular}}") 67 | print("\\end{center}") 68 | 69 | def print_results_tables(records, selection_method, latex): 70 | """Given all records, print a results table for each dataset.""" 71 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 72 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 73 | ).filter(lambda g: g["sweep_acc"] is not None) 74 | 75 | # read algorithm names and sort (predefined order) 76 | alg_names = Q(records).select("args.algorithm").unique() 77 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 78 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 79 | 80 | # read dataset names and sort (lexicographic order) 81 | dataset_names = Q(records).select("args.dataset").unique().sorted() 82 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 83 | 84 | for dataset in dataset_names: 85 | if latex: 86 | print() 87 | print("\\subsubsection{{{}}}".format(dataset)) 88 | test_envs = range(datasets.num_environments(dataset)) 89 | 90 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 91 | for i, algorithm in enumerate(alg_names): 92 | means = [] 93 | for j, test_env in enumerate(test_envs): 94 | trial_accs = (grouped_records 95 | .filter_equals( 96 | "dataset, algorithm, test_env", 97 | (dataset, algorithm, test_env) 98 | ).select("sweep_acc")) 99 | mean, err, table[i][j] = format_mean(trial_accs, latex) 100 | means.append(mean) 101 | if None in means: 102 | table[i][-1] = "X" 103 | else: 104 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 105 | 106 | col_labels = [ 107 | "Algorithm", 108 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 109 | "Avg" 110 | ] 111 | header_text = (f"Dataset: {dataset}, " 112 | f"model selection method: {selection_method.name}") 113 | print_table(table, header_text, alg_names, list(col_labels), 114 | colwidth=20, latex=latex) 115 | 116 | # Print an "averages" table 117 | if latex: 118 | print() 119 | print("\\subsubsection{Averages}") 120 | 121 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 122 | for i, algorithm in enumerate(alg_names): 123 | means = [] 124 | for j, dataset in enumerate(dataset_names): 125 | trial_averages = (grouped_records 126 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 127 | .group("trial_seed") 128 | .map(lambda trial_seed, group: 129 | group.select("sweep_acc").mean() 130 | ) 131 | ) 132 | mean, err, table[i][j] = format_mean(trial_averages, latex) 133 | means.append(mean) 134 | if None in means: 135 | table[i][-1] = "X" 136 | else: 137 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 138 | 139 | col_labels = ["Algorithm", *dataset_names, "Avg"] 140 | header_text = f"Averages, model selection method: {selection_method.name}" 141 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 142 | latex=latex) 143 | 144 | if __name__ == "__main__": 145 | np.set_printoptions(suppress=True) 146 | 147 | parser = argparse.ArgumentParser( 148 | description="Domain generalization testbed") 149 | parser.add_argument("--input_dir", type=str, required=True) 150 | parser.add_argument("--latex", action="store_true") 151 | args = parser.parse_args() 152 | 153 | results_file = "results.tex" if args.latex else "results.txt" 154 | 155 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w") 156 | 157 | records = reporting.load_records(args.input_dir) 158 | 159 | if args.latex: 160 | print("\\documentclass{article}") 161 | print("\\usepackage{booktabs}") 162 | print("\\usepackage{adjustbox}") 163 | print("\\begin{document}") 164 | print("\\section{Full DomainBed results}") 165 | print("% Total records:", len(records)) 166 | else: 167 | print("Total records:", len(records)) 168 | 169 | SELECTION_METHODS = [ 170 | model_selection.IIDAccuracySelectionMethod, 171 | model_selection.LeaveOneOutSelectionMethod, 172 | model_selection.OracleSelectionMethod, 173 | ] 174 | 175 | for selection_method in SELECTION_METHODS: 176 | if args.latex: 177 | print() 178 | print("\\subsection{{Model selection: {}}}".format( 179 | selection_method.name)) 180 | print_results_tables(records, selection_method, args.latex) 181 | 182 | if args.latex: 183 | print("\\end{document}") 184 | -------------------------------------------------------------------------------- /domainbed/scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from torchvision.datasets import MNIST 4 | import xml.etree.ElementTree as ET 5 | from zipfile import ZipFile 6 | import argparse 7 | import tarfile 8 | import shutil 9 | import gdown 10 | import uuid 11 | import json 12 | import os 13 | 14 | # from wilds.datasets.camelyon17_dataset import Camelyon17Dataset 15 | # from wilds.datasets.fmow_dataset import FMoWDataset 16 | 17 | 18 | # utils ####################################################################### 19 | 20 | def stage_path(data_dir, name): 21 | full_path = os.path.join(data_dir, name) 22 | 23 | if not os.path.exists(full_path): 24 | os.makedirs(full_path) 25 | 26 | return full_path 27 | 28 | 29 | def download_and_extract(url, dst, remove=True): 30 | gdown.download(url, dst, quiet=False) 31 | 32 | if dst.endswith(".tar.gz"): 33 | tar = tarfile.open(dst, "r:gz") 34 | tar.extractall(os.path.dirname(dst)) 35 | tar.close() 36 | 37 | if dst.endswith(".tar"): 38 | tar = tarfile.open(dst, "r:") 39 | tar.extractall(os.path.dirname(dst)) 40 | tar.close() 41 | 42 | if dst.endswith(".zip"): 43 | zf = ZipFile(dst, "r") 44 | zf.extractall(os.path.dirname(dst)) 45 | zf.close() 46 | 47 | if remove: 48 | os.remove(dst) 49 | 50 | 51 | # VLCS ######################################################################## 52 | 53 | # Slower, but builds dataset from the original sources 54 | # 55 | # def download_vlcs(data_dir): 56 | # full_path = stage_path(data_dir, "VLCS") 57 | # 58 | # tmp_path = os.path.join(full_path, "tmp/") 59 | # if not os.path.exists(tmp_path): 60 | # os.makedirs(tmp_path) 61 | # 62 | # with open("domainbed/misc/vlcs_files.txt", "r") as f: 63 | # lines = f.readlines() 64 | # files = [line.strip().split() for line in lines] 65 | # 66 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar", 67 | # os.path.join(tmp_path, "voc2007_trainval.tar")) 68 | # 69 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz", 70 | # os.path.join(tmp_path, "caltech101.tar.gz")) 71 | # 72 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar", 73 | # os.path.join(tmp_path, "sun09_hcontext.tar")) 74 | # 75 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:") 76 | # tar.extractall(tmp_path) 77 | # tar.close() 78 | # 79 | # for src, dst in files: 80 | # class_folder = os.path.join(data_dir, dst) 81 | # 82 | # if not os.path.exists(class_folder): 83 | # os.makedirs(class_folder) 84 | # 85 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg") 86 | # 87 | # if "labelme" in src: 88 | # # download labelme from the web 89 | # gdown.download(src, dst, quiet=False) 90 | # else: 91 | # src = os.path.join(tmp_path, src) 92 | # shutil.copyfile(src, dst) 93 | # 94 | # shutil.rmtree(tmp_path) 95 | 96 | 97 | def download_vlcs(data_dir): 98 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 99 | full_path = stage_path(data_dir, "VLCS") 100 | 101 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8", 102 | os.path.join(data_dir, "VLCS.tar.gz")) 103 | 104 | 105 | # MNIST ####################################################################### 106 | 107 | def download_mnist(data_dir): 108 | # Original URL: http://yann.lecun.com/exdb/mnist/ 109 | full_path = stage_path(data_dir, "MNIST") 110 | MNIST(full_path, download=True) 111 | 112 | 113 | # PACS ######################################################################## 114 | 115 | def download_pacs(data_dir): 116 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017 117 | full_path = stage_path(data_dir, "PACS") 118 | 119 | download_and_extract("https://drive.google.com/uc?id=0B6x7gtvErXgfbF9CSk53UkRxVzg", 120 | os.path.join(data_dir, "PACS.zip")) 121 | 122 | os.rename(os.path.join(data_dir, "kfold"), 123 | full_path) 124 | 125 | 126 | # Office-Home ################################################################# 127 | 128 | def download_office_home(data_dir): 129 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 130 | full_path = stage_path(data_dir, "office_home") 131 | 132 | download_and_extract("https://drive.google.com/uc?id=0B81rNlvomiwed0V1YUxQdC1uOTg", 133 | os.path.join(data_dir, "office_home.zip")) 134 | 135 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 136 | full_path) 137 | 138 | 139 | # DomainNET ################################################################### 140 | 141 | def download_domain_net(data_dir): 142 | # Original URL: http://ai.bu.edu/M3SDA/ 143 | full_path = stage_path(data_dir, "domain_net") 144 | 145 | urls = [ 146 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 147 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 148 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 149 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 150 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 151 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 152 | ] 153 | 154 | for url in urls: 155 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 156 | 157 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f: 158 | for line in f.readlines(): 159 | try: 160 | os.remove(os.path.join(full_path, line.strip())) 161 | except OSError: 162 | pass 163 | 164 | 165 | # TerraIncognita ############################################################## 166 | 167 | def download_terra_incognita(data_dir): 168 | # Original URL: https://beerys.github.io/CaltechCameraTraps/ 169 | # New URL: http://lila.science/datasets/caltech-camera-traps 170 | 171 | full_path = stage_path(data_dir, "terra_incognita") 172 | 173 | download_and_extract( 174 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz", 175 | os.path.join(full_path, "terra_incognita_images.tar.gz")) 176 | 177 | download_and_extract( 178 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip", 179 | os.path.join(full_path, "caltech_camera_traps.json.zip")) 180 | 181 | include_locations = ["38", "46", "100", "43"] 182 | 183 | include_categories = [ 184 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit", 185 | "raccoon", "squirrel" 186 | ] 187 | 188 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/") 189 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json") 190 | destination_folder = full_path 191 | 192 | stats = {} 193 | 194 | if not os.path.exists(destination_folder): 195 | os.mkdir(destination_folder) 196 | 197 | with open(annotations_file, "r") as f: 198 | data = json.load(f) 199 | 200 | category_dict = {} 201 | for item in data['categories']: 202 | category_dict[item['id']] = item['name'] 203 | 204 | for image in data['images']: 205 | image_location = image['location'] 206 | 207 | if image_location not in include_locations: 208 | continue 209 | 210 | loc_folder = os.path.join(destination_folder, 211 | 'location_' + str(image_location) + '/') 212 | 213 | if not os.path.exists(loc_folder): 214 | os.mkdir(loc_folder) 215 | 216 | image_id = image['id'] 217 | image_fname = image['file_name'] 218 | 219 | for annotation in data['annotations']: 220 | if annotation['image_id'] == image_id: 221 | if image_location not in stats: 222 | stats[image_location] = {} 223 | 224 | category = category_dict[annotation['category_id']] 225 | 226 | if category not in include_categories: 227 | continue 228 | 229 | if category not in stats[image_location]: 230 | stats[image_location][category] = 0 231 | else: 232 | stats[image_location][category] += 1 233 | 234 | loc_cat_folder = os.path.join(loc_folder, category + '/') 235 | 236 | if not os.path.exists(loc_cat_folder): 237 | os.mkdir(loc_cat_folder) 238 | 239 | dst_path = os.path.join(loc_cat_folder, image_fname) 240 | src_path = os.path.join(images_folder, image_fname) 241 | 242 | shutil.copyfile(src_path, dst_path) 243 | 244 | shutil.rmtree(images_folder) 245 | os.remove(annotations_file) 246 | 247 | 248 | # # SVIRO ################################################################# 249 | 250 | # def download_sviro(data_dir): 251 | # # Original URL: https://sviro.kl.dfki.de 252 | # full_path = stage_path(data_dir, "sviro") 253 | 254 | # download_and_extract("https://sviro.kl.dfki.de/?wpdmdl=1731", 255 | # os.path.join(data_dir, "sviro_grayscale_rectangle_classification.zip")) 256 | 257 | # os.rename(os.path.join(data_dir, "SVIRO_DOMAINBED"), 258 | # full_path) 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser(description='Download datasets') 263 | parser.add_argument('--data_dir', type=str, required=True) 264 | args = parser.parse_args() 265 | 266 | download_mnist(args.data_dir) 267 | download_pacs(args.data_dir) 268 | download_office_home(args.data_dir) 269 | download_domain_net(args.data_dir) 270 | download_vlcs(args.data_dir) 271 | download_terra_incognita(args.data_dir) 272 | 273 | # download_sviro(args.data_dir) 274 | # Camelyon17Dataset(root_dir=args.data_dir, download=True) 275 | # FMoWDataset(root_dir=args.data_dir, download=True) 276 | -------------------------------------------------------------------------------- /domainbed/scripts/list_top_hparams.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Example usage: 5 | python -u -m domainbed.scripts.list_top_hparams \ 6 | --input_dir domainbed/misc/test_sweep_data --algorithm ERM \ 7 | --dataset VLCS --test_env 0 8 | """ 9 | 10 | import collections 11 | 12 | 13 | import argparse 14 | import functools 15 | import glob 16 | import pickle 17 | import itertools 18 | import json 19 | import os 20 | import random 21 | import sys 22 | 23 | import numpy as np 24 | import tqdm 25 | 26 | from domainbed import datasets 27 | from domainbed import algorithms 28 | from domainbed.lib import misc, reporting 29 | from domainbed import model_selection 30 | from domainbed.lib.query import Q 31 | import warnings 32 | 33 | def todo_rename(records, selection_method, latex): 34 | 35 | grouped_records = reporting.get_grouped_records(records).map(lambda group: 36 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) } 37 | ).filter(lambda g: g["sweep_acc"] is not None) 38 | 39 | # read algorithm names and sort (predefined order) 40 | alg_names = Q(records).select("args.algorithm").unique() 41 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] + 42 | [n for n in alg_names if n not in algorithms.ALGORITHMS]) 43 | 44 | # read dataset names and sort (lexicographic order) 45 | dataset_names = Q(records).select("args.dataset").unique().sorted() 46 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names] 47 | 48 | for dataset in dataset_names: 49 | if latex: 50 | print() 51 | print("\\subsubsection{{{}}}".format(dataset)) 52 | test_envs = range(datasets.num_environments(dataset)) 53 | 54 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names] 55 | for i, algorithm in enumerate(alg_names): 56 | means = [] 57 | for j, test_env in enumerate(test_envs): 58 | trial_accs = (grouped_records 59 | .filter_equals( 60 | "dataset, algorithm, test_env", 61 | (dataset, algorithm, test_env) 62 | ).select("sweep_acc")) 63 | mean, err, table[i][j] = format_mean(trial_accs, latex) 64 | means.append(mean) 65 | if None in means: 66 | table[i][-1] = "X" 67 | else: 68 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 69 | 70 | col_labels = [ 71 | "Algorithm", 72 | *datasets.get_dataset_class(dataset).ENVIRONMENTS, 73 | "Avg" 74 | ] 75 | header_text = (f"Dataset: {dataset}, " 76 | f"model selection method: {selection_method.name}") 77 | print_table(table, header_text, alg_names, list(col_labels), 78 | colwidth=20, latex=latex) 79 | 80 | # Print an "averages" table 81 | if latex: 82 | print() 83 | print("\\subsubsection{Averages}") 84 | 85 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names] 86 | for i, algorithm in enumerate(alg_names): 87 | means = [] 88 | for j, dataset in enumerate(dataset_names): 89 | trial_averages = (grouped_records 90 | .filter_equals("algorithm, dataset", (algorithm, dataset)) 91 | .group("trial_seed") 92 | .map(lambda trial_seed, group: 93 | group.select("sweep_acc").mean() 94 | ) 95 | ) 96 | mean, err, table[i][j] = format_mean(trial_averages, latex) 97 | means.append(mean) 98 | if None in means: 99 | table[i][-1] = "X" 100 | else: 101 | table[i][-1] = "{:.1f}".format(sum(means) / len(means)) 102 | 103 | col_labels = ["Algorithm", *dataset_names, "Avg"] 104 | header_text = f"Averages, model selection method: {selection_method.name}" 105 | print_table(table, header_text, alg_names, col_labels, colwidth=25, 106 | latex=latex) 107 | 108 | if __name__ == "__main__": 109 | np.set_printoptions(suppress=True) 110 | 111 | parser = argparse.ArgumentParser( 112 | description="Domain generalization testbed") 113 | parser.add_argument("--input_dir", required=True) 114 | parser.add_argument('--dataset', required=True) 115 | parser.add_argument('--algorithm', required=True) 116 | parser.add_argument('--test_env', type=int, required=True) 117 | args = parser.parse_args() 118 | 119 | records = reporting.load_records(args.input_dir) 120 | print("Total records:", len(records)) 121 | 122 | records = reporting.get_grouped_records(records) 123 | records = records.filter( 124 | lambda r: 125 | r['dataset'] == args.dataset and 126 | r['algorithm'] == args.algorithm and 127 | r['test_env'] == args.test_env 128 | ) 129 | 130 | SELECTION_METHODS = [ 131 | model_selection.IIDAccuracySelectionMethod, 132 | model_selection.LeaveOneOutSelectionMethod, 133 | model_selection.OracleSelectionMethod, 134 | ] 135 | 136 | for selection_method in SELECTION_METHODS: 137 | print(f'Model selection: {selection_method.name}') 138 | 139 | for group in records: 140 | print(f"trial_seed: {group['trial_seed']}") 141 | best_hparams = selection_method.hparams_accs(group['records']) 142 | for run_acc, hparam_records in best_hparams: 143 | print(f"\t{run_acc}") 144 | for r in hparam_records: 145 | assert(r['hparams'] == hparam_records[0]['hparams']) 146 | print("\t\thparams:") 147 | for k, v in sorted(hparam_records[0]['hparams'].items()): 148 | print('\t\t\t{}: {}'.format(k, v)) 149 | print("\t\toutput_dirs:") 150 | output_dirs = hparam_records.select('args.output_dir').unique() 151 | for output_dir in output_dirs: 152 | print(f"\t\t\t{output_dir}") -------------------------------------------------------------------------------- /domainbed/scripts/save_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Save some representative images from each dataset to disk. 5 | """ 6 | import random 7 | import torch 8 | import argparse 9 | from domainbed import hparams_registry 10 | from domainbed import datasets 11 | import imageio 12 | import os 13 | from tqdm import tqdm 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser(description='Domain generalization') 17 | parser.add_argument('--data_dir', type=str) 18 | parser.add_argument('--output_dir', type=str) 19 | args = parser.parse_args() 20 | 21 | os.makedirs(args.output_dir, exist_ok=True) 22 | datasets_to_save = ['OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST', 'SVIRO'] 23 | 24 | for dataset_name in tqdm(datasets_to_save): 25 | hparams = hparams_registry.default_hparams('ERM', dataset_name) 26 | dataset = datasets.get_dataset_class(dataset_name)( 27 | args.data_dir, 28 | list(range(datasets.num_environments(dataset_name))), 29 | hparams) 30 | for env_idx, env in enumerate(tqdm(dataset)): 31 | for i in tqdm(range(50)): 32 | idx = random.choice(list(range(len(env)))) 33 | x, y = env[idx] 34 | while y > 10: 35 | idx = random.choice(list(range(len(env)))) 36 | x, y = env[idx] 37 | if x.shape[0] == 2: 38 | x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3,:,:] 39 | if x.min() < 0: 40 | mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None] 41 | std = torch.tensor([0.229, 0.224, 0.225])[:,None,None] 42 | x = (x * std) + mean 43 | assert(x.min() >= 0) 44 | assert(x.max() <= 1) 45 | x = (x * 255.99) 46 | x = x.numpy().astype('uint8').transpose(1,2,0) 47 | imageio.imwrite( 48 | os.path.join(args.output_dir, 49 | f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'), 50 | x) 51 | -------------------------------------------------------------------------------- /domainbed/scripts/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | """ 4 | Run sweeps 5 | """ 6 | 7 | import argparse 8 | import copy 9 | import getpass 10 | import hashlib 11 | import json 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | import uuid 17 | 18 | import numpy as np 19 | import torch 20 | 21 | from domainbed import datasets 22 | from domainbed import hparams_registry 23 | from domainbed import algorithms 24 | from domainbed.lib import misc 25 | from domainbed import command_launchers 26 | 27 | import tqdm 28 | import shlex 29 | 30 | class Job: 31 | NOT_LAUNCHED = 'Not launched' 32 | INCOMPLETE = 'Incomplete' 33 | DONE = 'Done' 34 | 35 | def __init__(self, train_args, sweep_output_dir): 36 | args_str = json.dumps(train_args, sort_keys=True) 37 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest() 38 | self.output_dir = os.path.join(sweep_output_dir, args_hash) 39 | 40 | self.train_args = copy.deepcopy(train_args) 41 | self.train_args['output_dir'] = self.output_dir 42 | command = ['python', '-m', 'domainbed.scripts.train'] 43 | for k, v in sorted(self.train_args.items()): 44 | if isinstance(v, list): 45 | v = ' '.join([str(v_) for v_ in v]) 46 | elif isinstance(v, str): 47 | v = shlex.quote(v) 48 | command.append(f'--{k} {v}') 49 | self.command_str = ' '.join(command) 50 | 51 | if os.path.exists(os.path.join(self.output_dir, 'done')): 52 | self.state = Job.DONE 53 | elif os.path.exists(self.output_dir): 54 | self.state = Job.INCOMPLETE 55 | else: 56 | self.state = Job.NOT_LAUNCHED 57 | 58 | def __str__(self): 59 | job_info = (self.train_args['dataset'], 60 | self.train_args['algorithm'], 61 | self.train_args['test_envs'], 62 | self.train_args['hparams_seed']) 63 | return '{}: {} {}'.format( 64 | self.state, 65 | self.output_dir, 66 | job_info) 67 | 68 | @staticmethod 69 | def launch(jobs, launcher_fn): 70 | print('Launching...') 71 | jobs = jobs.copy() 72 | np.random.shuffle(jobs) 73 | print('Making job directories:') 74 | for job in tqdm.tqdm(jobs, leave=False): 75 | os.makedirs(job.output_dir, exist_ok=True) 76 | commands = [job.command_str for job in jobs] 77 | launcher_fn(commands) 78 | print(f'Launched {len(jobs)} jobs!') 79 | 80 | @staticmethod 81 | def delete(jobs): 82 | print('Deleting...') 83 | for job in jobs: 84 | shutil.rmtree(job.output_dir) 85 | print(f'Deleted {len(jobs)} jobs!') 86 | 87 | def all_test_env_combinations(n): 88 | """ 89 | For a dataset with n >= 3 envs, return all combinations of 1 and 2 test 90 | envs. 91 | """ 92 | assert(n >= 3) 93 | for i in range(n): 94 | yield [i] 95 | for j in range(i+1, n): 96 | yield [i, j] 97 | 98 | def make_args_list(n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps, 99 | data_dir, task, holdout_fraction, single_test_envs, hparams): 100 | args_list = [] 101 | for trial_seed in range(n_trials): 102 | for dataset in dataset_names: 103 | for algorithm in algorithms: 104 | if single_test_envs: 105 | all_test_envs = [ 106 | [i] for i in range(datasets.num_environments(dataset))] 107 | else: 108 | all_test_envs = all_test_env_combinations( 109 | datasets.num_environments(dataset)) 110 | for test_envs in all_test_envs: 111 | for hparams_seed in range(n_hparams_from, n_hparams): 112 | train_args = {} 113 | train_args['dataset'] = dataset 114 | train_args['algorithm'] = algorithm 115 | train_args['test_envs'] = test_envs 116 | train_args['holdout_fraction'] = holdout_fraction 117 | train_args['hparams_seed'] = hparams_seed 118 | train_args['data_dir'] = data_dir 119 | train_args['task'] = task 120 | train_args['trial_seed'] = trial_seed 121 | train_args['seed'] = misc.seed_hash(dataset, 122 | algorithm, test_envs, hparams_seed, trial_seed) 123 | if steps is not None: 124 | train_args['steps'] = steps 125 | if hparams is not None: 126 | train_args['hparams'] = hparams 127 | args_list.append(train_args) 128 | return args_list 129 | 130 | def ask_for_confirmation(): 131 | response = input('Are you sure? (y/n) ') 132 | if not response.lower().strip()[:1] == "y": 133 | print('Nevermind!') 134 | exit(0) 135 | 136 | DATASETS = [d for d in datasets.DATASETS if "Debug" not in d] 137 | 138 | if __name__ == "__main__": 139 | parser = argparse.ArgumentParser(description='Run a sweep') 140 | parser.add_argument('command', choices=['launch', 'delete_incomplete']) 141 | parser.add_argument('--datasets', nargs='+', type=str, default=DATASETS) 142 | parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS) 143 | parser.add_argument('--task', type=str, default="domain_generalization") 144 | parser.add_argument('--n_hparams_from', type=int, default=0) 145 | parser.add_argument('--n_hparams', type=int, default=20) 146 | parser.add_argument('--output_dir', type=str, required=True) 147 | parser.add_argument('--data_dir', type=str, required=True) 148 | parser.add_argument('--seed', type=int, default=0) 149 | parser.add_argument('--n_trials', type=int, default=3) 150 | parser.add_argument('--command_launcher', type=str, required=True) 151 | parser.add_argument('--steps', type=int, default=None) 152 | parser.add_argument('--hparams', type=str, default=None) 153 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 154 | parser.add_argument('--single_test_envs', action='store_true') 155 | parser.add_argument('--skip_confirmation', action='store_true') 156 | args = parser.parse_args() 157 | 158 | args_list = make_args_list( 159 | n_trials=args.n_trials, 160 | dataset_names=args.datasets, 161 | algorithms=args.algorithms, 162 | n_hparams_from=args.n_hparams_from, 163 | n_hparams=args.n_hparams, 164 | steps=args.steps, 165 | data_dir=args.data_dir, 166 | task=args.task, 167 | holdout_fraction=args.holdout_fraction, 168 | single_test_envs=args.single_test_envs, 169 | hparams=args.hparams 170 | ) 171 | 172 | jobs = [Job(train_args, args.output_dir) for train_args in args_list] 173 | 174 | for job in jobs: 175 | print(job) 176 | print("{} jobs: {} done, {} incomplete, {} not launched.".format( 177 | len(jobs), 178 | len([j for j in jobs if j.state == Job.DONE]), 179 | len([j for j in jobs if j.state == Job.INCOMPLETE]), 180 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED])) 181 | ) 182 | 183 | if args.command == 'launch': 184 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED] 185 | print(f'About to launch {len(to_launch)} jobs.') 186 | if not args.skip_confirmation: 187 | ask_for_confirmation() 188 | launcher_fn = command_launchers.REGISTRY[args.command_launcher] 189 | Job.launch(to_launch, launcher_fn) 190 | 191 | elif args.command == 'delete_incomplete': 192 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE] 193 | print(f'About to delete {len(to_delete)} jobs.') 194 | if not args.skip_confirmation: 195 | ask_for_confirmation() 196 | Job.delete(to_delete) 197 | -------------------------------------------------------------------------------- /domainbed/scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import collections 5 | import json 6 | import os 7 | import random 8 | import sys 9 | import time 10 | import uuid 11 | 12 | import numpy as np 13 | import PIL 14 | import torch 15 | import torchvision 16 | import torch.utils.data 17 | 18 | from domainbed import datasets 19 | from domainbed import hparams_registry 20 | from domainbed import algorithms 21 | from domainbed.lib import misc 22 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description='Domain generalization') 26 | parser.add_argument('--data_dir', type=str) 27 | parser.add_argument('--dataset', type=str, default="RotatedMNIST") 28 | parser.add_argument('--algorithm', type=str, default="ERM") 29 | parser.add_argument('--task', type=str, default="domain_generalization", 30 | choices=["domain_generalization", "domain_adaptation"]) 31 | parser.add_argument('--hparams', type=str, 32 | help='JSON-serialized hparams dict') 33 | parser.add_argument('--hparams_seed', type=int, default=0, 34 | help='Seed for random hparams (0 means "default hparams")') 35 | parser.add_argument('--trial_seed', type=int, default=0, 36 | help='Trial number (used for seeding split_dataset and ' 37 | 'random_hparams).') 38 | parser.add_argument('--seed', type=int, default=0, 39 | help='Seed for everything else') 40 | parser.add_argument('--steps', type=int, default=None, 41 | help='Number of steps. Default is dataset-dependent.') 42 | parser.add_argument('--checkpoint_freq', type=int, default=None, 43 | help='Checkpoint every N steps. Default is dataset-dependent.') 44 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0]) 45 | parser.add_argument('--output_dir', type=str, default="train_output") 46 | parser.add_argument('--holdout_fraction', type=float, default=0.2) 47 | parser.add_argument('--uda_holdout_fraction', type=float, default=0, 48 | help="For domain adaptation, % of test to use unlabeled for training.") 49 | parser.add_argument('--skip_model_save', action='store_true') 50 | parser.add_argument('--save_model_every_checkpoint', action='store_true') 51 | args = parser.parse_args() 52 | 53 | # If we ever want to implement checkpointing, just persist these values 54 | # every once in a while, and then load them from disk here. 55 | start_step = 0 56 | algorithm_dict = None 57 | 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) 60 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) 61 | 62 | print("Environment:") 63 | print("\tPython: {}".format(sys.version.split(" ")[0])) 64 | print("\tPyTorch: {}".format(torch.__version__)) 65 | print("\tTorchvision: {}".format(torchvision.__version__)) 66 | print("\tCUDA: {}".format(torch.version.cuda)) 67 | print("\tCUDNN: {}".format(torch.backends.cudnn.version())) 68 | print("\tNumPy: {}".format(np.__version__)) 69 | print("\tPIL: {}".format(PIL.__version__)) 70 | 71 | print('Args:') 72 | for k, v in sorted(vars(args).items()): 73 | print('\t{}: {}'.format(k, v)) 74 | 75 | if args.hparams_seed == 0: 76 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset) 77 | else: 78 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset, 79 | misc.seed_hash(args.hparams_seed, args.trial_seed)) 80 | if args.hparams: 81 | hparams.update(json.loads(args.hparams)) 82 | 83 | print('HParams:') 84 | for k, v in sorted(hparams.items()): 85 | print('\t{}: {}'.format(k, v)) 86 | 87 | random.seed(args.seed) 88 | np.random.seed(args.seed) 89 | torch.manual_seed(args.seed) 90 | torch.backends.cudnn.deterministic = True 91 | torch.backends.cudnn.benchmark = False 92 | 93 | if torch.cuda.is_available(): 94 | device = "cuda" 95 | else: 96 | device = "cpu" 97 | 98 | if args.dataset in vars(datasets): 99 | dataset = vars(datasets)[args.dataset](args.data_dir, 100 | args.test_envs, hparams) 101 | else: 102 | raise NotImplementedError 103 | 104 | # Split each env into an 'in-split' and an 'out-split'. We'll train on 105 | # each in-split except the test envs, and evaluate on all splits. 106 | 107 | # To allow unsupervised domain adaptation experiments, we split each test 108 | # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used 109 | # by collect_results.py to compute classification accuracies. The 110 | # 'out-split' is used by the Oracle model selectino method. The unlabeled 111 | # samples in 'uda-split' are passed to the algorithm at training time if 112 | # args.task == "domain_adaptation". If we are interested in comparing 113 | # domain generalization and domain adaptation results, then domain 114 | # generalization algorithms should create the same 'uda-splits', which will 115 | # be discared at training. 116 | in_splits = [] 117 | out_splits = [] 118 | uda_splits = [] 119 | for env_i, env in enumerate(dataset): 120 | uda = [] 121 | 122 | out, in_ = misc.split_dataset(env, 123 | int(len(env)*args.holdout_fraction), 124 | misc.seed_hash(args.trial_seed, env_i)) 125 | 126 | if env_i in args.test_envs: 127 | uda, in_ = misc.split_dataset(in_, 128 | int(len(in_)*args.uda_holdout_fraction), 129 | misc.seed_hash(args.trial_seed, env_i)) 130 | 131 | if hparams['class_balanced']: 132 | in_weights = misc.make_weights_for_balanced_classes(in_) 133 | out_weights = misc.make_weights_for_balanced_classes(out) 134 | if uda is not None: 135 | uda_weights = misc.make_weights_for_balanced_classes(uda) 136 | else: 137 | in_weights, out_weights, uda_weights = None, None, None 138 | in_splits.append((in_, in_weights)) 139 | out_splits.append((out, out_weights)) 140 | if len(uda): 141 | uda_splits.append((uda, uda_weights)) 142 | 143 | if args.task == "domain_adaptation" and len(uda_splits) == 0: 144 | raise ValueError("Not enough unlabeled samples for domain adaptation.") 145 | 146 | train_loaders = [InfiniteDataLoader( 147 | dataset=env, 148 | weights=env_weights, 149 | batch_size=hparams['batch_size'], 150 | num_workers=dataset.N_WORKERS) 151 | for i, (env, env_weights) in enumerate(in_splits) 152 | if i not in args.test_envs] 153 | 154 | uda_loaders = [InfiniteDataLoader( 155 | dataset=env, 156 | weights=env_weights, 157 | batch_size=hparams['batch_size'], 158 | num_workers=dataset.N_WORKERS) 159 | for i, (env, env_weights) in enumerate(uda_splits) 160 | if i in args.test_envs] 161 | 162 | eval_loaders = [FastDataLoader( 163 | dataset=env, 164 | batch_size=64, 165 | num_workers=dataset.N_WORKERS) 166 | for env, _ in (in_splits + out_splits + uda_splits)] 167 | eval_weights = [None for _, weights in (in_splits + out_splits + uda_splits)] 168 | eval_loader_names = ['env{}_in'.format(i) 169 | for i in range(len(in_splits))] 170 | eval_loader_names += ['env{}_out'.format(i) 171 | for i in range(len(out_splits))] 172 | eval_loader_names += ['env{}_uda'.format(i) 173 | for i in range(len(uda_splits))] 174 | 175 | algorithm_class = algorithms.get_algorithm_class(args.algorithm) 176 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes, 177 | len(dataset) - len(args.test_envs), hparams) 178 | 179 | if algorithm_dict is not None: 180 | algorithm.load_state_dict(algorithm_dict) 181 | 182 | algorithm.to(device) 183 | 184 | train_minibatches_iterator = zip(*train_loaders) 185 | uda_minibatches_iterator = zip(*uda_loaders) 186 | checkpoint_vals = collections.defaultdict(lambda: []) 187 | 188 | steps_per_epoch = min([len(env)/hparams['batch_size'] for env,_ in in_splits]) 189 | 190 | n_steps = args.steps or dataset.N_STEPS 191 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ 192 | 193 | def save_checkpoint(filename): 194 | if args.skip_model_save: 195 | return 196 | save_dict = { 197 | "args": vars(args), 198 | "model_input_shape": dataset.input_shape, 199 | "model_num_classes": dataset.num_classes, 200 | "model_num_domains": len(dataset) - len(args.test_envs), 201 | "model_hparams": hparams, 202 | "model_dict": algorithm.cpu().state_dict() 203 | } 204 | torch.save(save_dict, os.path.join(args.output_dir, filename)) 205 | 206 | 207 | last_results_keys = None 208 | for step in range(start_step, n_steps): 209 | step_start_time = time.time() 210 | minibatches_device = [(x.to(device), y.to(device)) 211 | for x,y in next(train_minibatches_iterator)] 212 | if args.task == "domain_adaptation": 213 | uda_device = [x.to(device) 214 | for x,_ in next(uda_minibatches_iterator)] 215 | else: 216 | uda_device = None 217 | step_vals = algorithm.update(minibatches_device, uda_device) 218 | checkpoint_vals['step_time'].append(time.time() - step_start_time) 219 | 220 | for key, val in step_vals.items(): 221 | checkpoint_vals[key].append(val) 222 | 223 | if (step % checkpoint_freq == 0) or (step == n_steps - 1): 224 | results = { 225 | 'step': step, 226 | 'epoch': step / steps_per_epoch, 227 | } 228 | 229 | for key, val in checkpoint_vals.items(): 230 | results[key] = np.mean(val) 231 | 232 | evals = zip(eval_loader_names, eval_loaders, eval_weights) 233 | for name, loader, weights in evals: 234 | acc = misc.accuracy(algorithm, loader, weights, device) 235 | results[name+'_acc'] = acc 236 | 237 | results['mem_gb'] = torch.cuda.max_memory_allocated() / (1024.*1024.*1024.) 238 | 239 | results_keys = sorted(results.keys()) 240 | if results_keys != last_results_keys: 241 | misc.print_row(results_keys, colwidth=12) 242 | last_results_keys = results_keys 243 | misc.print_row([results[key] for key in results_keys], 244 | colwidth=12) 245 | 246 | results.update({ 247 | 'hparams': hparams, 248 | 'args': vars(args) 249 | }) 250 | 251 | epochs_path = os.path.join(args.output_dir, 'results.jsonl') 252 | with open(epochs_path, 'a') as f: 253 | f.write(json.dumps(results, sort_keys=True) + "\n") 254 | 255 | algorithm_dict = algorithm.state_dict() 256 | start_step = step + 1 257 | checkpoint_vals = collections.defaultdict(lambda: []) 258 | 259 | if args.save_model_every_checkpoint: 260 | save_checkpoint(f'model_step{step}.pkl') 261 | 262 | save_checkpoint('model.pkl') 263 | 264 | with open(os.path.join(args.output_dir, 'done'), 'w') as f: 265 | f.write('done') 266 | -------------------------------------------------------------------------------- /fig_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alexrame/fishr/7b8fdf1e0b15226ded9b58efd37698e74e616ab7/fig_intro.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision==0.9.1 3 | backpack-for-pytorch ==1.3.0 4 | numpy==1.20.2 5 | --------------------------------------------------------------------------------