├── .gitignore
├── LICENSE
├── README.md
├── coloredmnist
└── train_coloredmnist.py
├── domainbed
├── __init__.py
├── algorithms.py
├── command_launchers.py
├── datasets.py
├── hparams_registry.py
├── lib
│ ├── fast_data_loader.py
│ ├── misc.py
│ ├── query.py
│ ├── reporting.py
│ └── wide_resnet.py
├── model_selection.py
├── networks.py
└── scripts
│ ├── __init__.py
│ ├── collect_results.py
│ ├── download.py
│ ├── list_top_hparams.py
│ ├── save_images.py
│ ├── sweep.py
│ └── train.py
├── fig_intro.png
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Build and Release Folders
2 | bin-debug/
3 | bin-release/
4 | [Oo]bj/
5 | [Bb]in/
6 |
7 | # Other files and folders
8 | .settings/
9 | *__pycache__*
10 |
11 | # Executables
12 | *.swf
13 | *.air
14 | *.ipa
15 | *.apk
16 | *.pyc
17 |
18 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties`
19 | # should NOT be excluded as they contain compiler settings and other important
20 | # information for Eclipse / Flash Builder.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fishr: Invariant Gradient Variances for Out-of-distribution Generalization
2 |
3 | Official PyTorch implementation of the Fishr regularization for out-of-distribution generalization, ICML 2022 | [paper](https://arxiv.org/abs/2109.02934)
4 |
5 | [Alexandre Ramé](https://alexrame.github.io/), [Corentin Dancette](https://cdancette.fr/), [Matthieu Cord](http://webia.lip6.fr/~cord/)
6 |
7 | 
8 |
9 |
10 | ## Abstract
11 | Learning robust models that generalize well under changes in the data distribution is critical for real-world applications. To this end, there has been a growing surge of interest to learn simultaneously from multiple training domains - while enforcing different types of invariance across those domains. Yet, all existing approaches fail to show systematic benefits under fair evaluation protocols.
12 |
13 | In this paper, we propose a new learning scheme to enforce domain invariance in the space of the gradients of the loss function: specifically, we introduce a regularization term that matches the domain-level variances of gradients across training domains. Critically, our strategy, named Fishr, exhibits close relations with the Fisher Information and the Hessian of the loss. We show that forcing domain-level gradient covariances to be similar during the learning procedure eventually aligns the domain-level loss landscapes locally around the final weights.
14 |
15 | Extensive experiments demonstrate the effectiveness of Fishr for out-of-distribution generalization. In particular, Fishr improves the state of the art on the DomainBed benchmark and performs significantly better than
16 | Empirical Risk Minimization.
17 |
18 | # Installation
19 |
20 | ## Requirements overview
21 |
22 | Our implementation relies on the [BackPACK](https://github.com/f-dangel/backpack/) package in [PyTorch](https://pytorch.org/) to easily compute gradient variances.
23 |
24 | - python == 3.7.10
25 | - torch == 1.8.1
26 | - torchvision == 0.9.1
27 | - backpack-for-pytorch == 1.3.0
28 | - numpy == 1.20.2
29 |
30 | ## Procedure
31 |
32 | 1. Clone the repo:
33 | ```bash
34 | $ git clone https://github.com/alexrame/fishr.git
35 | ```
36 |
37 | 2. Install this repository and the dependencies using pip:
38 | ```bash
39 | $ conda create --name fishr python=3.7.10
40 | $ conda activate fishr
41 | $ cd fishr
42 | $ pip install -r requirements.txt
43 | ```
44 |
45 | With this, you can edit the Fishr code on the fly.
46 |
47 | # Overview
48 |
49 | This github enables the replication of our two main experiments: (1) on Colored MNIST in the setup defined by [IRM](https://github.com/facebookresearch/InvariantRiskMinimization/tree/master/code/colored_mnist) and (2) on the [DomainBed](https://github.com/facebookresearch/DomainBed/) benchmark.
50 |
51 |
52 | ## Colored MNIST in the IRM setup
53 |
54 | We first validate that Fishr tackles distribution shifts on the synthetic Colored MNIST.
55 | ### Main results (Table 2 in Section 6.A)
56 |
57 | To reproduce the results from Table 2, call ```python3 coloredmnist/train_coloredmnist.py --algorithm $algorithm``` where `algorithm` is either:
58 | - ```erm``` for Empirical Risk Minimization
59 | - ```irm``` for [Invariant Risk Minimization](https://arxiv.org/abs/1907.02893)
60 | - ```rex``` for [Out-of-Distribution Generalization via Risk Extrapolation](https://icml.cc/virtual/2021/oral/9186)
61 | - ```fishr``` for our proposed Fishr
62 |
63 | Results will be printed at the end of the script, averaged over 10 runs. Note that all hyperparameters are taken from the seminal [IRM](https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/reproduce_paper_results.sh) implementation.
64 |
65 | Method | Train acc. | Test acc. | Gray test acc.
66 | --------|------------|------------|----------------
67 | ERM | 86.4 ± 0.2 | 14.0 ± 0.7 | 71.0 ± 0.7
68 | IRM | 71.0 ± 0.5 | 65.6 ± 1.8 | 66.1 ± 0.2
69 | V-REx | 71.7 ± 1.5 | 67.2 ± 1.5 | 68.6 ± 2.2
70 | Fishr | 71.0 ± 0.9 | 69.5 ± 1.0 | 70.2 ± 1.1
71 |
72 |
73 |
74 | ### Without label flipping (Table 5 in Appendix C.2.3)
75 | The script ```coloredmnist.train_coloredmnist``` also accepts as input the argument `--label_flipping_prob` which defines the label flipping probability. By default, it's 0.25, so to reproduce the results from Table 5 you should set `--label_flipping_prob 0`.
76 | ### Fishr variants (Table 6 in Appendix C.2.4)
77 | This table considers two additional Fishr variants, reproduced with `algorithm` set to:
78 | - ```fishr_offdiagonal``` for Fishr but without centering the gradient variances
79 | - ```fishr_notcentered``` for Fishr but on the full covariance rather than only the diagonal
80 |
81 | ## DomainBed
82 |
83 | DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in [In Search of Lost Domain Generalization](https://arxiv.org/abs/2007.01434). Instructions below are copied and adapted from the official [github](https://github.com/facebookresearch/DomainBed/).
84 |
85 | ### Algorithms and hyperparameter grids
86 |
87 | We added Fishr as a new algorithm [here](domainbed/algorithms.py), and defined Fishr's hyperparameter grids [here](domainbed/hparams_registry.py), as defined in Table 7 in Appendix D.
88 |
89 | ### Datasets
90 |
91 | We ran Fishr on following [datasets](domainbed/datasets.py):
92 |
93 | * Rotated MNIST ([Ghifary et al., 2015](https://arxiv.org/abs/1508.07680))
94 | * Colored MNIST ([Arjovsky et al., 2019](https://arxiv.org/abs/1907.02893))
95 | * VLCS ([Fang et al., 2013](https://openaccess.thecvf.com/content_iccv_2013/papers/Fang_Unbiased_Metric_Learning_2013_ICCV_paper.pdf))
96 | * PACS ([Li et al., 2017](https://arxiv.org/abs/1710.03077))
97 | * OfficeHome ([Venkateswara et al., 2017](https://arxiv.org/abs/1706.07522))
98 | * A TerraIncognita ([Beery et al., 2018](https://arxiv.org/abs/1807.04975)) subset
99 | * DomainNet ([Peng et al., 2019](http://ai.bu.edu/M3SDA/))
100 |
101 | ### Launch training
102 |
103 | Download the datasets:
104 |
105 | ```sh
106 | python3 -m domainbed.scripts.download\
107 | --data_dir=/my/data/dir
108 | ```
109 |
110 | Train a model for debugging:
111 |
112 | ```sh
113 | python3 -m domainbed.scripts.train\
114 | --data_dir=/my/data/dir/\
115 | --algorithm Fishr\
116 | --dataset ColoredMNIST\
117 | --test_env 2
118 | ```
119 |
120 | Launch a sweep for hyperparameter search:
121 |
122 | ```sh
123 | python -m domainbed.scripts.sweep launch\
124 | --data_dir=/my/data/dir/\
125 | --output_dir=/my/sweep/output/path\
126 | --command_launcher MyLauncher
127 | --datasets ColoredMNIST\
128 | --algorithms Fishr
129 | ```
130 | Here, `MyLauncher` is your cluster's command launcher, as implemented in `command_launchers.py`.
131 |
132 |
133 | ### Performances inspection (Tables 3 and 4 in Section 6.B.2, Tables in Appendix G)
134 |
135 | To view the results of your sweep:
136 |
137 | ````sh
138 | python -m domainbed.scripts.collect_results\
139 | --input_dir=/my/sweep/output/path
140 | ````
141 |
142 | We inspect performances using following [model selection criteria](domainbed/model_selection.py), that differ in what data is used to choose the best hyper-parameters for a given model:
143 |
144 | * `OracleSelectionMethod` (`Oracle`): A random subset from the data of the test domain.
145 | * `IIDAccuracySelectionMethod` (`Training`): A random subset from the data of the training domains.
146 |
147 | Critically, Fishr performs consistently better than Empirical Risk Minimization.
148 |
149 | Model selection | Algorithm | Colored MNIST | Rotated MNIST | VLCS | PACS |OfficeHome | TerraIncognita | DomainNet | Avg
150 | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
151 | | | | | | | | | |
152 | Oracle | ERM | 57.8 ± 0.2 | 97.8 ± 0.1 | 77.6 ± 0.3 | 86.7 ± 0.3 | 66.4 ± 0.5 | 53.0 ± 0.3 | 41.3 ± 0.1 | 68.7
153 | Oracle | Fishr | 68.8 ± 1.4 | 97.8 ± 0.1 | 78.2 ± 0.2 | 86.9 ± 0.2 | 68.2 ± 0.2 | 53.6 ± 0.4 | 41.8 ± 0.2 | 70.8
154 | | | | | | | | | |
155 | Training | ERM | 51.5 ± 0.1 | 98.0 ± 0.0 | 77.5 ± 0.4 | 85.5 ± 0.2 | 66.5 ± 0.3 | 46.1 ± 1.8 | 40.9 ± 0.1 | 66.6
156 | Training | Fishr | 52.0 ± 0.2 | 97.8 ± 0.0 | 77.8 ± 0.1 | 85.5 ± 0.4 | 67.8 ± 0.1 | 47.4 ± 1.6 | 41.7 ± 0.0 | 67.1
157 |
158 |
159 | # Conclusion
160 |
161 | We addressed the task of out-of-distribution generalization for computer vision classification tasks. We derive a new and simple regularization - Fishr - that matches the gradient variances across domains as a proxy for matching domain-level Hessians. Our scalable strategy reaches state-of-the-art performances on the DomainBed benchmark and performs better than ERM. Our empirical experiments suggest that Fishr regularization would consistently improve a deep classifier in real-world applications when dealing with data from multiple domains. If you need help to use Fishr, please open an issue or contact alexandre.rame@lip6.fr.
162 |
163 | # Citation
164 |
165 | If you find this code useful for your research, please consider citing our work:
166 |
167 | ```
168 | @inproceedings{rame2021ishr,
169 | title={Fishr: Invariant Gradient Variances for Out-of-distribution Generalization},
170 | author={Alexandre Rame and Corentin Dancette and Matthieu Cord},
171 | year={2022},
172 | booktitle={ICML}
173 | }
174 | ```
175 |
--------------------------------------------------------------------------------
/coloredmnist/train_coloredmnist.py:
--------------------------------------------------------------------------------
1 | # This script was first copied from https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/main.py under the license
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 | # All rights reserved.
4 | #
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 | #
8 | # Then we included our new regularization loss Fishr. To do so:
9 | # 1. we first compute gradients covariance on each domain (see compute_grads_variance method) using BackPACK package
10 | # 2. then, we compute l2 distance between these gradient covariances (see l2_between_grads_variance method)
11 |
12 | import random
13 | import argparse
14 | import numpy as np
15 | from collections import OrderedDict
16 |
17 | import torch
18 | from torchvision import datasets
19 | from torch import nn, optim, autograd
20 |
21 | from backpack import backpack, extend
22 | from backpack.extensions import BatchGrad
23 |
24 | parser = argparse.ArgumentParser(description='Colored MNIST')
25 |
26 | # select your algorithm
27 | parser.add_argument(
28 | '--algorithm',
29 | type=str,
30 | default="fishr",
31 | choices=[
32 | ## Four main methods, for Table 2 in Section 6.A
33 | 'erm', # Empirical Risk Minimization
34 | 'irm', # Invariant Risk Minimization (https://arxiv.org/abs/1907.02893)
35 | 'rex', # Out-of-Distribution Generalization via Risk Extrapolation (https://icml.cc/virtual/2021/oral/9186)
36 | 'fishr', # Our proposed Fishr
37 | ## two Fishr variants, for Table 6 in Appendix C.2.4
38 | 'fishr_offdiagonal' # Fishr but on the full covariance rather than only the diagonal
39 | 'fishr_notcentered', # Fishr but without centering the gradient variances
40 | ]
41 | )
42 | # select whether you want to apply label flipping or not
43 | # Set to 0 in Table 5 in Appendix C.2.3 and in the right half of Table 6 in Appendix C.2.4
44 | parser.add_argument('--label_flipping_prob', type=float, default=0.25)
45 |
46 | # Following hyperparameters are directly taken from from https://github.com/facebookresearch/InvariantRiskMinimization/blob/master/code/colored_mnist/reproduce_paper_results.sh
47 | # They should not be modified except in case of a new proper hyperparameter search with an external validation dataset.
48 | # Overall, we compare all approaches using the hyperparameters optimized for IRM.
49 | parser.add_argument('--hidden_dim', type=int, default=390)
50 | parser.add_argument('--l2_regularizer_weight', type=float, default=0.00110794568)
51 | parser.add_argument('--lr', type=float, default=0.0004898536566546834)
52 | parser.add_argument('--penalty_anneal_iters', type=int, default=190)
53 | parser.add_argument('--penalty_weight', type=float, default=91257.18613115903)
54 | parser.add_argument('--steps', type=int, default=501)
55 | # experimental setup
56 | parser.add_argument('--grayscale_model', action='store_true')
57 | parser.add_argument('--n_restarts', type=int, default=10)
58 | parser.add_argument('--seed', type=int, default=0, help='Seed for everything')
59 |
60 | flags = parser.parse_args()
61 |
62 | print('Flags:')
63 | for k, v in sorted(vars(flags).items()):
64 | print("\t{}: {}".format(k, v))
65 |
66 | random.seed(flags.seed)
67 | np.random.seed(flags.seed)
68 | torch.manual_seed(flags.seed)
69 | torch.backends.cudnn.deterministic = True
70 | torch.backends.cudnn.benchmark = False
71 |
72 | final_train_accs = []
73 | final_test_accs = []
74 | final_graytest_accs = []
75 | for restart in range(flags.n_restarts):
76 | print("Restart", restart)
77 |
78 | # Load MNIST, make train/val splits, and shuffle train set examples
79 |
80 | mnist = datasets.MNIST('~/datasets/mnist', train=True, download=True)
81 | mnist_train = (mnist.data[:50000], mnist.targets[:50000])
82 | mnist_val = (mnist.data[50000:], mnist.targets[50000:])
83 |
84 | rng_state = np.random.get_state()
85 | np.random.shuffle(mnist_train[0].numpy())
86 | np.random.set_state(rng_state)
87 | np.random.shuffle(mnist_train[1].numpy())
88 |
89 | # Build environments
90 |
91 |
92 | def make_environment(images, labels, e, grayscale=False):
93 |
94 | def torch_bernoulli(p, size):
95 | return (torch.rand(size) < p).float()
96 |
97 | def torch_xor(a, b):
98 | return (a - b).abs() # Assumes both inputs are either 0 or 1
99 |
100 | # 2x subsample for computational convenience
101 | images = images.reshape((-1, 28, 28))[:, ::2, ::2]
102 | # Assign a binary label based on the digit; flip label with probability 0.25
103 | labels = (labels < 5).float()
104 | labels = torch_xor(labels, torch_bernoulli(flags.label_flipping_prob, len(labels)))
105 | # Assign a color based on the label; flip the color with probability e
106 | colors = torch_xor(labels, torch_bernoulli(e, len(labels)))
107 | # Apply the color to the image by zeroing out the other color channel
108 | images = torch.stack([images, images], dim=1)
109 | if not grayscale:
110 | images[torch.tensor(range(len(images))), (1 - colors).long(), :, :] *= 0
111 | return {'images': (images.float() / 255.).cuda(), 'labels': labels[:, None].cuda()}
112 |
113 | envs = [
114 | make_environment(mnist_train[0][::2], mnist_train[1][::2], 0.2),
115 | make_environment(mnist_train[0][1::2], mnist_train[1][1::2], 0.1),
116 | make_environment(mnist_val[0], mnist_val[1], 0.9),
117 | make_environment(mnist_val[0], mnist_val[1], 0.9, grayscale=True)
118 | ]
119 |
120 | # Define and instantiate the model
121 |
122 |
123 | class MLP(nn.Module):
124 |
125 | def __init__(self):
126 | super(MLP, self).__init__()
127 | if flags.grayscale_model:
128 | lin1 = nn.Linear(14 * 14, flags.hidden_dim)
129 | else:
130 | lin1 = nn.Linear(2 * 14 * 14, flags.hidden_dim)
131 | lin2 = nn.Linear(flags.hidden_dim, flags.hidden_dim)
132 |
133 | self.classifier = extend(nn.Linear(flags.hidden_dim, 1))
134 | for lin in [lin1, lin2, self.classifier]:
135 | nn.init.xavier_uniform_(lin.weight)
136 | nn.init.zeros_(lin.bias)
137 | self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True))
138 |
139 | def forward(self, input):
140 | if flags.grayscale_model:
141 | out = input.view(input.shape[0], 2, 14 * 14).sum(dim=1)
142 | else:
143 | out = input.view(input.shape[0], 2 * 14 * 14)
144 | features = self._main(out)
145 | logits = self.classifier(features)
146 | return features, logits
147 |
148 | mlp = MLP().cuda()
149 |
150 | # Define loss function helpers
151 |
152 |
153 | def mean_nll(logits, y):
154 | return nn.functional.binary_cross_entropy_with_logits(logits, y)
155 |
156 | def mean_accuracy(logits, y):
157 | preds = (logits > 0.).float()
158 | return ((preds - y).abs() < 1e-2).float().mean()
159 |
160 | def compute_irm_penalty(logits, y):
161 | scale = torch.tensor(1.).cuda().requires_grad_()
162 | loss = mean_nll(logits * scale, y)
163 | grad = autograd.grad(loss, [scale], create_graph=True)[0]
164 | return torch.sum(grad**2)
165 |
166 | bce_extended = extend(nn.BCEWithLogitsLoss())
167 |
168 | def compute_grads_variance(features, labels, classifier):
169 | logits = classifier(features)
170 | loss = bce_extended(logits, labels)
171 | with backpack(BatchGrad()):
172 | loss.backward(
173 | inputs=list(classifier.parameters()), retain_graph=True, create_graph=True
174 | )
175 |
176 | dict_grads = OrderedDict(
177 | [
178 | (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1))
179 | for name, weights in classifier.named_parameters()
180 | ]
181 | )
182 | dict_grads_variance = {}
183 | for name, _grads in dict_grads.items():
184 | grads = _grads * labels.size(0) # multiply by batch size
185 | env_mean = grads.mean(dim=0, keepdim=True)
186 | if flags.algorithm != "fishr_notcentered":
187 | grads = grads - env_mean
188 | if flags.algorithm == "fishr_offdiagonal":
189 | dict_grads_variance[name] = torch.einsum("na,nb->ab", grads,
190 | grads) / (grads.size(0) * grads.size(1))
191 | else:
192 | dict_grads_variance[name] = (grads).pow(2).mean(dim=0)
193 |
194 | return dict_grads_variance
195 |
196 | def l2_between_grads_variance(cov_1, cov_2):
197 | assert len(cov_1) == len(cov_2)
198 | cov_1_values = [cov_1[key] for key in sorted(cov_1.keys())]
199 | cov_2_values = [cov_2[key] for key in sorted(cov_2.keys())]
200 | return (
201 | torch.cat(tuple([t.view(-1) for t in cov_1_values])) -
202 | torch.cat(tuple([t.view(-1) for t in cov_2_values]))
203 | ).pow(2).sum()
204 |
205 | # Train loop
206 |
207 | def pretty_print(*values):
208 | col_width = 13
209 |
210 | def format_val(v):
211 | if not isinstance(v, str):
212 | v = np.array2string(v, precision=5, floatmode='fixed')
213 | return v.ljust(col_width)
214 |
215 | str_values = [format_val(v) for v in values]
216 | print(" ".join(str_values))
217 |
218 | optimizer = optim.Adam(mlp.parameters(), lr=flags.lr)
219 |
220 | pretty_print(
221 | 'step', 'train nll', 'train acc', 'fishr penalty', 'rex penalty', 'irm penalty', 'test acc',
222 | "gray test acc"
223 | )
224 | for step in range(flags.steps):
225 | for edx, env in enumerate(envs):
226 | features, logits = mlp(env['images'])
227 | env['nll'] = mean_nll(logits, env['labels'])
228 | env['acc'] = mean_accuracy(logits, env['labels'])
229 | env['irm'] = compute_irm_penalty(logits, env['labels'])
230 | if edx in [0, 1]:
231 | # True when the dataset is in training
232 | optimizer.zero_grad()
233 | env["grads_variance"] = compute_grads_variance(features, env['labels'], mlp.classifier)
234 |
235 | train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).mean()
236 | train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean()
237 |
238 | weight_norm = torch.tensor(0.).cuda()
239 | for w in mlp.parameters():
240 | weight_norm += w.norm().pow(2)
241 |
242 | loss = train_nll.clone()
243 | loss += flags.l2_regularizer_weight * weight_norm
244 |
245 | irm_penalty = torch.stack([envs[0]['irm'], envs[1]['irm']]).mean()
246 | rex_penalty = (envs[0]['nll'].mean() - envs[1]['nll'].mean())**2
247 |
248 | # Compute the variance averaged over the two training domains
249 | dict_grads_variance_averaged = OrderedDict(
250 | [
251 | (
252 | name,
253 | torch.stack([envs[0]["grads_variance"][name], envs[1]["grads_variance"][name]],
254 | dim=0).mean(dim=0)
255 | ) for name in envs[0]["grads_variance"]
256 | ]
257 | )
258 | fishr_penalty = (
259 | l2_between_grads_variance(envs[0]["grads_variance"], dict_grads_variance_averaged) +
260 | l2_between_grads_variance(envs[1]["grads_variance"], dict_grads_variance_averaged)
261 | )
262 |
263 | # apply the selected regularization
264 | if flags.algorithm == "erm":
265 | pass
266 | else:
267 | if flags.algorithm.startswith("fishr"):
268 | train_penalty = fishr_penalty
269 | elif flags.algorithm == "rex":
270 | train_penalty = rex_penalty
271 | elif flags.algorithm == "irm":
272 | train_penalty = irm_penalty
273 | else:
274 | raise ValueError(flags.algorithm)
275 | penalty_weight = (flags.penalty_weight if step >= flags.penalty_anneal_iters else 1.0)
276 | loss += penalty_weight * train_penalty
277 | if penalty_weight > 1.0:
278 | # Rescale the entire loss to keep gradients in a reasonable range
279 | loss /= penalty_weight
280 |
281 | optimizer.zero_grad()
282 | loss.backward()
283 | optimizer.step()
284 |
285 | test_acc = envs[2]['acc']
286 | grayscale_test_acc = envs[3]['acc']
287 | if step % 100 == 0:
288 | pretty_print(
289 | np.int32(step),
290 | train_nll.detach().cpu().numpy(),
291 | train_acc.detach().cpu().numpy(),
292 | fishr_penalty.detach().cpu().numpy(),
293 | rex_penalty.detach().cpu().numpy(),
294 | irm_penalty.detach().cpu().numpy(),
295 | test_acc.detach().cpu().numpy(),
296 | grayscale_test_acc.detach().cpu().numpy(),
297 | )
298 |
299 | final_train_accs.append(train_acc.detach().cpu().numpy())
300 | final_test_accs.append(test_acc.detach().cpu().numpy())
301 | final_graytest_accs.append(grayscale_test_acc.detach().cpu().numpy())
302 | print('Final train acc (mean/std across restarts so far):')
303 | print(np.mean(final_train_accs), np.std(final_train_accs))
304 | print('Final test acc (mean/std across restarts so far):')
305 | print(np.mean(final_test_accs), np.std(final_test_accs))
306 | print('Final gray test acc (mean/std across restarts so far):')
307 | print(np.mean(final_graytest_accs), np.std(final_graytest_accs))
308 |
--------------------------------------------------------------------------------
/domainbed/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 |
--------------------------------------------------------------------------------
/domainbed/algorithms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.autograd as autograd
7 | from torch.autograd import Variable
8 |
9 | import copy
10 | import numpy as np
11 | from collections import defaultdict, OrderedDict
12 | try:
13 | from backpack import backpack, extend
14 | from backpack.extensions import BatchGrad
15 | except:
16 | backpack = None
17 |
18 | from domainbed import networks
19 | from domainbed.lib.misc import (
20 | random_pairs_of_minibatches, ParamDict, MovingAverage, l2_between_dicts
21 | )
22 |
23 | ALGORITHMS = [
24 | 'ERM',
25 | 'Fish',
26 | 'IRM',
27 | 'GroupDRO',
28 | 'Mixup',
29 | 'MLDG',
30 | 'CORAL',
31 | 'MMD',
32 | 'DANN',
33 | 'CDANN',
34 | 'MTL',
35 | 'SagNet',
36 | 'ARM',
37 | 'VREx',
38 | 'RSC',
39 | 'SD',
40 | 'ANDMask',
41 | 'SANDMask', # SAND-mask
42 | 'IGA',
43 | 'SelfReg',
44 | "Fishr"
45 | ]
46 |
47 |
48 | def get_algorithm_class(algorithm_name):
49 | """Return the algorithm class with the given name."""
50 | if algorithm_name not in globals():
51 | raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
52 | return globals()[algorithm_name]
53 |
54 |
55 | class Algorithm(torch.nn.Module):
56 | """
57 | A subclass of Algorithm implements a domain generalization algorithm.
58 | Subclasses should implement the following:
59 | - update()
60 | - predict()
61 | """
62 |
63 | def __init__(self, input_shape, num_classes, num_domains, hparams):
64 | super(Algorithm, self).__init__()
65 | self.hparams = hparams
66 |
67 | def update(self, minibatches, unlabeled=None):
68 | """
69 | Perform one update step, given a list of (x, y) tuples for all
70 | environments.
71 |
72 | Admits an optional list of unlabeled minibatches from the test domains,
73 | when task is domain_adaptation.
74 | """
75 | raise NotImplementedError
76 |
77 | def predict(self, x):
78 | raise NotImplementedError
79 |
80 |
81 | class ERM(Algorithm):
82 | """
83 | Empirical Risk Minimization (ERM)
84 | """
85 |
86 | def __init__(self, input_shape, num_classes, num_domains, hparams):
87 | super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams)
88 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
89 | self.classifier = networks.Classifier(
90 | self.featurizer.n_outputs, num_classes, self.hparams['nonlinear_classifier']
91 | )
92 |
93 | self.network = nn.Sequential(self.featurizer, self.classifier)
94 | self.optimizer = torch.optim.Adam(
95 | self.network.parameters(),
96 | lr=self.hparams["lr"],
97 | weight_decay=self.hparams['weight_decay']
98 | )
99 |
100 | def update(self, minibatches, unlabeled=None):
101 | all_x = torch.cat([x for x, y in minibatches])
102 | all_y = torch.cat([y for x, y in minibatches])
103 | loss = F.cross_entropy(self.predict(all_x), all_y)
104 |
105 | self.optimizer.zero_grad()
106 | loss.backward()
107 | self.optimizer.step()
108 |
109 | return {'loss': loss.item()}
110 |
111 | def predict(self, x):
112 | return self.network(x)
113 |
114 |
115 | class Fish(Algorithm):
116 | """
117 | Implementation of Fish, as seen in Gradient Matching for Domain
118 | Generalization, Shi et al. 2021.
119 | """
120 |
121 | def __init__(self, input_shape, num_classes, num_domains, hparams):
122 | super(Fish, self).__init__(input_shape, num_classes, num_domains, hparams)
123 | self.input_shape = input_shape
124 | self.num_classes = num_classes
125 |
126 | self.network = networks.WholeFish(input_shape, num_classes, hparams)
127 | self.optimizer = torch.optim.Adam(
128 | self.network.parameters(),
129 | lr=self.hparams["lr"],
130 | weight_decay=self.hparams['weight_decay']
131 | )
132 | self.optimizer_inner_state = None
133 |
134 | def create_clone(self, device):
135 | self.network_inner = networks.WholeFish(
136 | self.input_shape, self.num_classes, self.hparams, weights=self.network.state_dict()
137 | ).to(device)
138 | self.optimizer_inner = torch.optim.Adam(
139 | self.network_inner.parameters(),
140 | lr=self.hparams["lr"],
141 | weight_decay=self.hparams['weight_decay']
142 | )
143 | if self.optimizer_inner_state is not None:
144 | self.optimizer_inner.load_state_dict(self.optimizer_inner_state)
145 |
146 | def fish(self, meta_weights, inner_weights, lr_meta):
147 | meta_weights = ParamDict(meta_weights)
148 | inner_weights = ParamDict(inner_weights)
149 | meta_weights += lr_meta * (inner_weights - meta_weights)
150 | return meta_weights
151 |
152 | def update(self, minibatches, unlabeled=None):
153 | self.create_clone(minibatches[0][0].device)
154 |
155 | for x, y in minibatches:
156 | loss = F.cross_entropy(self.network_inner(x), y)
157 | self.optimizer_inner.zero_grad()
158 | loss.backward()
159 | self.optimizer_inner.step()
160 |
161 | self.optimizer_inner_state = self.optimizer_inner.state_dict()
162 | meta_weights = self.fish(
163 | meta_weights=self.network.state_dict(),
164 | inner_weights=self.network_inner.state_dict(),
165 | lr_meta=self.hparams["meta_lr"]
166 | )
167 | self.network.reset_weights(meta_weights)
168 |
169 | return {'loss': loss.item()}
170 |
171 | def predict(self, x):
172 | return self.network(x)
173 |
174 |
175 | class ARM(ERM):
176 | """ Adaptive Risk Minimization (ARM) """
177 |
178 | def __init__(self, input_shape, num_classes, num_domains, hparams):
179 | original_input_shape = input_shape
180 | input_shape = (1 + original_input_shape[0],) + original_input_shape[1:]
181 | super(ARM, self).__init__(input_shape, num_classes, num_domains, hparams)
182 | self.context_net = networks.ContextNet(original_input_shape)
183 | self.support_size = hparams['batch_size']
184 |
185 | def predict(self, x):
186 | batch_size, c, h, w = x.shape
187 | if batch_size % self.support_size == 0:
188 | meta_batch_size = batch_size // self.support_size
189 | support_size = self.support_size
190 | else:
191 | meta_batch_size, support_size = 1, batch_size
192 | context = self.context_net(x)
193 | context = context.reshape((meta_batch_size, support_size, 1, h, w))
194 | context = context.mean(dim=1)
195 | context = torch.repeat_interleave(context, repeats=support_size, dim=0)
196 | x = torch.cat([x, context], dim=1)
197 | return self.network(x)
198 |
199 |
200 | class AbstractDANN(Algorithm):
201 | """Domain-Adversarial Neural Networks (abstract class)"""
202 |
203 | def __init__(self, input_shape, num_classes, num_domains, hparams, conditional, class_balance):
204 |
205 | super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains, hparams)
206 |
207 | self.register_buffer('update_count', torch.tensor([0]))
208 | self.conditional = conditional
209 | self.class_balance = class_balance
210 |
211 | # Algorithms
212 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
213 | self.classifier = networks.Classifier(
214 | self.featurizer.n_outputs, num_classes, self.hparams['nonlinear_classifier']
215 | )
216 | self.discriminator = networks.MLP(self.featurizer.n_outputs, num_domains, self.hparams)
217 | self.class_embeddings = nn.Embedding(num_classes, self.featurizer.n_outputs)
218 |
219 | # Optimizers
220 | self.disc_opt = torch.optim.Adam(
221 | (list(self.discriminator.parameters()) + list(self.class_embeddings.parameters())),
222 | lr=self.hparams["lr_d"],
223 | weight_decay=self.hparams['weight_decay_d'],
224 | betas=(self.hparams['beta1'], 0.9)
225 | )
226 |
227 | self.gen_opt = torch.optim.Adam(
228 | (list(self.featurizer.parameters()) + list(self.classifier.parameters())),
229 | lr=self.hparams["lr_g"],
230 | weight_decay=self.hparams['weight_decay_g'],
231 | betas=(self.hparams['beta1'], 0.9)
232 | )
233 |
234 | def update(self, minibatches, unlabeled=None):
235 | device = "cuda" if minibatches[0][0].is_cuda else "cpu"
236 | self.update_count += 1
237 | all_x = torch.cat([x for x, y in minibatches])
238 | all_y = torch.cat([y for x, y in minibatches])
239 | all_z = self.featurizer(all_x)
240 | if self.conditional:
241 | disc_input = all_z + self.class_embeddings(all_y)
242 | else:
243 | disc_input = all_z
244 | disc_out = self.discriminator(disc_input)
245 | disc_labels = torch.cat(
246 | [
247 | torch.full((x.shape[0],), i, dtype=torch.int64, device=device)
248 | for i, (x, y) in enumerate(minibatches)
249 | ]
250 | )
251 |
252 | if self.class_balance:
253 | y_counts = F.one_hot(all_y).sum(dim=0)
254 | weights = 1. / (y_counts[all_y] * y_counts.shape[0]).float()
255 | disc_loss = F.cross_entropy(disc_out, disc_labels, reduction='none')
256 | disc_loss = (weights * disc_loss).sum()
257 | else:
258 | disc_loss = F.cross_entropy(disc_out, disc_labels)
259 |
260 | disc_softmax = F.softmax(disc_out, dim=1)
261 | input_grad = autograd.grad(
262 | disc_softmax[:, disc_labels].sum(), [disc_input], create_graph=True
263 | )[0]
264 | grad_penalty = (input_grad**2).sum(dim=1).mean(dim=0)
265 | disc_loss += self.hparams['grad_penalty'] * grad_penalty
266 |
267 | d_steps_per_g = self.hparams['d_steps_per_g_step']
268 | if (self.update_count.item() % (1 + d_steps_per_g) < d_steps_per_g):
269 |
270 | self.disc_opt.zero_grad()
271 | disc_loss.backward()
272 | self.disc_opt.step()
273 | return {'disc_loss': disc_loss.item()}
274 | else:
275 | all_preds = self.classifier(all_z)
276 | classifier_loss = F.cross_entropy(all_preds, all_y)
277 | gen_loss = (classifier_loss + (self.hparams['lambda'] * -disc_loss))
278 | self.disc_opt.zero_grad()
279 | self.gen_opt.zero_grad()
280 | gen_loss.backward()
281 | self.gen_opt.step()
282 | return {'gen_loss': gen_loss.item()}
283 |
284 | def predict(self, x):
285 | return self.classifier(self.featurizer(x))
286 |
287 |
288 | class DANN(AbstractDANN):
289 | """Unconditional DANN"""
290 |
291 | def __init__(self, input_shape, num_classes, num_domains, hparams):
292 | super(DANN, self).__init__(
293 | input_shape, num_classes, num_domains, hparams, conditional=False, class_balance=False
294 | )
295 |
296 |
297 | class CDANN(AbstractDANN):
298 | """Conditional DANN"""
299 |
300 | def __init__(self, input_shape, num_classes, num_domains, hparams):
301 | super(CDANN, self).__init__(
302 | input_shape, num_classes, num_domains, hparams, conditional=True, class_balance=True
303 | )
304 |
305 |
306 | class IRM(ERM):
307 | """Invariant Risk Minimization"""
308 |
309 | def __init__(self, input_shape, num_classes, num_domains, hparams):
310 | super(IRM, self).__init__(input_shape, num_classes, num_domains, hparams)
311 | self.register_buffer('update_count', torch.tensor([0]))
312 |
313 | @staticmethod
314 | def _irm_penalty(logits, y):
315 | device = "cuda" if logits[0][0].is_cuda else "cpu"
316 | scale = torch.tensor(1.).to(device).requires_grad_()
317 | loss_1 = F.cross_entropy(logits[::2] * scale, y[::2])
318 | loss_2 = F.cross_entropy(logits[1::2] * scale, y[1::2])
319 | grad_1 = autograd.grad(loss_1, [scale], create_graph=True)[0]
320 | grad_2 = autograd.grad(loss_2, [scale], create_graph=True)[0]
321 | result = torch.sum(grad_1 * grad_2)
322 | return result
323 |
324 | def update(self, minibatches, unlabeled=None):
325 | device = "cuda" if minibatches[0][0].is_cuda else "cpu"
326 | penalty_weight = (
327 | self.hparams['irm_lambda']
328 | if self.update_count >= self.hparams['irm_penalty_anneal_iters'] else 1.0
329 | )
330 | nll = 0.
331 | penalty = 0.
332 |
333 | all_x = torch.cat([x for x, y in minibatches])
334 | all_logits = self.network(all_x)
335 | all_logits_idx = 0
336 | for i, (x, y) in enumerate(minibatches):
337 | logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
338 | all_logits_idx += x.shape[0]
339 | nll += F.cross_entropy(logits, y)
340 | penalty += self._irm_penalty(logits, y)
341 | nll /= len(minibatches)
342 | penalty /= len(minibatches)
343 | loss = nll + (penalty_weight * penalty)
344 |
345 | if self.update_count == self.hparams['irm_penalty_anneal_iters']:
346 | # Reset Adam, because it doesn't like the sharp jump in gradient
347 | # magnitudes that happens at this step.
348 | self.optimizer = torch.optim.Adam(
349 | self.network.parameters(),
350 | lr=self.hparams["lr"],
351 | weight_decay=self.hparams['weight_decay']
352 | )
353 |
354 | self.optimizer.zero_grad()
355 | loss.backward()
356 | self.optimizer.step()
357 |
358 | self.update_count += 1
359 | return {'loss': loss.item(), 'nll': nll.item(), 'penalty': penalty.item()}
360 |
361 |
362 | class VREx(ERM):
363 | """V-REx algorithm from http://arxiv.org/abs/2003.00688"""
364 |
365 | def __init__(self, input_shape, num_classes, num_domains, hparams):
366 | super(VREx, self).__init__(input_shape, num_classes, num_domains, hparams)
367 | self.register_buffer('update_count', torch.tensor([0]))
368 |
369 | def update(self, minibatches, unlabeled=None):
370 | if self.update_count >= self.hparams["vrex_penalty_anneal_iters"]:
371 | penalty_weight = self.hparams["vrex_lambda"]
372 | else:
373 | penalty_weight = 1.0
374 |
375 | nll = 0.
376 |
377 | all_x = torch.cat([x for x, y in minibatches])
378 | all_logits = self.network(all_x)
379 | all_logits_idx = 0
380 | losses = torch.zeros(len(minibatches))
381 | for i, (x, y) in enumerate(minibatches):
382 | logits = all_logits[all_logits_idx:all_logits_idx + x.shape[0]]
383 | all_logits_idx += x.shape[0]
384 | nll = F.cross_entropy(logits, y)
385 | losses[i] = nll
386 |
387 | mean = losses.mean()
388 | penalty = ((losses - mean)**2).mean()
389 | loss = mean + penalty_weight * penalty
390 |
391 | if self.update_count == self.hparams['vrex_penalty_anneal_iters']:
392 | # Reset Adam (like IRM), because it doesn't like the sharp jump in
393 | # gradient magnitudes that happens at this step.
394 | self.optimizer = torch.optim.Adam(
395 | self.network.parameters(),
396 | lr=self.hparams["lr"],
397 | weight_decay=self.hparams['weight_decay']
398 | )
399 |
400 | self.optimizer.zero_grad()
401 | loss.backward()
402 | self.optimizer.step()
403 |
404 | self.update_count += 1
405 | return {'loss': loss.item(), 'nll': nll.item(), 'penalty': penalty.item()}
406 |
407 |
408 | class Mixup(ERM):
409 | """
410 | Mixup of minibatches from different domains
411 | https://arxiv.org/pdf/2001.00677.pdf
412 | https://arxiv.org/pdf/1912.01805.pdf
413 | """
414 |
415 | def __init__(self, input_shape, num_classes, num_domains, hparams):
416 | super(Mixup, self).__init__(input_shape, num_classes, num_domains, hparams)
417 |
418 | def update(self, minibatches, unlabeled=None):
419 | objective = 0
420 |
421 | for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
422 | lam = np.random.beta(self.hparams["mixup_alpha"], self.hparams["mixup_alpha"])
423 |
424 | x = lam * xi + (1 - lam) * xj
425 | predictions = self.predict(x)
426 |
427 | objective += lam * F.cross_entropy(predictions, yi)
428 | objective += (1 - lam) * F.cross_entropy(predictions, yj)
429 |
430 | objective /= len(minibatches)
431 |
432 | self.optimizer.zero_grad()
433 | objective.backward()
434 | self.optimizer.step()
435 |
436 | return {'loss': objective.item()}
437 |
438 |
439 | class GroupDRO(ERM):
440 | """
441 | Robust ERM minimizes the error at the worst minibatch
442 | Algorithm 1 from [https://arxiv.org/pdf/1911.08731.pdf]
443 | """
444 |
445 | def __init__(self, input_shape, num_classes, num_domains, hparams):
446 | super(GroupDRO, self).__init__(input_shape, num_classes, num_domains, hparams)
447 | self.register_buffer("q", torch.Tensor())
448 |
449 | def update(self, minibatches, unlabeled=None):
450 | device = "cuda" if minibatches[0][0].is_cuda else "cpu"
451 |
452 | if not len(self.q):
453 | self.q = torch.ones(len(minibatches)).to(device)
454 |
455 | losses = torch.zeros(len(minibatches)).to(device)
456 |
457 | for m in range(len(minibatches)):
458 | x, y = minibatches[m]
459 | losses[m] = F.cross_entropy(self.predict(x), y)
460 | self.q[m] *= (self.hparams["groupdro_eta"] * losses[m].data).exp()
461 |
462 | self.q /= self.q.sum()
463 |
464 | loss = torch.dot(losses, self.q)
465 |
466 | self.optimizer.zero_grad()
467 | loss.backward()
468 | self.optimizer.step()
469 |
470 | return {'loss': loss.item()}
471 |
472 |
473 | class MLDG(ERM):
474 | """
475 | Model-Agnostic Meta-Learning
476 | Algorithm 1 / Equation (3) from: https://arxiv.org/pdf/1710.03463.pdf
477 | Related: https://arxiv.org/pdf/1703.03400.pdf
478 | Related: https://arxiv.org/pdf/1910.13580.pdf
479 | """
480 |
481 | def __init__(self, input_shape, num_classes, num_domains, hparams):
482 | super(MLDG, self).__init__(input_shape, num_classes, num_domains, hparams)
483 |
484 | def update(self, minibatches, unlabeled=None):
485 | """
486 | Terms being computed:
487 | * Li = Loss(xi, yi, params)
488 | * Gi = Grad(Li, params)
489 |
490 | * Lj = Loss(xj, yj, Optimizer(params, grad(Li, params)))
491 | * Gj = Grad(Lj, params)
492 |
493 | * params = Optimizer(params, Grad(Li + beta * Lj, params))
494 | * = Optimizer(params, Gi + beta * Gj)
495 |
496 | That is, when calling .step(), we want grads to be Gi + beta * Gj
497 |
498 | For computational efficiency, we do not compute second derivatives.
499 | """
500 | num_mb = len(minibatches)
501 | objective = 0
502 |
503 | self.optimizer.zero_grad()
504 | for p in self.network.parameters():
505 | if p.grad is None:
506 | p.grad = torch.zeros_like(p)
507 |
508 | for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
509 | # fine tune clone-network on task "i"
510 | inner_net = copy.deepcopy(self.network)
511 |
512 | inner_opt = torch.optim.Adam(
513 | inner_net.parameters(),
514 | lr=self.hparams["lr"],
515 | weight_decay=self.hparams['weight_decay']
516 | )
517 |
518 | inner_obj = F.cross_entropy(inner_net(xi), yi)
519 |
520 | inner_opt.zero_grad()
521 | inner_obj.backward()
522 | inner_opt.step()
523 |
524 | # The network has now accumulated gradients Gi
525 | # The clone-network has now parameters P - lr * Gi
526 | for p_tgt, p_src in zip(self.network.parameters(), inner_net.parameters()):
527 | if p_src.grad is not None:
528 | p_tgt.grad.data.add_(p_src.grad.data / num_mb)
529 |
530 | # `objective` is populated for reporting purposes
531 | objective += inner_obj.item()
532 |
533 | # this computes Gj on the clone-network
534 | loss_inner_j = F.cross_entropy(inner_net(xj), yj)
535 | grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(), allow_unused=True)
536 |
537 | # `objective` is populated for reporting purposes
538 | objective += (self.hparams['mldg_beta'] * loss_inner_j).item()
539 |
540 | for p, g_j in zip(self.network.parameters(), grad_inner_j):
541 | if g_j is not None:
542 | p.grad.data.add_(self.hparams['mldg_beta'] * g_j.data / num_mb)
543 |
544 | # The network has now accumulated gradients Gi + beta * Gj
545 | # Repeat for all train-test splits, do .step()
546 |
547 | objective /= len(minibatches)
548 |
549 | self.optimizer.step()
550 |
551 | return {'loss': objective}
552 |
553 | # This commented "update" method back-propagates through the gradients of
554 | # the inner update, as suggested in the original MAML paper. However, this
555 | # is twice as expensive as the uncommented "update" method, which does not
556 | # compute second-order derivatives, implementing the First-Order MAML
557 | # method (FOMAML) described in the original MAML paper.
558 |
559 | # def update(self, minibatches, unlabeled=None):
560 | # objective = 0
561 | # beta = self.hparams["beta"]
562 | # inner_iterations = self.hparams["inner_iterations"]
563 |
564 | # self.optimizer.zero_grad()
565 |
566 | # with higher.innerloop_ctx(self.network, self.optimizer,
567 | # copy_initial_weights=False) as (inner_network, inner_optimizer):
568 |
569 | # for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
570 | # for inner_iteration in range(inner_iterations):
571 | # li = F.cross_entropy(inner_network(xi), yi)
572 | # inner_optimizer.step(li)
573 | #
574 | # objective += F.cross_entropy(self.network(xi), yi)
575 | # objective += beta * F.cross_entropy(inner_network(xj), yj)
576 |
577 | # objective /= len(minibatches)
578 | # objective.backward()
579 | #
580 | # self.optimizer.step()
581 | #
582 | # return objective
583 |
584 |
585 | class AbstractMMD(ERM):
586 | """
587 | Perform ERM while matching the pair-wise domain feature distributions
588 | using MMD (abstract class)
589 | """
590 |
591 | def __init__(self, input_shape, num_classes, num_domains, hparams, gaussian):
592 | super(AbstractMMD, self).__init__(input_shape, num_classes, num_domains, hparams)
593 | if gaussian:
594 | self.kernel_type = "gaussian"
595 | else:
596 | self.kernel_type = "mean_cov"
597 |
598 | def my_cdist(self, x1, x2):
599 | x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
600 | x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
601 | res = torch.addmm(
602 | x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2
603 | ).add_(x1_norm)
604 | return res.clamp_min_(1e-30)
605 |
606 | def gaussian_kernel(self, x, y, gamma=[0.001, 0.01, 0.1, 1, 10, 100, 1000]):
607 | D = self.my_cdist(x, y)
608 | K = torch.zeros_like(D)
609 |
610 | for g in gamma:
611 | K.add_(torch.exp(D.mul(-g)))
612 |
613 | return K
614 |
615 | def mmd(self, x, y):
616 | if self.kernel_type == "gaussian":
617 | Kxx = self.gaussian_kernel(x, x).mean()
618 | Kyy = self.gaussian_kernel(y, y).mean()
619 | Kxy = self.gaussian_kernel(x, y).mean()
620 | return Kxx + Kyy - 2 * Kxy
621 | else:
622 | mean_x = x.mean(0, keepdim=True)
623 | mean_y = y.mean(0, keepdim=True)
624 | cent_x = x - mean_x
625 | cent_y = y - mean_y
626 | cova_x = (cent_x.t() @ cent_x) / (len(x) - 1)
627 | cova_y = (cent_y.t() @ cent_y) / (len(y) - 1)
628 |
629 | mean_diff = (mean_x - mean_y).pow(2).mean()
630 | cova_diff = (cova_x - cova_y).pow(2).mean()
631 |
632 | return mean_diff + cova_diff
633 |
634 | def update(self, minibatches, unlabeled=None):
635 | objective = 0
636 | penalty = 0
637 | nmb = len(minibatches)
638 |
639 | features = [self.featurizer(xi) for xi, _ in minibatches]
640 | classifs = [self.classifier(fi) for fi in features]
641 | targets = [yi for _, yi in minibatches]
642 |
643 | for i in range(nmb):
644 | objective += F.cross_entropy(classifs[i], targets[i])
645 | for j in range(i + 1, nmb):
646 | penalty += self.mmd(features[i], features[j])
647 |
648 | objective /= nmb
649 | if nmb > 1:
650 | penalty /= (nmb * (nmb - 1) / 2)
651 |
652 | self.optimizer.zero_grad()
653 | (objective + (self.hparams['mmd_gamma'] * penalty)).backward()
654 | self.optimizer.step()
655 |
656 | if torch.is_tensor(penalty):
657 | penalty = penalty.item()
658 |
659 | return {'loss': objective.item(), 'penalty': penalty}
660 |
661 |
662 | class MMD(AbstractMMD):
663 | """
664 | MMD using Gaussian kernel
665 | """
666 |
667 | def __init__(self, input_shape, num_classes, num_domains, hparams):
668 | super(MMD, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=True)
669 |
670 |
671 | class CORAL(AbstractMMD):
672 | """
673 | MMD using mean and covariance difference
674 | """
675 |
676 | def __init__(self, input_shape, num_classes, num_domains, hparams):
677 | super(CORAL, self).__init__(input_shape, num_classes, num_domains, hparams, gaussian=False)
678 |
679 |
680 | class MTL(Algorithm):
681 | """
682 | A neural network version of
683 | Domain Generalization by Marginal Transfer Learning
684 | (https://arxiv.org/abs/1711.07910)
685 | """
686 |
687 | def __init__(self, input_shape, num_classes, num_domains, hparams):
688 | super(MTL, self).__init__(input_shape, num_classes, num_domains, hparams)
689 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
690 | self.classifier = networks.Classifier(
691 | self.featurizer.n_outputs * 2, num_classes, self.hparams['nonlinear_classifier']
692 | )
693 | self.optimizer = torch.optim.Adam(
694 | list(self.featurizer.parameters()) +\
695 | list(self.classifier.parameters()),
696 | lr=self.hparams["lr"],
697 | weight_decay=self.hparams['weight_decay']
698 | )
699 |
700 | self.register_buffer('embeddings', torch.zeros(num_domains, self.featurizer.n_outputs))
701 |
702 | self.ema = self.hparams['mtl_ema']
703 |
704 | def update(self, minibatches, unlabeled=None):
705 | loss = 0
706 | for env, (x, y) in enumerate(minibatches):
707 | loss += F.cross_entropy(self.predict(x, env), y)
708 |
709 | self.optimizer.zero_grad()
710 | loss.backward()
711 | self.optimizer.step()
712 |
713 | return {'loss': loss.item()}
714 |
715 | def update_embeddings_(self, features, env=None):
716 | return_embedding = features.mean(0)
717 |
718 | if env is not None:
719 | return_embedding = self.ema * return_embedding +\
720 | (1 - self.ema) * self.embeddings[env]
721 |
722 | self.embeddings[env] = return_embedding.clone().detach()
723 |
724 | return return_embedding.view(1, -1).repeat(len(features), 1)
725 |
726 | def predict(self, x, env=None):
727 | features = self.featurizer(x)
728 | embedding = self.update_embeddings_(features, env).normal_()
729 | return self.classifier(torch.cat((features, embedding), 1))
730 |
731 |
732 | class SagNet(Algorithm):
733 | """
734 | Style Agnostic Network
735 | Algorithm 1 from: https://arxiv.org/abs/1910.11645
736 | """
737 |
738 | def __init__(self, input_shape, num_classes, num_domains, hparams):
739 | super(SagNet, self).__init__(input_shape, num_classes, num_domains, hparams)
740 | # featurizer network
741 | self.network_f = networks.Featurizer(input_shape, self.hparams)
742 | # content network
743 | self.network_c = networks.Classifier(
744 | self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier']
745 | )
746 | # style network
747 | self.network_s = networks.Classifier(
748 | self.network_f.n_outputs, num_classes, self.hparams['nonlinear_classifier']
749 | )
750 |
751 | # # This commented block of code implements something closer to the
752 | # # original paper, but is specific to ResNet and puts in disadvantage
753 | # # the other algorithms.
754 | # resnet_c = networks.Featurizer(input_shape, self.hparams)
755 | # resnet_s = networks.Featurizer(input_shape, self.hparams)
756 | # # featurizer network
757 | # self.network_f = torch.nn.Sequential(
758 | # resnet_c.network.conv1,
759 | # resnet_c.network.bn1,
760 | # resnet_c.network.relu,
761 | # resnet_c.network.maxpool,
762 | # resnet_c.network.layer1,
763 | # resnet_c.network.layer2,
764 | # resnet_c.network.layer3)
765 | # # content network
766 | # self.network_c = torch.nn.Sequential(
767 | # resnet_c.network.layer4,
768 | # resnet_c.network.avgpool,
769 | # networks.Flatten(),
770 | # resnet_c.network.fc)
771 | # # style network
772 | # self.network_s = torch.nn.Sequential(
773 | # resnet_s.network.layer4,
774 | # resnet_s.network.avgpool,
775 | # networks.Flatten(),
776 | # resnet_s.network.fc)
777 |
778 | def opt(p):
779 | return torch.optim.Adam(p, lr=hparams["lr"], weight_decay=hparams["weight_decay"])
780 |
781 | self.optimizer_f = opt(self.network_f.parameters())
782 | self.optimizer_c = opt(self.network_c.parameters())
783 | self.optimizer_s = opt(self.network_s.parameters())
784 | self.weight_adv = hparams["sag_w_adv"]
785 |
786 | def forward_c(self, x):
787 | # learning content network on randomized style
788 | return self.network_c(self.randomize(self.network_f(x), "style"))
789 |
790 | def forward_s(self, x):
791 | # learning style network on randomized content
792 | return self.network_s(self.randomize(self.network_f(x), "content"))
793 |
794 | def randomize(self, x, what="style", eps=1e-5):
795 | device = "cuda" if x.is_cuda else "cpu"
796 | sizes = x.size()
797 | alpha = torch.rand(sizes[0], 1).to(device)
798 |
799 | if len(sizes) == 4:
800 | x = x.view(sizes[0], sizes[1], -1)
801 | alpha = alpha.unsqueeze(-1)
802 |
803 | mean = x.mean(-1, keepdim=True)
804 | var = x.var(-1, keepdim=True)
805 |
806 | x = (x - mean) / (var + eps).sqrt()
807 |
808 | idx_swap = torch.randperm(sizes[0])
809 | if what == "style":
810 | mean = alpha * mean + (1 - alpha) * mean[idx_swap]
811 | var = alpha * var + (1 - alpha) * var[idx_swap]
812 | else:
813 | x = x[idx_swap].detach()
814 |
815 | x = x * (var + eps).sqrt() + mean
816 | return x.view(*sizes)
817 |
818 | def update(self, minibatches, unlabeled=None):
819 | all_x = torch.cat([x for x, y in minibatches])
820 | all_y = torch.cat([y for x, y in minibatches])
821 |
822 | # learn content
823 | self.optimizer_f.zero_grad()
824 | self.optimizer_c.zero_grad()
825 | loss_c = F.cross_entropy(self.forward_c(all_x), all_y)
826 | loss_c.backward()
827 | self.optimizer_f.step()
828 | self.optimizer_c.step()
829 |
830 | # learn style
831 | self.optimizer_s.zero_grad()
832 | loss_s = F.cross_entropy(self.forward_s(all_x), all_y)
833 | loss_s.backward()
834 | self.optimizer_s.step()
835 |
836 | # learn adversary
837 | self.optimizer_f.zero_grad()
838 | loss_adv = -F.log_softmax(self.forward_s(all_x), dim=1).mean(1).mean()
839 | loss_adv = loss_adv * self.weight_adv
840 | loss_adv.backward()
841 | self.optimizer_f.step()
842 |
843 | return {'loss_c': loss_c.item(), 'loss_s': loss_s.item(), 'loss_adv': loss_adv.item()}
844 |
845 | def predict(self, x):
846 | return self.network_c(self.network_f(x))
847 |
848 |
849 | class RSC(ERM):
850 |
851 | def __init__(self, input_shape, num_classes, num_domains, hparams):
852 | super(RSC, self).__init__(input_shape, num_classes, num_domains, hparams)
853 | self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100
854 | self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100
855 | self.num_classes = num_classes
856 |
857 | def update(self, minibatches, unlabeled=None):
858 | device = "cuda" if minibatches[0][0].is_cuda else "cpu"
859 |
860 | # inputs
861 | all_x = torch.cat([x for x, y in minibatches])
862 | # labels
863 | all_y = torch.cat([y for _, y in minibatches])
864 | # one-hot labels
865 | all_o = torch.nn.functional.one_hot(all_y, self.num_classes)
866 | # features
867 | all_f = self.featurizer(all_x)
868 | # predictions
869 | all_p = self.classifier(all_f)
870 |
871 | # Equation (1): compute gradients with respect to representation
872 | all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]
873 |
874 | # Equation (2): compute top-gradient-percentile mask
875 | percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
876 | percentiles = torch.Tensor(percentiles)
877 | percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
878 | mask_f = all_g.lt(percentiles.to(device)).float()
879 |
880 | # Equation (3): mute top-gradient-percentile activations
881 | all_f_muted = all_f * mask_f
882 |
883 | # Equation (4): compute muted predictions
884 | all_p_muted = self.classifier(all_f_muted)
885 |
886 | # Section 3.3: Batch Percentage
887 | all_s = F.softmax(all_p, dim=1)
888 | all_s_muted = F.softmax(all_p_muted, dim=1)
889 | changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
890 | percentile = np.percentile(changes.detach().cpu(), self.drop_b)
891 | mask_b = changes.lt(percentile).float().view(-1, 1)
892 | mask = torch.logical_or(mask_f, mask_b).float()
893 |
894 | # Equations (3) and (4) again, this time mutting over examples
895 | all_p_muted_again = self.classifier(all_f * mask)
896 |
897 | # Equation (5): update
898 | loss = F.cross_entropy(all_p_muted_again, all_y)
899 | self.optimizer.zero_grad()
900 | loss.backward()
901 | self.optimizer.step()
902 |
903 | return {'loss': loss.item()}
904 |
905 |
906 | class SD(ERM):
907 | """
908 | Gradient Starvation: A Learning Proclivity in Neural Networks
909 | Equation 25 from [https://arxiv.org/pdf/2011.09468.pdf]
910 | """
911 |
912 | def __init__(self, input_shape, num_classes, num_domains, hparams):
913 | super(SD, self).__init__(input_shape, num_classes, num_domains, hparams)
914 | self.sd_reg = hparams["sd_reg"]
915 |
916 | def update(self, minibatches, unlabeled=None):
917 | all_x = torch.cat([x for x, y in minibatches])
918 | all_y = torch.cat([y for x, y in minibatches])
919 | all_p = self.predict(all_x)
920 |
921 | loss = F.cross_entropy(all_p, all_y)
922 | penalty = (all_p**2).mean()
923 | objective = loss + self.sd_reg * penalty
924 |
925 | self.optimizer.zero_grad()
926 | objective.backward()
927 | self.optimizer.step()
928 |
929 | return {'loss': loss.item(), 'penalty': penalty.item()}
930 |
931 |
932 | class ANDMask(ERM):
933 | """
934 | Learning Explanations that are Hard to Vary [https://arxiv.org/abs/2009.00329]
935 | AND-Mask implementation from [https://github.com/gibipara92/learning-explanations-hard-to-vary]
936 | """
937 |
938 | def __init__(self, input_shape, num_classes, num_domains, hparams):
939 | super(ANDMask, self).__init__(input_shape, num_classes, num_domains, hparams)
940 |
941 | self.tau = hparams["tau"]
942 |
943 | def update(self, minibatches, unlabeled=None):
944 | mean_loss = 0
945 | param_gradients = [[] for _ in self.network.parameters()]
946 | for i, (x, y) in enumerate(minibatches):
947 | logits = self.network(x)
948 |
949 | env_loss = F.cross_entropy(logits, y)
950 | mean_loss += env_loss.item() / len(minibatches)
951 |
952 | env_grads = autograd.grad(env_loss, self.network.parameters())
953 | for grads, env_grad in zip(param_gradients, env_grads):
954 | grads.append(env_grad)
955 |
956 | self.optimizer.zero_grad()
957 | self.mask_grads(self.tau, param_gradients, self.network.parameters())
958 | self.optimizer.step()
959 |
960 | return {'loss': mean_loss}
961 |
962 | def mask_grads(self, tau, gradients, params):
963 |
964 | for param, grads in zip(params, gradients):
965 | grads = torch.stack(grads, dim=0)
966 | grad_signs = torch.sign(grads)
967 | mask = torch.mean(grad_signs, dim=0).abs() >= self.tau
968 | mask = mask.to(torch.float32)
969 | avg_grad = torch.mean(grads, dim=0)
970 |
971 | mask_t = (mask.sum() / mask.numel())
972 | param.grad = mask * avg_grad
973 | param.grad *= (1. / (1e-10 + mask_t))
974 |
975 | return 0
976 |
977 |
978 | class IGA(ERM):
979 | """
980 | Inter-environmental Gradient Alignment
981 | From https://arxiv.org/abs/2008.01883v2
982 | """
983 |
984 | def __init__(self, in_features, num_classes, num_domains, hparams):
985 | super(IGA, self).__init__(in_features, num_classes, num_domains, hparams)
986 |
987 | def update(self, minibatches, unlabeled=False):
988 | total_loss = 0
989 | grads = []
990 | for i, (x, y) in enumerate(minibatches):
991 | logits = self.network(x)
992 |
993 | env_loss = F.cross_entropy(logits, y)
994 | total_loss += env_loss
995 |
996 | env_grad = autograd.grad(env_loss, self.network.parameters(), create_graph=True)
997 |
998 | grads.append(env_grad)
999 |
1000 | mean_loss = total_loss / len(minibatches)
1001 | mean_grad = autograd.grad(mean_loss, self.network.parameters(), retain_graph=True)
1002 |
1003 | # compute trace penalty
1004 | penalty_value = 0
1005 | for grad in grads:
1006 | for g, mean_g in zip(grad, mean_grad):
1007 | penalty_value += (g - mean_g).pow(2).sum()
1008 |
1009 | objective = mean_loss + self.hparams['penalty'] * penalty_value
1010 |
1011 | self.optimizer.zero_grad()
1012 | objective.backward()
1013 | self.optimizer.step()
1014 |
1015 | return {'loss': mean_loss.item(), 'penalty': penalty_value.item()}
1016 |
1017 |
1018 | class SelfReg(ERM):
1019 |
1020 | def __init__(self, input_shape, num_classes, num_domains, hparams):
1021 | super(SelfReg, self).__init__(input_shape, num_classes, num_domains, hparams)
1022 | self.num_classes = num_classes
1023 | self.MSEloss = nn.MSELoss()
1024 | input_feat_size = self.featurizer.n_outputs
1025 | hidden_size = input_feat_size if input_feat_size == 2048 else input_feat_size * 2
1026 |
1027 | self.cdpl = nn.Sequential(
1028 | nn.Linear(input_feat_size, hidden_size), nn.BatchNorm1d(hidden_size),
1029 | nn.ReLU(inplace=True), nn.Linear(hidden_size, hidden_size), nn.BatchNorm1d(hidden_size),
1030 | nn.ReLU(inplace=True), nn.Linear(hidden_size, input_feat_size),
1031 | nn.BatchNorm1d(input_feat_size)
1032 | )
1033 |
1034 | def update(self, minibatches, unlabeled=None):
1035 |
1036 | all_x = torch.cat([x for x, y in minibatches])
1037 | all_y = torch.cat([y for _, y in minibatches])
1038 |
1039 | lam = np.random.beta(0.5, 0.5)
1040 |
1041 | batch_size = all_y.size()[0]
1042 |
1043 | # cluster and order features into same-class group
1044 | with torch.no_grad():
1045 | sorted_y, indices = torch.sort(all_y)
1046 | sorted_x = torch.zeros_like(all_x)
1047 | for idx, order in enumerate(indices):
1048 | sorted_x[idx] = all_x[order]
1049 | intervals = []
1050 | ex = 0
1051 | for idx, val in enumerate(sorted_y):
1052 | if ex == val:
1053 | continue
1054 | intervals.append(idx)
1055 | ex = val
1056 | intervals.append(batch_size)
1057 |
1058 | all_x = sorted_x
1059 | all_y = sorted_y
1060 |
1061 | feat = self.featurizer(all_x)
1062 | proj = self.cdpl(feat)
1063 |
1064 | output = self.classifier(feat)
1065 |
1066 | # shuffle
1067 | output_2 = torch.zeros_like(output)
1068 | feat_2 = torch.zeros_like(proj)
1069 | output_3 = torch.zeros_like(output)
1070 | feat_3 = torch.zeros_like(proj)
1071 | ex = 0
1072 | for end in intervals:
1073 | shuffle_indices = torch.randperm(end - ex) + ex
1074 | shuffle_indices2 = torch.randperm(end - ex) + ex
1075 | for idx in range(end - ex):
1076 | output_2[idx + ex] = output[shuffle_indices[idx]]
1077 | feat_2[idx + ex] = proj[shuffle_indices[idx]]
1078 | output_3[idx + ex] = output[shuffle_indices2[idx]]
1079 | feat_3[idx + ex] = proj[shuffle_indices2[idx]]
1080 | ex = end
1081 |
1082 | # mixup
1083 | output_3 = lam * output_2 + (1 - lam) * output_3
1084 | feat_3 = lam * feat_2 + (1 - lam) * feat_3
1085 |
1086 | # regularization
1087 | L_ind_logit = self.MSEloss(output, output_2)
1088 | L_hdl_logit = self.MSEloss(output, output_3)
1089 | L_ind_feat = 0.3 * self.MSEloss(feat, feat_2)
1090 | L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3)
1091 |
1092 | cl_loss = F.cross_entropy(output, all_y)
1093 | C_scale = min(cl_loss.item(), 1.)
1094 | loss = cl_loss + C_scale * (
1095 | lam * (L_ind_logit + L_ind_feat) + (1 - lam) * (L_hdl_logit + L_hdl_feat)
1096 | )
1097 |
1098 | self.optimizer.zero_grad()
1099 | loss.backward()
1100 | self.optimizer.step()
1101 |
1102 | return {'loss': loss.item()}
1103 |
1104 |
1105 | class SANDMask(ERM):
1106 | """
1107 | SAND-mask: An Enhanced Gradient Masking Strategy for the Discovery of Invariances in Domain Generalization
1108 |
1109 | """
1110 |
1111 | def __init__(self, input_shape, num_classes, num_domains, hparams):
1112 | super(SANDMask, self).__init__(input_shape, num_classes, num_domains, hparams)
1113 |
1114 | self.tau = hparams["tau"]
1115 | self.k = hparams["k"]
1116 | betas = (0.9, 0.999)
1117 | self.optimizer = torch.optim.Adam(
1118 | self.network.parameters(),
1119 | lr=self.hparams["lr"],
1120 | weight_decay=self.hparams['weight_decay'],
1121 | betas=betas
1122 | )
1123 |
1124 | self.register_buffer('update_count', torch.tensor([0]))
1125 |
1126 | def update(self, minibatches, unlabeled=None):
1127 |
1128 | mean_loss = 0
1129 | param_gradients = [[] for _ in self.network.parameters()]
1130 | for i, (x, y) in enumerate(minibatches):
1131 | logits = self.network(x)
1132 |
1133 | env_loss = F.cross_entropy(logits, y)
1134 | mean_loss += env_loss.item() / len(minibatches)
1135 | env_grads = autograd.grad(env_loss, self.network.parameters(), retain_graph=True)
1136 | for grads, env_grad in zip(param_gradients, env_grads):
1137 | grads.append(env_grad)
1138 |
1139 | self.optimizer.zero_grad()
1140 | # gradient masking applied here
1141 | self.mask_grads(param_gradients, self.network.parameters())
1142 | self.optimizer.step()
1143 | self.update_count += 1
1144 |
1145 | return {'loss': mean_loss}
1146 |
1147 | def mask_grads(self, gradients, params):
1148 | '''
1149 | Here a mask with continuous values in the range [0,1] is formed to control the amount of update for each
1150 | parameter based on the agreement of gradients coming from different environments.
1151 | '''
1152 | device = gradients[0][0].device
1153 | for param, grads in zip(params, gradients):
1154 | grads = torch.stack(grads, dim=0)
1155 | avg_grad = torch.mean(grads, dim=0)
1156 | grad_signs = torch.sign(grads)
1157 | gamma = torch.tensor(1.0).to(device)
1158 | grads_var = grads.var(dim=0)
1159 | grads_var[torch.isnan(grads_var)] = 1e-17
1160 | lam = (gamma * grads_var).pow(-1)
1161 | mask = torch.tanh(self.k * lam * (torch.abs(grad_signs.mean(dim=0)) - self.tau))
1162 | mask = torch.max(mask, torch.zeros_like(mask))
1163 | mask[torch.isnan(mask)] = 1e-17
1164 | mask_t = (mask.sum() / mask.numel())
1165 | param.grad = mask * avg_grad
1166 | param.grad *= (1. / (1e-10 + mask_t))
1167 |
1168 |
1169 | class Fishr(Algorithm):
1170 | "Invariant Gradients variances for Out-of-distribution Generalization"
1171 |
1172 | def __init__(self, input_shape, num_classes, num_domains, hparams):
1173 | assert backpack is not None, "Install backpack with: 'pip install backpack-for-pytorch==1.3.0'"
1174 | super(Fishr, self).__init__(input_shape, num_classes, num_domains, hparams)
1175 | self.num_domains = num_domains
1176 |
1177 | self.featurizer = networks.Featurizer(input_shape, self.hparams)
1178 | self.classifier = extend(
1179 | networks.Classifier(
1180 | self.featurizer.n_outputs,
1181 | num_classes,
1182 | self.hparams['nonlinear_classifier'],
1183 | )
1184 | )
1185 | self.network = nn.Sequential(self.featurizer, self.classifier)
1186 |
1187 | self.register_buffer("update_count", torch.tensor([0]))
1188 | self.bce_extended = extend(nn.CrossEntropyLoss(reduction='none'))
1189 | self.ema_per_domain = [
1190 | MovingAverage(ema=self.hparams["ema"], oneminusema_correction=True)
1191 | for _ in range(self.num_domains)
1192 | ]
1193 | self._init_optimizer()
1194 |
1195 | def _init_optimizer(self):
1196 | self.optimizer = torch.optim.Adam(
1197 | list(self.featurizer.parameters()) + list(self.classifier.parameters()),
1198 | lr=self.hparams["lr"],
1199 | weight_decay=self.hparams["weight_decay"],
1200 | )
1201 |
1202 | def update(self, minibatches, unlabeled=False):
1203 | assert len(minibatches) == self.num_domains
1204 | all_x = torch.cat([x for x, y in minibatches])
1205 | all_y = torch.cat([y for x, y in minibatches])
1206 | len_minibatches = [x.shape[0] for x, y in minibatches]
1207 |
1208 | all_z = self.featurizer(all_x)
1209 | all_logits = self.classifier(all_z)
1210 |
1211 | penalty = self.compute_fishr_penalty(all_logits, all_y, len_minibatches)
1212 | all_nll = F.cross_entropy(all_logits, all_y)
1213 |
1214 | penalty_weight = 0
1215 | if self.update_count >= self.hparams["penalty_anneal_iters"]:
1216 | penalty_weight = self.hparams["lambda"]
1217 | if self.update_count == self.hparams["penalty_anneal_iters"] != 0:
1218 | # Reset Adam as in IRM or V-REx, because it may not like the sharp jump in
1219 | # gradient magnitudes that happens at this step.
1220 | self._init_optimizer()
1221 | self.update_count += 1
1222 |
1223 | objective = all_nll + penalty_weight * penalty
1224 | self.optimizer.zero_grad()
1225 | objective.backward()
1226 | self.optimizer.step()
1227 |
1228 | return {'loss': objective.item(), 'nll': all_nll.item(), 'penalty': penalty.item()}
1229 |
1230 | def compute_fishr_penalty(self, all_logits, all_y, len_minibatches):
1231 | dict_grads = self._get_grads(all_logits, all_y)
1232 | grads_var_per_domain = self._get_grads_var_per_domain(dict_grads, len_minibatches)
1233 | return self._compute_distance_grads_var(grads_var_per_domain)
1234 |
1235 | def _get_grads(self, logits, y):
1236 | self.optimizer.zero_grad()
1237 | loss = self.bce_extended(logits, y).sum()
1238 | with backpack(BatchGrad()):
1239 | loss.backward(
1240 | inputs=list(self.classifier.parameters()), retain_graph=True, create_graph=True
1241 | )
1242 |
1243 | # compute individual grads for all samples across all domains simultaneously
1244 | dict_grads = OrderedDict(
1245 | [
1246 | (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1))
1247 | for name, weights in self.classifier.named_parameters()
1248 | ]
1249 | )
1250 | return dict_grads
1251 |
1252 | def _get_grads_var_per_domain(self, dict_grads, len_minibatches):
1253 | # grads var per domain
1254 | grads_var_per_domain = [{} for _ in range(self.num_domains)]
1255 | for name, _grads in dict_grads.items():
1256 | all_idx = 0
1257 | for domain_id, bsize in enumerate(len_minibatches):
1258 | env_grads = _grads[all_idx:all_idx + bsize]
1259 | all_idx += bsize
1260 | env_mean = env_grads.mean(dim=0, keepdim=True)
1261 | env_grads_centered = env_grads - env_mean
1262 | grads_var_per_domain[domain_id][name] = (env_grads_centered).pow(2).mean(dim=0)
1263 |
1264 | # moving average
1265 | for domain_id in range(self.num_domains):
1266 | grads_var_per_domain[domain_id] = self.ema_per_domain[domain_id].update(
1267 | grads_var_per_domain[domain_id]
1268 | )
1269 |
1270 | return grads_var_per_domain
1271 |
1272 | def _compute_distance_grads_var(self, grads_var_per_domain):
1273 |
1274 | # compute gradient variances averaged across domains
1275 | grads_var = OrderedDict(
1276 | [
1277 | (
1278 | name,
1279 | torch.stack(
1280 | [
1281 | grads_var_per_domain[domain_id][name]
1282 | for domain_id in range(self.num_domains)
1283 | ],
1284 | dim=0
1285 | ).mean(dim=0)
1286 | )
1287 | for name in grads_var_per_domain[0].keys()
1288 | ]
1289 | )
1290 |
1291 | penalty = 0
1292 | for domain_id in range(self.num_domains):
1293 | penalty += l2_between_dicts(grads_var_per_domain[domain_id], grads_var)
1294 | return penalty / self.num_domains
1295 |
1296 | def predict(self, x):
1297 | return self.network(x)
1298 |
--------------------------------------------------------------------------------
/domainbed/command_launchers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | A command launcher launches a list of commands on a cluster; implement your own
5 | launcher to add support for your cluster. We've provided an example launcher
6 | which runs all commands serially on the local machine.
7 | """
8 |
9 | import subprocess
10 | import time
11 | import torch
12 |
13 | def local_launcher(commands):
14 | """Launch commands serially on the local machine."""
15 | for cmd in commands:
16 | subprocess.call(cmd, shell=True)
17 |
18 | def dummy_launcher(commands):
19 | """
20 | Doesn't run anything; instead, prints each command.
21 | Useful for testing.
22 | """
23 | for cmd in commands:
24 | print(f'Dummy launcher: {cmd}')
25 |
26 | def multi_gpu_launcher(commands):
27 | """
28 | Launch commands on the local machine, using all GPUs in parallel.
29 | """
30 | print('WARNING: using experimental multi_gpu_launcher.')
31 | n_gpus = torch.cuda.device_count()
32 | procs_by_gpu = [None]*n_gpus
33 |
34 | while len(commands) > 0:
35 | for gpu_idx in range(n_gpus):
36 | proc = procs_by_gpu[gpu_idx]
37 | if (proc is None) or (proc.poll() is not None):
38 | # Nothing is running on this GPU; launch a command.
39 | cmd = commands.pop(0)
40 | new_proc = subprocess.Popen(
41 | f'CUDA_VISIBLE_DEVICES={gpu_idx} {cmd}', shell=True)
42 | procs_by_gpu[gpu_idx] = new_proc
43 | break
44 | time.sleep(1)
45 |
46 | # Wait for the last few tasks to finish before returning
47 | for p in procs_by_gpu:
48 | if p is not None:
49 | p.wait()
50 |
51 | REGISTRY = {
52 | 'local': local_launcher,
53 | 'dummy': dummy_launcher,
54 | 'multi_gpu': multi_gpu_launcher
55 | }
56 |
57 | try:
58 | from domainbed import facebook
59 | facebook.register_command_launchers(REGISTRY)
60 | except ImportError:
61 | pass
62 |
--------------------------------------------------------------------------------
/domainbed/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import torch
5 | from PIL import Image, ImageFile
6 | from torchvision import transforms
7 | import torchvision.datasets.folder
8 | from torch.utils.data import TensorDataset, Subset
9 | from torchvision.datasets import MNIST, ImageFolder
10 | from torchvision.transforms.functional import rotate
11 |
12 | # from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
13 | # from wilds.datasets.fmow_dataset import FMoWDataset
14 |
15 | ImageFile.LOAD_TRUNCATED_IMAGES = True
16 |
17 | DATASETS = [
18 | # Debug
19 | "Debug28",
20 | "Debug224",
21 | # Small images
22 | "ColoredMNIST",
23 | "RotatedMNIST",
24 | # Big images
25 | "VLCS",
26 | "PACS",
27 | "OfficeHome",
28 | "TerraIncognita",
29 | "DomainNet",
30 | # "SVIRO",
31 | # # WILDS datasets
32 | # "WILDSCamelyon",
33 | # "WILDSFMoW"
34 | ]
35 |
36 | def get_dataset_class(dataset_name):
37 | """Return the dataset class with the given name."""
38 | if dataset_name not in globals():
39 | raise NotImplementedError("Dataset not found: {}".format(dataset_name))
40 | return globals()[dataset_name]
41 |
42 |
43 | def num_environments(dataset_name):
44 | return len(get_dataset_class(dataset_name).ENVIRONMENTS)
45 |
46 |
47 | class MultipleDomainDataset:
48 | N_STEPS = 5001 # Default, subclasses may override
49 | CHECKPOINT_FREQ = 100 # Default, subclasses may override
50 | N_WORKERS = 8 # Default, subclasses may override
51 | ENVIRONMENTS = None # Subclasses should override
52 | INPUT_SHAPE = None # Subclasses should override
53 |
54 | def __getitem__(self, index):
55 | return self.datasets[index]
56 |
57 | def __len__(self):
58 | return len(self.datasets)
59 |
60 |
61 | class Debug(MultipleDomainDataset):
62 | def __init__(self, root, test_envs, hparams):
63 | super().__init__()
64 | self.input_shape = self.INPUT_SHAPE
65 | self.num_classes = 2
66 | self.datasets = []
67 | for _ in [0, 1, 2]:
68 | self.datasets.append(
69 | TensorDataset(
70 | torch.randn(16, *self.INPUT_SHAPE),
71 | torch.randint(0, self.num_classes, (16,))
72 | )
73 | )
74 |
75 | class Debug28(Debug):
76 | INPUT_SHAPE = (3, 28, 28)
77 | ENVIRONMENTS = ['0', '1', '2']
78 |
79 | class Debug224(Debug):
80 | INPUT_SHAPE = (3, 224, 224)
81 | ENVIRONMENTS = ['0', '1', '2']
82 |
83 |
84 | class MultipleEnvironmentMNIST(MultipleDomainDataset):
85 | def __init__(self, root, environments, dataset_transform, input_shape,
86 | num_classes):
87 | super().__init__()
88 | if root is None:
89 | raise ValueError('Data directory not specified!')
90 |
91 | original_dataset_tr = MNIST(root, train=True, download=True)
92 | original_dataset_te = MNIST(root, train=False, download=True)
93 |
94 | original_images = torch.cat((original_dataset_tr.data,
95 | original_dataset_te.data))
96 |
97 | original_labels = torch.cat((original_dataset_tr.targets,
98 | original_dataset_te.targets))
99 |
100 | shuffle = torch.randperm(len(original_images))
101 |
102 | original_images = original_images[shuffle]
103 | original_labels = original_labels[shuffle]
104 |
105 | self.datasets = []
106 |
107 | for i in range(len(environments)):
108 | images = original_images[i::len(environments)]
109 | labels = original_labels[i::len(environments)]
110 | self.datasets.append(dataset_transform(images, labels, environments[i]))
111 |
112 | self.input_shape = input_shape
113 | self.num_classes = num_classes
114 |
115 |
116 | class ColoredMNIST(MultipleEnvironmentMNIST):
117 | ENVIRONMENTS = ['+90%', '+80%', '-90%']
118 |
119 | def __init__(self, root, test_envs, hparams):
120 | super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9],
121 | self.color_dataset, (2, 28, 28,), 2)
122 |
123 | self.input_shape = (2, 28, 28,)
124 | self.num_classes = 2
125 |
126 | def color_dataset(self, images, labels, environment):
127 | # # Subsample 2x for computational convenience
128 | # images = images.reshape((-1, 28, 28))[:, ::2, ::2]
129 | # Assign a binary label based on the digit
130 | labels = (labels < 5).float()
131 | # Flip label with probability 0.25
132 | labels = self.torch_xor_(labels,
133 | self.torch_bernoulli_(0.25, len(labels)))
134 |
135 | # Assign a color based on the label; flip the color with probability e
136 | colors = self.torch_xor_(labels,
137 | self.torch_bernoulli_(environment,
138 | len(labels)))
139 | images = torch.stack([images, images], dim=1)
140 | # Apply the color to the image by zeroing out the other color channel
141 | images[torch.tensor(range(len(images))), (
142 | 1 - colors).long(), :, :] *= 0
143 |
144 | x = images.float().div_(255.0)
145 | y = labels.view(-1).long()
146 |
147 | return TensorDataset(x, y)
148 |
149 | def torch_bernoulli_(self, p, size):
150 | return (torch.rand(size) < p).float()
151 |
152 | def torch_xor_(self, a, b):
153 | return (a - b).abs()
154 |
155 |
156 | class RotatedMNIST(MultipleEnvironmentMNIST):
157 | ENVIRONMENTS = ['0', '15', '30', '45', '60', '75']
158 |
159 | def __init__(self, root, test_envs, hparams):
160 | super(RotatedMNIST, self).__init__(root, [0, 15, 30, 45, 60, 75],
161 | self.rotate_dataset, (1, 28, 28,), 10)
162 |
163 | def rotate_dataset(self, images, labels, angle):
164 | rotation = transforms.Compose([
165 | transforms.ToPILImage(),
166 | transforms.Lambda(lambda x: rotate(x, angle, fill=(0,),
167 | interpolation=torchvision.transforms.InterpolationMode.BILINEAR)),
168 | transforms.ToTensor()])
169 |
170 | x = torch.zeros(len(images), 1, 28, 28)
171 | for i in range(len(images)):
172 | x[i] = rotation(images[i])
173 |
174 | y = labels.view(-1)
175 |
176 | return TensorDataset(x, y)
177 |
178 |
179 | class MultipleEnvironmentImageFolder(MultipleDomainDataset):
180 | def __init__(self, root, test_envs, augment, hparams):
181 | super().__init__()
182 | environments = [f.name for f in os.scandir(root) if f.is_dir()]
183 | environments = sorted(environments)
184 |
185 | transform = transforms.Compose([
186 | transforms.Resize((224,224)),
187 | transforms.ToTensor(),
188 | transforms.Normalize(
189 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
190 | ])
191 |
192 | augment_transform = transforms.Compose([
193 | # transforms.Resize((224,224)),
194 | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
195 | transforms.RandomHorizontalFlip(),
196 | transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
197 | transforms.RandomGrayscale(),
198 | transforms.ToTensor(),
199 | transforms.Normalize(
200 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
201 | ])
202 |
203 | self.datasets = []
204 | for i, environment in enumerate(environments):
205 |
206 | if augment and (i not in test_envs):
207 | env_transform = augment_transform
208 | else:
209 | env_transform = transform
210 |
211 | path = os.path.join(root, environment)
212 | env_dataset = ImageFolder(path,
213 | transform=env_transform)
214 |
215 | self.datasets.append(env_dataset)
216 |
217 | self.input_shape = (3, 224, 224,)
218 | self.num_classes = len(self.datasets[-1].classes)
219 |
220 | class VLCS(MultipleEnvironmentImageFolder):
221 | CHECKPOINT_FREQ = 300
222 | ENVIRONMENTS = ["C", "L", "S", "V"]
223 | def __init__(self, root, test_envs, hparams):
224 | self.dir = os.path.join(root, "VLCS/")
225 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
226 |
227 | class PACS(MultipleEnvironmentImageFolder):
228 | CHECKPOINT_FREQ = 300
229 | ENVIRONMENTS = ["A", "C", "P", "S"]
230 | def __init__(self, root, test_envs, hparams):
231 | self.dir = os.path.join(root, "PACS/")
232 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
233 |
234 | class DomainNet(MultipleEnvironmentImageFolder):
235 | CHECKPOINT_FREQ = 1000
236 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"]
237 | def __init__(self, root, test_envs, hparams):
238 | self.dir = os.path.join(root, "domain_net/")
239 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
240 |
241 | class OfficeHome(MultipleEnvironmentImageFolder):
242 | CHECKPOINT_FREQ = 300
243 | ENVIRONMENTS = ["A", "C", "P", "R"]
244 | def __init__(self, root, test_envs, hparams):
245 | self.dir = os.path.join(root, "office_home/")
246 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
247 |
248 | class TerraIncognita(MultipleEnvironmentImageFolder):
249 | CHECKPOINT_FREQ = 300
250 | ENVIRONMENTS = ["L100", "L38", "L43", "L46"]
251 | def __init__(self, root, test_envs, hparams):
252 | self.dir = os.path.join(root, "terra_incognita/")
253 | super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
254 |
255 |
256 | # class SVIRO(MultipleEnvironmentImageFolder):
257 | # CHECKPOINT_FREQ = 300
258 | # ENVIRONMENTS = [
259 | # "aclass", "escape", "hilux", "i3", "lexus", "tesla", "tiguan", "tucson", "x5", "zoe"
260 | # ]
261 |
262 | # def __init__(self, root, test_envs, hparams):
263 | # self.dir = os.path.join(root, "sviro/")
264 | # super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams)
265 |
266 |
267 | # class WILDSEnvironment:
268 |
269 | # def __init__(self, wilds_dataset, metadata_name, metadata_value, transform=None):
270 | # self.name = metadata_name + "_" + str(metadata_value)
271 |
272 | # metadata_index = wilds_dataset.metadata_fields.index(metadata_name)
273 | # metadata_array = wilds_dataset.metadata_array
274 | # subset_indices = torch.where(metadata_array[:, metadata_index] == metadata_value)[0]
275 |
276 | # self.dataset = wilds_dataset
277 | # self.indices = subset_indices
278 | # self.transform = transform
279 |
280 | # def __getitem__(self, i):
281 | # x = self.dataset.get_input(self.indices[i])
282 | # if type(x).__name__ != "Image":
283 | # x = Image.fromarray(x)
284 |
285 | # y = self.dataset.y_array[self.indices[i]]
286 | # if self.transform is not None:
287 | # x = self.transform(x)
288 | # return x, y
289 |
290 | # def __len__(self):
291 | # return len(self.indices)
292 |
293 |
294 | # class WILDSDataset(MultipleDomainDataset):
295 | # INPUT_SHAPE = (3, 224, 224)
296 |
297 | # def __init__(self, dataset, metadata_name, test_envs, augment, hparams):
298 | # super().__init__()
299 |
300 | # transform = transforms.Compose(
301 | # [
302 | # transforms.Resize((224, 224)),
303 | # transforms.ToTensor(),
304 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
305 | # ]
306 | # )
307 |
308 | # augment_transform = transforms.Compose(
309 | # [
310 | # transforms.Resize((224, 224)),
311 | # transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
312 | # transforms.RandomHorizontalFlip(),
313 | # transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
314 | # transforms.RandomGrayscale(),
315 | # transforms.ToTensor(),
316 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
317 | # ]
318 | # )
319 |
320 | # self.datasets = []
321 |
322 | # for i, metadata_value in enumerate(self.metadata_values(dataset, metadata_name)):
323 | # if augment and (i not in test_envs):
324 | # env_transform = augment_transform
325 | # else:
326 | # env_transform = transform
327 |
328 | # env_dataset = WILDSEnvironment(dataset, metadata_name, metadata_value, env_transform)
329 |
330 | # self.datasets.append(env_dataset)
331 |
332 | # self.input_shape = (
333 | # 3,
334 | # 224,
335 | # 224,
336 | # )
337 | # self.num_classes = dataset.n_classes
338 |
339 | # def metadata_values(self, wilds_dataset, metadata_name):
340 | # metadata_index = wilds_dataset.metadata_fields.index(metadata_name)
341 | # metadata_vals = wilds_dataset.metadata_array[:, metadata_index]
342 | # return sorted(list(set(metadata_vals.view(-1).tolist())))
343 |
344 |
345 | # class WILDSCamelyon(WILDSDataset):
346 | # ENVIRONMENTS = ["hospital_0", "hospital_1", "hospital_2", "hospital_3", "hospital_4"]
347 |
348 | # def __init__(self, root, test_envs, hparams):
349 | # dataset = Camelyon17Dataset(root_dir=root)
350 | # super().__init__(dataset, "hospital", test_envs, hparams['data_augmentation'], hparams)
351 |
352 |
353 | # class WILDSFMoW(WILDSDataset):
354 | # ENVIRONMENTS = ["region_0", "region_1", "region_2", "region_3", "region_4", "region_5"]
355 |
356 | # def __init__(self, root, test_envs, hparams):
357 | # dataset = FMoWDataset(root_dir=root)
358 | # super().__init__(dataset, "region", test_envs, hparams['data_augmentation'], hparams)
359 |
--------------------------------------------------------------------------------
/domainbed/hparams_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import numpy as np
3 | from domainbed.lib import misc
4 |
5 |
6 | def _define_hparam(hparams, hparam_name, default_val, random_val_fn):
7 | hparams[hparam_name] = (hparams, hparam_name, default_val, random_val_fn)
8 |
9 |
10 | def _hparams(algorithm, dataset, random_seed):
11 | """
12 | Global registry of hyperparams. Each entry is a (default, random) tuple.
13 | New algorithms / networks / etc. should add entries here.
14 | """
15 | SMALL_IMAGES = ['Debug28', 'RotatedMNIST', 'ColoredMNIST']
16 |
17 | hparams = {}
18 |
19 | def _hparam(name, default_val, random_val_fn):
20 | """Define a hyperparameter. random_val_fn takes a RandomState and
21 | returns a random hyperparameter value."""
22 | assert(name not in hparams)
23 | random_state = np.random.RandomState(
24 | misc.seed_hash(random_seed, name)
25 | )
26 | hparams[name] = (default_val, random_val_fn(random_state))
27 |
28 | # Unconditional hparam definitions.
29 |
30 | _hparam('data_augmentation', True, lambda r: True)
31 | _hparam('resnet18', False, lambda r: False)
32 | _hparam('resnet_dropout', 0., lambda r: r.choice([0., 0.1, 0.5]))
33 | _hparam('class_balanced', False, lambda r: False)
34 | # TODO: nonlinear classifiers disabled
35 | _hparam('nonlinear_classifier', False,
36 | lambda r: bool(r.choice([False, False])))
37 |
38 | # Algorithm-specific hparam definitions. Each block of code below
39 | # corresponds to exactly one algorithm.
40 |
41 | if algorithm in ['DANN', 'CDANN']:
42 | _hparam('lambda', 1.0, lambda r: 10**r.uniform(-2, 2))
43 | _hparam('weight_decay_d', 0., lambda r: 10**r.uniform(-6, -2))
44 | _hparam('d_steps_per_g_step', 1, lambda r: int(2**r.uniform(0, 3)))
45 | _hparam('grad_penalty', 0., lambda r: 10**r.uniform(-2, 1))
46 | _hparam('beta1', 0.5, lambda r: r.choice([0., 0.5]))
47 | _hparam('mlp_width', 256, lambda r: int(2 ** r.uniform(6, 10)))
48 | _hparam('mlp_depth', 3, lambda r: int(r.choice([3, 4, 5])))
49 | _hparam('mlp_dropout', 0., lambda r: r.choice([0., 0.1, 0.5]))
50 |
51 | elif algorithm == 'Fish':
52 | _hparam('meta_lr', 0.5, lambda r:r.choice([0.05, 0.1, 0.5]))
53 |
54 | elif algorithm == "RSC":
55 | _hparam('rsc_f_drop_factor', 1/3, lambda r: r.uniform(0, 0.5))
56 | _hparam('rsc_b_drop_factor', 1/3, lambda r: r.uniform(0, 0.5))
57 |
58 | elif algorithm == "SagNet":
59 | _hparam('sag_w_adv', 0.1, lambda r: 10**r.uniform(-2, 1))
60 |
61 | elif algorithm == "IRM":
62 | _hparam('irm_lambda', 1e2, lambda r: 10**r.uniform(-1, 5))
63 | _hparam('irm_penalty_anneal_iters', 500,
64 | lambda r: int(10**r.uniform(0, 4)))
65 |
66 | elif algorithm == "Mixup":
67 | _hparam('mixup_alpha', 0.2, lambda r: 10**r.uniform(-1, -1))
68 |
69 | elif algorithm == "GroupDRO":
70 | _hparam('groupdro_eta', 1e-2, lambda r: 10**r.uniform(-3, -1))
71 |
72 | elif algorithm == "MMD" or algorithm == "CORAL":
73 | _hparam('mmd_gamma', 1., lambda r: 10**r.uniform(-1, 1))
74 |
75 | elif algorithm == "MLDG":
76 | _hparam('mldg_beta', 1., lambda r: 10**r.uniform(-1, 1))
77 |
78 | elif algorithm == "MTL":
79 | _hparam('mtl_ema', .99, lambda r: r.choice([0.5, 0.9, 0.99, 1.]))
80 |
81 | elif algorithm == "VREx":
82 | _hparam('vrex_lambda', 1e1, lambda r: 10**r.uniform(-1, 5))
83 | _hparam('vrex_penalty_anneal_iters', 500,
84 | lambda r: int(10**r.uniform(0, 4)))
85 |
86 | elif algorithm == "SD":
87 | _hparam('sd_reg', 0.1, lambda r: 10**r.uniform(-5, -1))
88 |
89 | elif algorithm == "ANDMask":
90 | _hparam('tau', 1, lambda r: r.uniform(0.5, 1.))
91 |
92 | elif algorithm == "IGA":
93 | _hparam('penalty', 1000, lambda r: 10**r.uniform(1, 5))
94 |
95 | elif algorithm == "SANDMask":
96 | _hparam('tau', 1.0, lambda r: r.uniform(0.0, 1.))
97 | _hparam('k', 1e+1, lambda r: int(10**r.uniform(-3, 5)))
98 |
99 | elif algorithm == "Fishr":
100 | _hparam('lambda', 1000., lambda r: 10**r.uniform(1., 4.))
101 | _hparam('penalty_anneal_iters', 1500, lambda r: int(r.uniform(0., 5000.)))
102 | _hparam('ema', 0.95, lambda r: r.uniform(0.90, 0.99))
103 |
104 | # Dataset-and-algorithm-specific hparam definitions. Each block of code
105 | # below corresponds to exactly one hparam. Avoid nested conditionals.
106 |
107 | if dataset in SMALL_IMAGES:
108 | _hparam('lr', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
109 | else:
110 | _hparam('lr', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
111 |
112 | if dataset in SMALL_IMAGES:
113 | _hparam('weight_decay', 0., lambda r: 0.)
114 | else:
115 | _hparam('weight_decay', 0., lambda r: 10**r.uniform(-6, -2))
116 |
117 | if dataset in SMALL_IMAGES:
118 | _hparam('batch_size', 64, lambda r: int(2**r.uniform(3, 9)))
119 | elif algorithm == 'ARM':
120 | _hparam('batch_size', 8, lambda r: 8)
121 | elif dataset == 'DomainNet':
122 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5)))
123 | else:
124 | _hparam('batch_size', 32, lambda r: int(2**r.uniform(3, 5.5)))
125 |
126 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES:
127 | _hparam('lr_g', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
128 | elif algorithm in ['DANN', 'CDANN']:
129 | _hparam('lr_g', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
130 |
131 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES:
132 | _hparam('lr_d', 1e-3, lambda r: 10**r.uniform(-4.5, -2.5))
133 | elif algorithm in ['DANN', 'CDANN']:
134 | _hparam('lr_d', 5e-5, lambda r: 10**r.uniform(-5, -3.5))
135 |
136 | if algorithm in ['DANN', 'CDANN'] and dataset in SMALL_IMAGES:
137 | _hparam('weight_decay_g', 0., lambda r: 0.)
138 | elif algorithm in ['DANN', 'CDANN']:
139 | _hparam('weight_decay_g', 0., lambda r: 10**r.uniform(-6, -2))
140 |
141 | return hparams
142 |
143 |
144 | def default_hparams(algorithm, dataset):
145 | return {a: b for a, (b, c) in _hparams(algorithm, dataset, 0).items()}
146 |
147 |
148 | def random_hparams(algorithm, dataset, seed):
149 | return {a: c for a, (b, c) in _hparams(algorithm, dataset, seed).items()}
150 |
--------------------------------------------------------------------------------
/domainbed/lib/fast_data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 |
5 | class _InfiniteSampler(torch.utils.data.Sampler):
6 | """Wraps another Sampler to yield an infinite stream."""
7 | def __init__(self, sampler):
8 | self.sampler = sampler
9 |
10 | def __iter__(self):
11 | while True:
12 | for batch in self.sampler:
13 | yield batch
14 |
15 | class InfiniteDataLoader:
16 | def __init__(self, dataset, weights, batch_size, num_workers):
17 | super().__init__()
18 |
19 | if weights is not None:
20 | sampler = torch.utils.data.WeightedRandomSampler(weights,
21 | replacement=True,
22 | num_samples=batch_size)
23 | else:
24 | sampler = torch.utils.data.RandomSampler(dataset,
25 | replacement=True)
26 |
27 | if weights == None:
28 | weights = torch.ones(len(dataset))
29 |
30 | batch_sampler = torch.utils.data.BatchSampler(
31 | sampler,
32 | batch_size=batch_size,
33 | drop_last=True)
34 |
35 | self._infinite_iterator = iter(torch.utils.data.DataLoader(
36 | dataset,
37 | num_workers=num_workers,
38 | batch_sampler=_InfiniteSampler(batch_sampler)
39 | ))
40 |
41 | def __iter__(self):
42 | while True:
43 | yield next(self._infinite_iterator)
44 |
45 | def __len__(self):
46 | raise ValueError
47 |
48 | class FastDataLoader:
49 | """DataLoader wrapper with slightly improved speed by not respawning worker
50 | processes at every epoch."""
51 | def __init__(self, dataset, batch_size, num_workers):
52 | super().__init__()
53 |
54 | batch_sampler = torch.utils.data.BatchSampler(
55 | torch.utils.data.RandomSampler(dataset, replacement=False),
56 | batch_size=batch_size,
57 | drop_last=False
58 | )
59 |
60 | self._infinite_iterator = iter(torch.utils.data.DataLoader(
61 | dataset,
62 | num_workers=num_workers,
63 | batch_sampler=_InfiniteSampler(batch_sampler)
64 | ))
65 |
66 | self._length = len(batch_sampler)
67 |
68 | def __iter__(self):
69 | for _ in range(len(self)):
70 | yield next(self._infinite_iterator)
71 |
72 | def __len__(self):
73 | return self._length
74 |
--------------------------------------------------------------------------------
/domainbed/lib/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | Things that don't belong anywhere else
5 | """
6 |
7 | import hashlib
8 | import json
9 | import os
10 | import sys
11 | from shutil import copyfile
12 | from collections import OrderedDict, defaultdict
13 | from numbers import Number
14 | import operator
15 |
16 | import numpy as np
17 | import torch
18 | import tqdm
19 | from collections import Counter
20 |
21 |
22 | def l2_between_dicts(dict_1, dict_2):
23 | assert len(dict_1) == len(dict_2)
24 | dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())]
25 | dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())]
26 | return (
27 | torch.cat(tuple([t.view(-1) for t in dict_1_values])) -
28 | torch.cat(tuple([t.view(-1) for t in dict_2_values]))
29 | ).pow(2).mean()
30 |
31 | class MovingAverage:
32 |
33 | def __init__(self, ema, oneminusema_correction=True):
34 | self.ema = ema
35 | self.named_parameters = {}
36 | self._updates = 0
37 | self._oneminusema_correction = oneminusema_correction
38 |
39 | def update(self, dict_data):
40 | ema_dict_data = {}
41 | for name, data in dict_data.items():
42 | data = data.view(1, -1)
43 | if self._updates == 0:
44 | previous_data = torch.zeros_like(data)
45 | else:
46 | previous_data = self.named_parameters[name]
47 |
48 | ema_data = self.ema * previous_data + (1 - self.ema) * data
49 | if self._oneminusema_correction:
50 | ema_dict_data[name] = ema_data / (1 - self.ema)
51 | else:
52 | ema_dict_data[name] = ema_data
53 | self.named_parameters[name] = ema_data.clone().detach()
54 |
55 | self._updates += 1
56 | return ema_dict_data
57 |
58 |
59 |
60 | def make_weights_for_balanced_classes(dataset):
61 | counts = Counter()
62 | classes = []
63 | for _, y in dataset:
64 | y = int(y)
65 | counts[y] += 1
66 | classes.append(y)
67 |
68 | n_classes = len(counts)
69 |
70 | weight_per_class = {}
71 | for y in counts:
72 | weight_per_class[y] = 1 / (counts[y] * n_classes)
73 |
74 | weights = torch.zeros(len(dataset))
75 | for i, y in enumerate(classes):
76 | weights[i] = weight_per_class[int(y)]
77 |
78 | return weights
79 |
80 | def pdb():
81 | sys.stdout = sys.__stdout__
82 | import pdb
83 | print("Launching PDB, enter 'n' to step to parent function.")
84 | pdb.set_trace()
85 |
86 | def seed_hash(*args):
87 | """
88 | Derive an integer hash from all args, for use as a random seed.
89 | """
90 | args_str = str(args)
91 | return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)
92 |
93 | def print_separator():
94 | print("="*80)
95 |
96 | def print_row(row, colwidth=10, latex=False):
97 | if latex:
98 | sep = " & "
99 | end_ = "\\\\"
100 | else:
101 | sep = " "
102 | end_ = ""
103 |
104 | def format_val(x):
105 | if np.issubdtype(type(x), np.floating):
106 | x = "{:.10f}".format(x)
107 | return str(x).ljust(colwidth)[:colwidth]
108 | print(sep.join([format_val(x) for x in row]), end_)
109 |
110 | class _SplitDataset(torch.utils.data.Dataset):
111 | """Used by split_dataset"""
112 | def __init__(self, underlying_dataset, keys):
113 | super(_SplitDataset, self).__init__()
114 | self.underlying_dataset = underlying_dataset
115 | self.keys = keys
116 | def __getitem__(self, key):
117 | return self.underlying_dataset[self.keys[key]]
118 | def __len__(self):
119 | return len(self.keys)
120 |
121 | def split_dataset(dataset, n, seed=0):
122 | """
123 | Return a pair of datasets corresponding to a random split of the given
124 | dataset, with n datapoints in the first dataset and the rest in the last,
125 | using the given random seed
126 | """
127 | assert(n <= len(dataset))
128 | keys = list(range(len(dataset)))
129 | np.random.RandomState(seed).shuffle(keys)
130 | keys_1 = keys[:n]
131 | keys_2 = keys[n:]
132 | return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)
133 |
134 | def random_pairs_of_minibatches(minibatches):
135 | perm = torch.randperm(len(minibatches)).tolist()
136 | pairs = []
137 |
138 | for i in range(len(minibatches)):
139 | j = i + 1 if i < (len(minibatches) - 1) else 0
140 |
141 | xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1]
142 | xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1]
143 |
144 | min_n = min(len(xi), len(xj))
145 |
146 | pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n])))
147 |
148 | return pairs
149 |
150 | def accuracy(network, loader, weights, device):
151 | correct = 0
152 | total = 0
153 | weights_offset = 0
154 |
155 | network.eval()
156 | with torch.no_grad():
157 | for x, y in loader:
158 | x = x.to(device)
159 | y = y.to(device)
160 | p = network.predict(x)
161 | if weights is None:
162 | batch_weights = torch.ones(len(x))
163 | else:
164 | batch_weights = weights[weights_offset : weights_offset + len(x)]
165 | weights_offset += len(x)
166 | batch_weights = batch_weights.to(device)
167 | if p.size(1) == 1:
168 | correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item()
169 | else:
170 | correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item()
171 | total += batch_weights.sum().item()
172 | network.train()
173 |
174 | return correct / total
175 |
176 | class Tee:
177 | def __init__(self, fname, mode="a"):
178 | self.stdout = sys.stdout
179 | self.file = open(fname, mode)
180 |
181 | def write(self, message):
182 | self.stdout.write(message)
183 | self.file.write(message)
184 | self.flush()
185 |
186 | def flush(self):
187 | self.stdout.flush()
188 | self.file.flush()
189 |
190 | class ParamDict(OrderedDict):
191 | """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile.
192 | A dictionary where the values are Tensors, meant to represent weights of
193 | a model. This subclass lets you perform arithmetic on weights directly."""
194 |
195 | def __init__(self, *args, **kwargs):
196 | super().__init__(*args, *kwargs)
197 |
198 | def _prototype(self, other, op):
199 | if isinstance(other, Number):
200 | return ParamDict({k: op(v, other) for k, v in self.items()})
201 | elif isinstance(other, dict):
202 | return ParamDict({k: op(self[k], other[k]) for k in self})
203 | else:
204 | raise NotImplementedError
205 |
206 | def __add__(self, other):
207 | return self._prototype(other, operator.add)
208 |
209 | def __rmul__(self, other):
210 | return self._prototype(other, operator.mul)
211 |
212 | __mul__ = __rmul__
213 |
214 | def __neg__(self):
215 | return ParamDict({k: -v for k, v in self.items()})
216 |
217 | def __rsub__(self, other):
218 | # a- b := a + (-b)
219 | return self.__add__(other.__neg__())
220 |
221 | __sub__ = __rsub__
222 |
223 | def __truediv__(self, other):
224 | return self._prototype(other, operator.truediv)
225 |
--------------------------------------------------------------------------------
/domainbed/lib/query.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """Small query library."""
4 |
5 | import collections
6 | import inspect
7 | import json
8 | import types
9 | import unittest
10 | import warnings
11 | import math
12 |
13 | import numpy as np
14 |
15 |
16 | def make_selector_fn(selector):
17 | """
18 | If selector is a function, return selector.
19 | Otherwise, return a function corresponding to the selector string. Examples
20 | of valid selector strings and the corresponding functions:
21 | x lambda obj: obj['x']
22 | x.y lambda obj: obj['x']['y']
23 | x,y lambda obj: (obj['x'], obj['y'])
24 | """
25 | if isinstance(selector, str):
26 | if ',' in selector:
27 | parts = selector.split(',')
28 | part_selectors = [make_selector_fn(part) for part in parts]
29 | return lambda obj: tuple(sel(obj) for sel in part_selectors)
30 | elif '.' in selector:
31 | parts = selector.split('.')
32 | part_selectors = [make_selector_fn(part) for part in parts]
33 | def f(obj):
34 | for sel in part_selectors:
35 | obj = sel(obj)
36 | return obj
37 | return f
38 | else:
39 | key = selector.strip()
40 | return lambda obj: obj[key]
41 | elif isinstance(selector, types.FunctionType):
42 | return selector
43 | else:
44 | raise TypeError
45 |
46 | def hashable(obj):
47 | try:
48 | hash(obj)
49 | return obj
50 | except TypeError:
51 | return json.dumps({'_':obj}, sort_keys=True)
52 |
53 | class Q(object):
54 | def __init__(self, list_):
55 | super(Q, self).__init__()
56 | self._list = list_
57 |
58 | def __len__(self):
59 | return len(self._list)
60 |
61 | def __getitem__(self, key):
62 | return self._list[key]
63 |
64 | def __eq__(self, other):
65 | if isinstance(other, self.__class__):
66 | return self._list == other._list
67 | else:
68 | return self._list == other
69 |
70 | def __str__(self):
71 | return str(self._list)
72 |
73 | def __repr__(self):
74 | return repr(self._list)
75 |
76 | def _append(self, item):
77 | """Unsafe, be careful you know what you're doing."""
78 | self._list.append(item)
79 |
80 | def group(self, selector):
81 | """
82 | Group elements by selector and return a list of (group, group_records)
83 | tuples.
84 | """
85 | selector = make_selector_fn(selector)
86 | groups = {}
87 | for x in self._list:
88 | group = selector(x)
89 | group_key = hashable(group)
90 | if group_key not in groups:
91 | groups[group_key] = (group, Q([]))
92 | groups[group_key][1]._append(x)
93 | results = [groups[key] for key in sorted(groups.keys())]
94 | return Q(results)
95 |
96 | def group_map(self, selector, fn):
97 | """
98 | Group elements by selector, apply fn to each group, and return a list
99 | of the results.
100 | """
101 | return self.group(selector).map(fn)
102 |
103 | def map(self, fn):
104 | """
105 | map self onto fn. If fn takes multiple args, tuple-unpacking
106 | is applied.
107 | """
108 | if len(inspect.signature(fn).parameters) > 1:
109 | return Q([fn(*x) for x in self._list])
110 | else:
111 | return Q([fn(x) for x in self._list])
112 |
113 | def select(self, selector):
114 | selector = make_selector_fn(selector)
115 | return Q([selector(x) for x in self._list])
116 |
117 | def min(self):
118 | return min(self._list)
119 |
120 | def max(self):
121 | return max(self._list)
122 |
123 | def sum(self):
124 | return sum(self._list)
125 |
126 | def len(self):
127 | return len(self._list)
128 |
129 | def mean(self):
130 | with warnings.catch_warnings():
131 | warnings.simplefilter("ignore")
132 | return float(np.mean(self._list))
133 |
134 | def std(self):
135 | with warnings.catch_warnings():
136 | warnings.simplefilter("ignore")
137 | return float(np.std(self._list))
138 |
139 | def mean_std(self):
140 | return (self.mean(), self.std())
141 |
142 | def argmax(self, selector):
143 | selector = make_selector_fn(selector)
144 | return max(self._list, key=selector)
145 |
146 | def filter(self, fn):
147 | return Q([x for x in self._list if fn(x)])
148 |
149 | def filter_equals(self, selector, value):
150 | """like [x for x in y if x.selector == value]"""
151 | selector = make_selector_fn(selector)
152 | return self.filter(lambda r: selector(r) == value)
153 |
154 | def filter_not_none(self):
155 | return self.filter(lambda r: r is not None)
156 |
157 | def filter_not_nan(self):
158 | return self.filter(lambda r: not np.isnan(r))
159 |
160 | def flatten(self):
161 | return Q([y for x in self._list for y in x])
162 |
163 | def unique(self):
164 | result = []
165 | result_set = set()
166 | for x in self._list:
167 | hashable_x = hashable(x)
168 | if hashable_x not in result_set:
169 | result_set.add(hashable_x)
170 | result.append(x)
171 | return Q(result)
172 |
173 | def sorted(self, key=None):
174 | if key is None:
175 | key = lambda x: x
176 | def key2(x):
177 | x = key(x)
178 | if isinstance(x, (np.floating, float)) and np.isnan(x):
179 | return float('-inf')
180 | else:
181 | return x
182 | return Q(sorted(self._list, key=key2))
183 |
--------------------------------------------------------------------------------
/domainbed/lib/reporting.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import collections
4 |
5 | import json
6 | import os
7 |
8 | import tqdm
9 |
10 | from domainbed.lib.query import Q
11 |
12 | def load_records(path):
13 | records = []
14 | for i, subdir in tqdm.tqdm(list(enumerate(os.listdir(path))),
15 | ncols=80,
16 | leave=False):
17 | results_path = os.path.join(path, subdir, "results.jsonl")
18 | try:
19 | with open(results_path, "r") as f:
20 | for line in f:
21 | records.append(json.loads(line[:-1]))
22 | except IOError:
23 | pass
24 |
25 | return Q(records)
26 |
27 | def get_grouped_records(records):
28 | """Group records by (trial_seed, dataset, algorithm, test_env). Because
29 | records can have multiple test envs, a given record may appear in more than
30 | one group."""
31 | result = collections.defaultdict(lambda: [])
32 | for r in records:
33 | for test_env in r["args"]["test_envs"]:
34 | group = (r["args"]["trial_seed"],
35 | r["args"]["dataset"],
36 | r["args"]["algorithm"],
37 | test_env)
38 | result[group].append(r)
39 | return Q([{"trial_seed": t, "dataset": d, "algorithm": a, "test_env": e,
40 | "records": Q(r)} for (t,d,a,e),r in result.items()])
41 |
--------------------------------------------------------------------------------
/domainbed/lib/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | From https://github.com/meliketoy/wide-resnet.pytorch
5 | """
6 |
7 | import sys
8 |
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import torch.nn.init as init
14 | from torch.autograd import Variable
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | return nn.Conv2d(
19 | in_planes,
20 | out_planes,
21 | kernel_size=3,
22 | stride=stride,
23 | padding=1,
24 | bias=True)
25 |
26 |
27 | def conv_init(m):
28 | classname = m.__class__.__name__
29 | if classname.find('Conv') != -1:
30 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
31 | init.constant_(m.bias, 0)
32 | elif classname.find('BatchNorm') != -1:
33 | init.constant_(m.weight, 1)
34 | init.constant_(m.bias, 0)
35 |
36 |
37 | class wide_basic(nn.Module):
38 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
39 | super(wide_basic, self).__init__()
40 | self.bn1 = nn.BatchNorm2d(in_planes)
41 | self.conv1 = nn.Conv2d(
42 | in_planes, planes, kernel_size=3, padding=1, bias=True)
43 | self.dropout = nn.Dropout(p=dropout_rate)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.conv2 = nn.Conv2d(
46 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
47 |
48 | self.shortcut = nn.Sequential()
49 | if stride != 1 or in_planes != planes:
50 | self.shortcut = nn.Sequential(
51 | nn.Conv2d(
52 | in_planes, planes, kernel_size=1, stride=stride,
53 | bias=True), )
54 |
55 | def forward(self, x):
56 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
57 | out = self.conv2(F.relu(self.bn2(out)))
58 | out += self.shortcut(x)
59 |
60 | return out
61 |
62 |
63 | class Wide_ResNet(nn.Module):
64 | """Wide Resnet with the softmax layer chopped off"""
65 | def __init__(self, input_shape, depth, widen_factor, dropout_rate):
66 | super(Wide_ResNet, self).__init__()
67 | self.in_planes = 16
68 |
69 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
70 | n = (depth - 4) / 6
71 | k = widen_factor
72 |
73 | # print('| Wide-Resnet %dx%d' % (depth, k))
74 | nStages = [16, 16 * k, 32 * k, 64 * k]
75 |
76 | self.conv1 = conv3x3(input_shape[0], nStages[0])
77 | self.layer1 = self._wide_layer(
78 | wide_basic, nStages[1], n, dropout_rate, stride=1)
79 | self.layer2 = self._wide_layer(
80 | wide_basic, nStages[2], n, dropout_rate, stride=2)
81 | self.layer3 = self._wide_layer(
82 | wide_basic, nStages[3], n, dropout_rate, stride=2)
83 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
84 |
85 | self.n_outputs = nStages[3]
86 |
87 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
88 | strides = [stride] + [1] * (int(num_blocks) - 1)
89 | layers = []
90 |
91 | for stride in strides:
92 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
93 | self.in_planes = planes
94 |
95 | return nn.Sequential(*layers)
96 |
97 | def forward(self, x):
98 | out = self.conv1(x)
99 | out = self.layer1(out)
100 | out = self.layer2(out)
101 | out = self.layer3(out)
102 | out = F.relu(self.bn1(out))
103 | out = F.avg_pool2d(out, 8)
104 | return out[:, :, 0, 0]
105 |
--------------------------------------------------------------------------------
/domainbed/model_selection.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import itertools
4 | import numpy as np
5 |
6 | def get_test_records(records):
7 | """Given records with a common test env, get the test records (i.e. the
8 | records with *only* that single test env and no other test envs)"""
9 | return records.filter(lambda r: len(r['args']['test_envs']) == 1)
10 |
11 | class SelectionMethod:
12 | """Abstract class whose subclasses implement strategies for model
13 | selection across hparams and timesteps."""
14 |
15 | def __init__(self):
16 | raise TypeError
17 |
18 | @classmethod
19 | def run_acc(self, run_records):
20 | """
21 | Given records from a run, return a {val_acc, test_acc} dict representing
22 | the best val-acc and corresponding test-acc for that run.
23 | """
24 | raise NotImplementedError
25 |
26 | @classmethod
27 | def hparams_accs(self, records):
28 | """
29 | Given all records from a single (dataset, algorithm, test env) pair,
30 | return a sorted list of (run_acc, records) tuples.
31 | """
32 | return (records.group('args.hparams_seed')
33 | .map(lambda _, run_records:
34 | (
35 | self.run_acc(run_records),
36 | run_records
37 | )
38 | ).filter(lambda x: x[0] is not None)
39 | .sorted(key=lambda x: x[0]['val_acc'])[::-1]
40 | )
41 |
42 | @classmethod
43 | def sweep_acc(self, records):
44 | """
45 | Given all records from a single (dataset, algorithm, test env) pair,
46 | return the mean test acc of the k runs with the top val accs.
47 | """
48 | _hparams_accs = self.hparams_accs(records)
49 | if len(_hparams_accs):
50 | return _hparams_accs[0][0]['test_acc']
51 | else:
52 | return None
53 |
54 | class OracleSelectionMethod(SelectionMethod):
55 | """Like Selection method which picks argmax(test_out_acc) across all hparams
56 | and checkpoints, but instead of taking the argmax over all
57 | checkpoints, we pick the last checkpoint, i.e. no early stopping."""
58 | name = "test-domain validation set (oracle)"
59 |
60 | @classmethod
61 | def run_acc(self, run_records):
62 | run_records = run_records.filter(lambda r:
63 | len(r['args']['test_envs']) == 1)
64 | if not len(run_records):
65 | return None
66 | test_env = run_records[0]['args']['test_envs'][0]
67 | test_out_acc_key = 'env{}_out_acc'.format(test_env)
68 | test_in_acc_key = 'env{}_in_acc'.format(test_env)
69 | chosen_record = run_records.sorted(lambda r: r['step'])[-1]
70 | return {
71 | 'val_acc': chosen_record[test_out_acc_key],
72 | 'test_acc': chosen_record[test_in_acc_key]
73 | }
74 |
75 | class IIDAccuracySelectionMethod(SelectionMethod):
76 | """Picks argmax(mean(env_out_acc for env in train_envs))"""
77 | name = "training-domain validation set"
78 |
79 | @classmethod
80 | def _step_acc(self, record):
81 | """Given a single record, return a {val_acc, test_acc} dict."""
82 | test_env = record['args']['test_envs'][0]
83 | val_env_keys = []
84 | for i in itertools.count():
85 | if f'env{i}_out_acc' not in record:
86 | break
87 | if i != test_env:
88 | val_env_keys.append(f'env{i}_out_acc')
89 | test_in_acc_key = 'env{}_in_acc'.format(test_env)
90 | return {
91 | 'val_acc': np.mean([record[key] for key in val_env_keys]),
92 | 'test_acc': record[test_in_acc_key]
93 | }
94 |
95 | @classmethod
96 | def run_acc(self, run_records):
97 | test_records = get_test_records(run_records)
98 | if not len(test_records):
99 | return None
100 | return test_records.map(self._step_acc).argmax('val_acc')
101 |
102 | class LeaveOneOutSelectionMethod(SelectionMethod):
103 | """Picks (hparams, step) by leave-one-out cross validation."""
104 | name = "leave-one-domain-out cross-validation"
105 |
106 | @classmethod
107 | def _step_acc(self, records):
108 | """Return the {val_acc, test_acc} for a group of records corresponding
109 | to a single step."""
110 | test_records = get_test_records(records)
111 | if len(test_records) != 1:
112 | return None
113 |
114 | test_env = test_records[0]['args']['test_envs'][0]
115 | n_envs = 0
116 | for i in itertools.count():
117 | if f'env{i}_out_acc' not in records[0]:
118 | break
119 | n_envs += 1
120 | val_accs = np.zeros(n_envs) - 1
121 | for r in records.filter(lambda r: len(r['args']['test_envs']) == 2):
122 | val_env = (set(r['args']['test_envs']) - set([test_env])).pop()
123 | val_accs[val_env] = r['env{}_in_acc'.format(val_env)]
124 | val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:])
125 | if any([v==-1 for v in val_accs]):
126 | return None
127 | val_acc = np.sum(val_accs) / (n_envs-1)
128 | return {
129 | 'val_acc': val_acc,
130 | 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)]
131 | }
132 |
133 | @classmethod
134 | def run_acc(self, records):
135 | step_accs = records.group('step').map(lambda step, step_records:
136 | self._step_acc(step_records)
137 | ).filter_not_none()
138 | if len(step_accs):
139 | return step_accs.argmax('val_acc')
140 | else:
141 | return None
142 |
--------------------------------------------------------------------------------
/domainbed/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision.models
7 |
8 | from domainbed.lib import wide_resnet
9 | import copy
10 |
11 |
12 | def remove_batch_norm_from_resnet(model):
13 | fuse = torch.nn.utils.fusion.fuse_conv_bn_eval
14 | model.eval()
15 |
16 | model.conv1 = fuse(model.conv1, model.bn1)
17 | model.bn1 = Identity()
18 |
19 | for name, module in model.named_modules():
20 | if name.startswith("layer") and len(name) == 6:
21 | for b, bottleneck in enumerate(module):
22 | for name2, module2 in bottleneck.named_modules():
23 | if name2.startswith("conv"):
24 | bn_name = "bn" + name2[-1]
25 | setattr(bottleneck, name2,
26 | fuse(module2, getattr(bottleneck, bn_name)))
27 | setattr(bottleneck, bn_name, Identity())
28 | if isinstance(bottleneck.downsample, torch.nn.Sequential):
29 | bottleneck.downsample[0] = fuse(bottleneck.downsample[0],
30 | bottleneck.downsample[1])
31 | bottleneck.downsample[1] = Identity()
32 | model.train()
33 | return model
34 |
35 |
36 | class Identity(nn.Module):
37 | """An identity layer"""
38 | def __init__(self):
39 | super(Identity, self).__init__()
40 |
41 | def forward(self, x):
42 | return x
43 |
44 |
45 | class MLP(nn.Module):
46 | """Just an MLP"""
47 | def __init__(self, n_inputs, n_outputs, hparams):
48 | super(MLP, self).__init__()
49 | self.input = nn.Linear(n_inputs, hparams['mlp_width'])
50 | self.dropout = nn.Dropout(hparams['mlp_dropout'])
51 | self.hiddens = nn.ModuleList([
52 | nn.Linear(hparams['mlp_width'], hparams['mlp_width'])
53 | for _ in range(hparams['mlp_depth']-2)])
54 | self.output = nn.Linear(hparams['mlp_width'], n_outputs)
55 | self.n_outputs = n_outputs
56 |
57 | def forward(self, x):
58 | x = self.input(x)
59 | x = self.dropout(x)
60 | x = F.relu(x)
61 | for hidden in self.hiddens:
62 | x = hidden(x)
63 | x = self.dropout(x)
64 | x = F.relu(x)
65 | x = self.output(x)
66 | return x
67 |
68 |
69 | class ResNet(torch.nn.Module):
70 | """ResNet with the softmax chopped off and the batchnorm frozen"""
71 | def __init__(self, input_shape, hparams):
72 | super(ResNet, self).__init__()
73 | if hparams['resnet18']:
74 | self.network = torchvision.models.resnet18(pretrained=True)
75 | self.n_outputs = 512
76 | else:
77 | self.network = torchvision.models.resnet50(pretrained=True)
78 | self.n_outputs = 2048
79 |
80 | # self.network = remove_batch_norm_from_resnet(self.network)
81 |
82 | # adapt number of channels
83 | nc = input_shape[0]
84 | if nc != 3:
85 | tmp = self.network.conv1.weight.data.clone()
86 |
87 | self.network.conv1 = nn.Conv2d(
88 | nc, 64, kernel_size=(7, 7),
89 | stride=(2, 2), padding=(3, 3), bias=False)
90 |
91 | for i in range(nc):
92 | self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :]
93 |
94 | # save memory
95 | del self.network.fc
96 | self.network.fc = Identity()
97 |
98 | self.freeze_bn()
99 | self.hparams = hparams
100 | self.dropout = nn.Dropout(hparams['resnet_dropout'])
101 |
102 | def forward(self, x):
103 | """Encode x into a feature vector of size n_outputs."""
104 | return self.dropout(self.network(x))
105 |
106 | def train(self, mode=True):
107 | """
108 | Override the default train() to freeze the BN parameters
109 | """
110 | super().train(mode)
111 | self.freeze_bn()
112 |
113 | def freeze_bn(self):
114 | for m in self.network.modules():
115 | if isinstance(m, nn.BatchNorm2d):
116 | m.eval()
117 |
118 |
119 | class MNIST_CNN(nn.Module):
120 | """
121 | Hand-tuned architecture for MNIST.
122 | Weirdness I've noticed so far with this architecture:
123 | - adding a linear layer after the mean-pool in features hurts
124 | RotatedMNIST-100 generalization severely.
125 | """
126 | n_outputs = 128
127 |
128 | def __init__(self, input_shape):
129 | super(MNIST_CNN, self).__init__()
130 | self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1)
131 | self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
132 | self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
133 | self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)
134 |
135 | self.bn0 = nn.GroupNorm(8, 64)
136 | self.bn1 = nn.GroupNorm(8, 128)
137 | self.bn2 = nn.GroupNorm(8, 128)
138 | self.bn3 = nn.GroupNorm(8, 128)
139 |
140 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
141 |
142 | def forward(self, x):
143 | x = self.conv1(x)
144 | x = F.relu(x)
145 | x = self.bn0(x)
146 |
147 | x = self.conv2(x)
148 | x = F.relu(x)
149 | x = self.bn1(x)
150 |
151 | x = self.conv3(x)
152 | x = F.relu(x)
153 | x = self.bn2(x)
154 |
155 | x = self.conv4(x)
156 | x = F.relu(x)
157 | x = self.bn3(x)
158 |
159 | x = self.avgpool(x)
160 | x = x.view(len(x), -1)
161 | return x
162 |
163 |
164 | class ContextNet(nn.Module):
165 | def __init__(self, input_shape):
166 | super(ContextNet, self).__init__()
167 |
168 | # Keep same dimensions
169 | padding = (5 - 1) // 2
170 | self.context_net = nn.Sequential(
171 | nn.Conv2d(input_shape[0], 64, 5, padding=padding),
172 | nn.BatchNorm2d(64),
173 | nn.ReLU(),
174 | nn.Conv2d(64, 64, 5, padding=padding),
175 | nn.BatchNorm2d(64),
176 | nn.ReLU(),
177 | nn.Conv2d(64, 1, 5, padding=padding),
178 | )
179 |
180 | def forward(self, x):
181 | return self.context_net(x)
182 |
183 |
184 | def Featurizer(input_shape, hparams):
185 | """Auto-select an appropriate featurizer for the given input shape."""
186 | if len(input_shape) == 1:
187 | return MLP(input_shape[0], hparams["mlp_width"], hparams)
188 | elif input_shape[1:3] == (28, 28):
189 | return MNIST_CNN(input_shape)
190 | elif input_shape[1:3] == (32, 32):
191 | return wide_resnet.Wide_ResNet(input_shape, 16, 2, 0.)
192 | elif input_shape[1:3] == (224, 224):
193 | return ResNet(input_shape, hparams)
194 | else:
195 | raise NotImplementedError
196 |
197 |
198 | def Classifier(in_features, out_features, is_nonlinear=False):
199 | if is_nonlinear:
200 | return torch.nn.Sequential(
201 | torch.nn.Linear(in_features, in_features // 2),
202 | torch.nn.ReLU(),
203 | torch.nn.Linear(in_features // 2, in_features // 4),
204 | torch.nn.ReLU(),
205 | torch.nn.Linear(in_features // 4, out_features))
206 | else:
207 | return torch.nn.Linear(in_features, out_features)
208 |
209 |
210 | class WholeFish(nn.Module):
211 | def __init__(self, input_shape, num_classes, hparams, weights=None):
212 | super(WholeFish, self).__init__()
213 | featurizer = Featurizer(input_shape, hparams)
214 | classifier = Classifier(
215 | featurizer.n_outputs,
216 | num_classes,
217 | hparams['nonlinear_classifier'])
218 | self.net = nn.Sequential(
219 | featurizer, classifier
220 | )
221 | if weights is not None:
222 | self.load_state_dict(copy.deepcopy(weights))
223 |
224 | def reset_weights(self, weights):
225 | self.load_state_dict(copy.deepcopy(weights))
226 |
227 | def forward(self, x):
228 | return self.net(x)
229 |
--------------------------------------------------------------------------------
/domainbed/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 |
--------------------------------------------------------------------------------
/domainbed/scripts/collect_results.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import collections
4 |
5 |
6 | import argparse
7 | import functools
8 | import glob
9 | import pickle
10 | import itertools
11 | import json
12 | import os
13 | import random
14 | import sys
15 |
16 | import numpy as np
17 | import tqdm
18 |
19 | from domainbed import datasets
20 | from domainbed import algorithms
21 | from domainbed.lib import misc, reporting
22 | from domainbed import model_selection
23 | from domainbed.lib.query import Q
24 | import warnings
25 |
26 | def format_mean(data, latex):
27 | """Given a list of datapoints, return a string describing their mean and
28 | standard error"""
29 | if len(data) == 0:
30 | return None, None, "X"
31 | mean = 100 * np.mean(list(data))
32 | err = 100 * np.std(list(data) / np.sqrt(len(data)))
33 | if latex:
34 | return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err)
35 | else:
36 | return mean, err, "{:.1f} +/- {:.1f}".format(mean, err)
37 |
38 | def print_table(table, header_text, row_labels, col_labels, colwidth=10,
39 | latex=True):
40 | """Pretty-print a 2D array of data, optionally with row/col labels"""
41 | print("")
42 |
43 | if latex:
44 | num_cols = len(table[0])
45 | print("\\begin{center}")
46 | print("\\adjustbox{max width=\\textwidth}{%")
47 | print("\\begin{tabular}{l" + "c" * num_cols + "}")
48 | print("\\toprule")
49 | else:
50 | print("--------", header_text)
51 |
52 | for row, label in zip(table, row_labels):
53 | row.insert(0, label)
54 |
55 | if latex:
56 | col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}"
57 | for col_label in col_labels]
58 | table.insert(0, col_labels)
59 |
60 | for r, row in enumerate(table):
61 | misc.print_row(row, colwidth=colwidth, latex=latex)
62 | if latex and r == 0:
63 | print("\\midrule")
64 | if latex:
65 | print("\\bottomrule")
66 | print("\\end{tabular}}")
67 | print("\\end{center}")
68 |
69 | def print_results_tables(records, selection_method, latex):
70 | """Given all records, print a results table for each dataset."""
71 | grouped_records = reporting.get_grouped_records(records).map(lambda group:
72 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
73 | ).filter(lambda g: g["sweep_acc"] is not None)
74 |
75 | # read algorithm names and sort (predefined order)
76 | alg_names = Q(records).select("args.algorithm").unique()
77 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
78 | [n for n in alg_names if n not in algorithms.ALGORITHMS])
79 |
80 | # read dataset names and sort (lexicographic order)
81 | dataset_names = Q(records).select("args.dataset").unique().sorted()
82 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names]
83 |
84 | for dataset in dataset_names:
85 | if latex:
86 | print()
87 | print("\\subsubsection{{{}}}".format(dataset))
88 | test_envs = range(datasets.num_environments(dataset))
89 |
90 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names]
91 | for i, algorithm in enumerate(alg_names):
92 | means = []
93 | for j, test_env in enumerate(test_envs):
94 | trial_accs = (grouped_records
95 | .filter_equals(
96 | "dataset, algorithm, test_env",
97 | (dataset, algorithm, test_env)
98 | ).select("sweep_acc"))
99 | mean, err, table[i][j] = format_mean(trial_accs, latex)
100 | means.append(mean)
101 | if None in means:
102 | table[i][-1] = "X"
103 | else:
104 | table[i][-1] = "{:.1f}".format(sum(means) / len(means))
105 |
106 | col_labels = [
107 | "Algorithm",
108 | *datasets.get_dataset_class(dataset).ENVIRONMENTS,
109 | "Avg"
110 | ]
111 | header_text = (f"Dataset: {dataset}, "
112 | f"model selection method: {selection_method.name}")
113 | print_table(table, header_text, alg_names, list(col_labels),
114 | colwidth=20, latex=latex)
115 |
116 | # Print an "averages" table
117 | if latex:
118 | print()
119 | print("\\subsubsection{Averages}")
120 |
121 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names]
122 | for i, algorithm in enumerate(alg_names):
123 | means = []
124 | for j, dataset in enumerate(dataset_names):
125 | trial_averages = (grouped_records
126 | .filter_equals("algorithm, dataset", (algorithm, dataset))
127 | .group("trial_seed")
128 | .map(lambda trial_seed, group:
129 | group.select("sweep_acc").mean()
130 | )
131 | )
132 | mean, err, table[i][j] = format_mean(trial_averages, latex)
133 | means.append(mean)
134 | if None in means:
135 | table[i][-1] = "X"
136 | else:
137 | table[i][-1] = "{:.1f}".format(sum(means) / len(means))
138 |
139 | col_labels = ["Algorithm", *dataset_names, "Avg"]
140 | header_text = f"Averages, model selection method: {selection_method.name}"
141 | print_table(table, header_text, alg_names, col_labels, colwidth=25,
142 | latex=latex)
143 |
144 | if __name__ == "__main__":
145 | np.set_printoptions(suppress=True)
146 |
147 | parser = argparse.ArgumentParser(
148 | description="Domain generalization testbed")
149 | parser.add_argument("--input_dir", type=str, required=True)
150 | parser.add_argument("--latex", action="store_true")
151 | args = parser.parse_args()
152 |
153 | results_file = "results.tex" if args.latex else "results.txt"
154 |
155 | sys.stdout = misc.Tee(os.path.join(args.input_dir, results_file), "w")
156 |
157 | records = reporting.load_records(args.input_dir)
158 |
159 | if args.latex:
160 | print("\\documentclass{article}")
161 | print("\\usepackage{booktabs}")
162 | print("\\usepackage{adjustbox}")
163 | print("\\begin{document}")
164 | print("\\section{Full DomainBed results}")
165 | print("% Total records:", len(records))
166 | else:
167 | print("Total records:", len(records))
168 |
169 | SELECTION_METHODS = [
170 | model_selection.IIDAccuracySelectionMethod,
171 | model_selection.LeaveOneOutSelectionMethod,
172 | model_selection.OracleSelectionMethod,
173 | ]
174 |
175 | for selection_method in SELECTION_METHODS:
176 | if args.latex:
177 | print()
178 | print("\\subsection{{Model selection: {}}}".format(
179 | selection_method.name))
180 | print_results_tables(records, selection_method, args.latex)
181 |
182 | if args.latex:
183 | print("\\end{document}")
184 |
--------------------------------------------------------------------------------
/domainbed/scripts/download.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | from torchvision.datasets import MNIST
4 | import xml.etree.ElementTree as ET
5 | from zipfile import ZipFile
6 | import argparse
7 | import tarfile
8 | import shutil
9 | import gdown
10 | import uuid
11 | import json
12 | import os
13 |
14 | # from wilds.datasets.camelyon17_dataset import Camelyon17Dataset
15 | # from wilds.datasets.fmow_dataset import FMoWDataset
16 |
17 |
18 | # utils #######################################################################
19 |
20 | def stage_path(data_dir, name):
21 | full_path = os.path.join(data_dir, name)
22 |
23 | if not os.path.exists(full_path):
24 | os.makedirs(full_path)
25 |
26 | return full_path
27 |
28 |
29 | def download_and_extract(url, dst, remove=True):
30 | gdown.download(url, dst, quiet=False)
31 |
32 | if dst.endswith(".tar.gz"):
33 | tar = tarfile.open(dst, "r:gz")
34 | tar.extractall(os.path.dirname(dst))
35 | tar.close()
36 |
37 | if dst.endswith(".tar"):
38 | tar = tarfile.open(dst, "r:")
39 | tar.extractall(os.path.dirname(dst))
40 | tar.close()
41 |
42 | if dst.endswith(".zip"):
43 | zf = ZipFile(dst, "r")
44 | zf.extractall(os.path.dirname(dst))
45 | zf.close()
46 |
47 | if remove:
48 | os.remove(dst)
49 |
50 |
51 | # VLCS ########################################################################
52 |
53 | # Slower, but builds dataset from the original sources
54 | #
55 | # def download_vlcs(data_dir):
56 | # full_path = stage_path(data_dir, "VLCS")
57 | #
58 | # tmp_path = os.path.join(full_path, "tmp/")
59 | # if not os.path.exists(tmp_path):
60 | # os.makedirs(tmp_path)
61 | #
62 | # with open("domainbed/misc/vlcs_files.txt", "r") as f:
63 | # lines = f.readlines()
64 | # files = [line.strip().split() for line in lines]
65 | #
66 | # download_and_extract("http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar",
67 | # os.path.join(tmp_path, "voc2007_trainval.tar"))
68 | #
69 | # download_and_extract("https://drive.google.com/uc?id=1I8ydxaAQunz9R_qFFdBFtw6rFTUW9goz",
70 | # os.path.join(tmp_path, "caltech101.tar.gz"))
71 | #
72 | # download_and_extract("http://groups.csail.mit.edu/vision/Hcontext/data/sun09_hcontext.tar",
73 | # os.path.join(tmp_path, "sun09_hcontext.tar"))
74 | #
75 | # tar = tarfile.open(os.path.join(tmp_path, "sun09.tar"), "r:")
76 | # tar.extractall(tmp_path)
77 | # tar.close()
78 | #
79 | # for src, dst in files:
80 | # class_folder = os.path.join(data_dir, dst)
81 | #
82 | # if not os.path.exists(class_folder):
83 | # os.makedirs(class_folder)
84 | #
85 | # dst = os.path.join(class_folder, uuid.uuid4().hex + ".jpg")
86 | #
87 | # if "labelme" in src:
88 | # # download labelme from the web
89 | # gdown.download(src, dst, quiet=False)
90 | # else:
91 | # src = os.path.join(tmp_path, src)
92 | # shutil.copyfile(src, dst)
93 | #
94 | # shutil.rmtree(tmp_path)
95 |
96 |
97 | def download_vlcs(data_dir):
98 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017
99 | full_path = stage_path(data_dir, "VLCS")
100 |
101 | download_and_extract("https://drive.google.com/uc?id=1skwblH1_okBwxWxmRsp9_qi15hyPpxg8",
102 | os.path.join(data_dir, "VLCS.tar.gz"))
103 |
104 |
105 | # MNIST #######################################################################
106 |
107 | def download_mnist(data_dir):
108 | # Original URL: http://yann.lecun.com/exdb/mnist/
109 | full_path = stage_path(data_dir, "MNIST")
110 | MNIST(full_path, download=True)
111 |
112 |
113 | # PACS ########################################################################
114 |
115 | def download_pacs(data_dir):
116 | # Original URL: http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017
117 | full_path = stage_path(data_dir, "PACS")
118 |
119 | download_and_extract("https://drive.google.com/uc?id=0B6x7gtvErXgfbF9CSk53UkRxVzg",
120 | os.path.join(data_dir, "PACS.zip"))
121 |
122 | os.rename(os.path.join(data_dir, "kfold"),
123 | full_path)
124 |
125 |
126 | # Office-Home #################################################################
127 |
128 | def download_office_home(data_dir):
129 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/
130 | full_path = stage_path(data_dir, "office_home")
131 |
132 | download_and_extract("https://drive.google.com/uc?id=0B81rNlvomiwed0V1YUxQdC1uOTg",
133 | os.path.join(data_dir, "office_home.zip"))
134 |
135 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"),
136 | full_path)
137 |
138 |
139 | # DomainNET ###################################################################
140 |
141 | def download_domain_net(data_dir):
142 | # Original URL: http://ai.bu.edu/M3SDA/
143 | full_path = stage_path(data_dir, "domain_net")
144 |
145 | urls = [
146 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip",
147 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip",
148 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip",
149 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip",
150 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip",
151 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip"
152 | ]
153 |
154 | for url in urls:
155 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1]))
156 |
157 | with open("domainbed/misc/domain_net_duplicates.txt", "r") as f:
158 | for line in f.readlines():
159 | try:
160 | os.remove(os.path.join(full_path, line.strip()))
161 | except OSError:
162 | pass
163 |
164 |
165 | # TerraIncognita ##############################################################
166 |
167 | def download_terra_incognita(data_dir):
168 | # Original URL: https://beerys.github.io/CaltechCameraTraps/
169 | # New URL: http://lila.science/datasets/caltech-camera-traps
170 |
171 | full_path = stage_path(data_dir, "terra_incognita")
172 |
173 | download_and_extract(
174 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/eccv_18_all_images_sm.tar.gz",
175 | os.path.join(full_path, "terra_incognita_images.tar.gz"))
176 |
177 | download_and_extract(
178 | "https://lilablobssc.blob.core.windows.net/caltechcameratraps/labels/caltech_camera_traps.json.zip",
179 | os.path.join(full_path, "caltech_camera_traps.json.zip"))
180 |
181 | include_locations = ["38", "46", "100", "43"]
182 |
183 | include_categories = [
184 | "bird", "bobcat", "cat", "coyote", "dog", "empty", "opossum", "rabbit",
185 | "raccoon", "squirrel"
186 | ]
187 |
188 | images_folder = os.path.join(full_path, "eccv_18_all_images_sm/")
189 | annotations_file = os.path.join(full_path, "caltech_images_20210113.json")
190 | destination_folder = full_path
191 |
192 | stats = {}
193 |
194 | if not os.path.exists(destination_folder):
195 | os.mkdir(destination_folder)
196 |
197 | with open(annotations_file, "r") as f:
198 | data = json.load(f)
199 |
200 | category_dict = {}
201 | for item in data['categories']:
202 | category_dict[item['id']] = item['name']
203 |
204 | for image in data['images']:
205 | image_location = image['location']
206 |
207 | if image_location not in include_locations:
208 | continue
209 |
210 | loc_folder = os.path.join(destination_folder,
211 | 'location_' + str(image_location) + '/')
212 |
213 | if not os.path.exists(loc_folder):
214 | os.mkdir(loc_folder)
215 |
216 | image_id = image['id']
217 | image_fname = image['file_name']
218 |
219 | for annotation in data['annotations']:
220 | if annotation['image_id'] == image_id:
221 | if image_location not in stats:
222 | stats[image_location] = {}
223 |
224 | category = category_dict[annotation['category_id']]
225 |
226 | if category not in include_categories:
227 | continue
228 |
229 | if category not in stats[image_location]:
230 | stats[image_location][category] = 0
231 | else:
232 | stats[image_location][category] += 1
233 |
234 | loc_cat_folder = os.path.join(loc_folder, category + '/')
235 |
236 | if not os.path.exists(loc_cat_folder):
237 | os.mkdir(loc_cat_folder)
238 |
239 | dst_path = os.path.join(loc_cat_folder, image_fname)
240 | src_path = os.path.join(images_folder, image_fname)
241 |
242 | shutil.copyfile(src_path, dst_path)
243 |
244 | shutil.rmtree(images_folder)
245 | os.remove(annotations_file)
246 |
247 |
248 | # # SVIRO #################################################################
249 |
250 | # def download_sviro(data_dir):
251 | # # Original URL: https://sviro.kl.dfki.de
252 | # full_path = stage_path(data_dir, "sviro")
253 |
254 | # download_and_extract("https://sviro.kl.dfki.de/?wpdmdl=1731",
255 | # os.path.join(data_dir, "sviro_grayscale_rectangle_classification.zip"))
256 |
257 | # os.rename(os.path.join(data_dir, "SVIRO_DOMAINBED"),
258 | # full_path)
259 |
260 |
261 | if __name__ == "__main__":
262 | parser = argparse.ArgumentParser(description='Download datasets')
263 | parser.add_argument('--data_dir', type=str, required=True)
264 | args = parser.parse_args()
265 |
266 | download_mnist(args.data_dir)
267 | download_pacs(args.data_dir)
268 | download_office_home(args.data_dir)
269 | download_domain_net(args.data_dir)
270 | download_vlcs(args.data_dir)
271 | download_terra_incognita(args.data_dir)
272 |
273 | # download_sviro(args.data_dir)
274 | # Camelyon17Dataset(root_dir=args.data_dir, download=True)
275 | # FMoWDataset(root_dir=args.data_dir, download=True)
276 |
--------------------------------------------------------------------------------
/domainbed/scripts/list_top_hparams.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | Example usage:
5 | python -u -m domainbed.scripts.list_top_hparams \
6 | --input_dir domainbed/misc/test_sweep_data --algorithm ERM \
7 | --dataset VLCS --test_env 0
8 | """
9 |
10 | import collections
11 |
12 |
13 | import argparse
14 | import functools
15 | import glob
16 | import pickle
17 | import itertools
18 | import json
19 | import os
20 | import random
21 | import sys
22 |
23 | import numpy as np
24 | import tqdm
25 |
26 | from domainbed import datasets
27 | from domainbed import algorithms
28 | from domainbed.lib import misc, reporting
29 | from domainbed import model_selection
30 | from domainbed.lib.query import Q
31 | import warnings
32 |
33 | def todo_rename(records, selection_method, latex):
34 |
35 | grouped_records = reporting.get_grouped_records(records).map(lambda group:
36 | { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
37 | ).filter(lambda g: g["sweep_acc"] is not None)
38 |
39 | # read algorithm names and sort (predefined order)
40 | alg_names = Q(records).select("args.algorithm").unique()
41 | alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
42 | [n for n in alg_names if n not in algorithms.ALGORITHMS])
43 |
44 | # read dataset names and sort (lexicographic order)
45 | dataset_names = Q(records).select("args.dataset").unique().sorted()
46 | dataset_names = [d for d in datasets.DATASETS if d in dataset_names]
47 |
48 | for dataset in dataset_names:
49 | if latex:
50 | print()
51 | print("\\subsubsection{{{}}}".format(dataset))
52 | test_envs = range(datasets.num_environments(dataset))
53 |
54 | table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names]
55 | for i, algorithm in enumerate(alg_names):
56 | means = []
57 | for j, test_env in enumerate(test_envs):
58 | trial_accs = (grouped_records
59 | .filter_equals(
60 | "dataset, algorithm, test_env",
61 | (dataset, algorithm, test_env)
62 | ).select("sweep_acc"))
63 | mean, err, table[i][j] = format_mean(trial_accs, latex)
64 | means.append(mean)
65 | if None in means:
66 | table[i][-1] = "X"
67 | else:
68 | table[i][-1] = "{:.1f}".format(sum(means) / len(means))
69 |
70 | col_labels = [
71 | "Algorithm",
72 | *datasets.get_dataset_class(dataset).ENVIRONMENTS,
73 | "Avg"
74 | ]
75 | header_text = (f"Dataset: {dataset}, "
76 | f"model selection method: {selection_method.name}")
77 | print_table(table, header_text, alg_names, list(col_labels),
78 | colwidth=20, latex=latex)
79 |
80 | # Print an "averages" table
81 | if latex:
82 | print()
83 | print("\\subsubsection{Averages}")
84 |
85 | table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names]
86 | for i, algorithm in enumerate(alg_names):
87 | means = []
88 | for j, dataset in enumerate(dataset_names):
89 | trial_averages = (grouped_records
90 | .filter_equals("algorithm, dataset", (algorithm, dataset))
91 | .group("trial_seed")
92 | .map(lambda trial_seed, group:
93 | group.select("sweep_acc").mean()
94 | )
95 | )
96 | mean, err, table[i][j] = format_mean(trial_averages, latex)
97 | means.append(mean)
98 | if None in means:
99 | table[i][-1] = "X"
100 | else:
101 | table[i][-1] = "{:.1f}".format(sum(means) / len(means))
102 |
103 | col_labels = ["Algorithm", *dataset_names, "Avg"]
104 | header_text = f"Averages, model selection method: {selection_method.name}"
105 | print_table(table, header_text, alg_names, col_labels, colwidth=25,
106 | latex=latex)
107 |
108 | if __name__ == "__main__":
109 | np.set_printoptions(suppress=True)
110 |
111 | parser = argparse.ArgumentParser(
112 | description="Domain generalization testbed")
113 | parser.add_argument("--input_dir", required=True)
114 | parser.add_argument('--dataset', required=True)
115 | parser.add_argument('--algorithm', required=True)
116 | parser.add_argument('--test_env', type=int, required=True)
117 | args = parser.parse_args()
118 |
119 | records = reporting.load_records(args.input_dir)
120 | print("Total records:", len(records))
121 |
122 | records = reporting.get_grouped_records(records)
123 | records = records.filter(
124 | lambda r:
125 | r['dataset'] == args.dataset and
126 | r['algorithm'] == args.algorithm and
127 | r['test_env'] == args.test_env
128 | )
129 |
130 | SELECTION_METHODS = [
131 | model_selection.IIDAccuracySelectionMethod,
132 | model_selection.LeaveOneOutSelectionMethod,
133 | model_selection.OracleSelectionMethod,
134 | ]
135 |
136 | for selection_method in SELECTION_METHODS:
137 | print(f'Model selection: {selection_method.name}')
138 |
139 | for group in records:
140 | print(f"trial_seed: {group['trial_seed']}")
141 | best_hparams = selection_method.hparams_accs(group['records'])
142 | for run_acc, hparam_records in best_hparams:
143 | print(f"\t{run_acc}")
144 | for r in hparam_records:
145 | assert(r['hparams'] == hparam_records[0]['hparams'])
146 | print("\t\thparams:")
147 | for k, v in sorted(hparam_records[0]['hparams'].items()):
148 | print('\t\t\t{}: {}'.format(k, v))
149 | print("\t\toutput_dirs:")
150 | output_dirs = hparam_records.select('args.output_dir').unique()
151 | for output_dir in output_dirs:
152 | print(f"\t\t\t{output_dir}")
--------------------------------------------------------------------------------
/domainbed/scripts/save_images.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | Save some representative images from each dataset to disk.
5 | """
6 | import random
7 | import torch
8 | import argparse
9 | from domainbed import hparams_registry
10 | from domainbed import datasets
11 | import imageio
12 | import os
13 | from tqdm import tqdm
14 |
15 | if __name__ == '__main__':
16 | parser = argparse.ArgumentParser(description='Domain generalization')
17 | parser.add_argument('--data_dir', type=str)
18 | parser.add_argument('--output_dir', type=str)
19 | args = parser.parse_args()
20 |
21 | os.makedirs(args.output_dir, exist_ok=True)
22 | datasets_to_save = ['OfficeHome', 'TerraIncognita', 'DomainNet', 'RotatedMNIST', 'ColoredMNIST', 'SVIRO']
23 |
24 | for dataset_name in tqdm(datasets_to_save):
25 | hparams = hparams_registry.default_hparams('ERM', dataset_name)
26 | dataset = datasets.get_dataset_class(dataset_name)(
27 | args.data_dir,
28 | list(range(datasets.num_environments(dataset_name))),
29 | hparams)
30 | for env_idx, env in enumerate(tqdm(dataset)):
31 | for i in tqdm(range(50)):
32 | idx = random.choice(list(range(len(env))))
33 | x, y = env[idx]
34 | while y > 10:
35 | idx = random.choice(list(range(len(env))))
36 | x, y = env[idx]
37 | if x.shape[0] == 2:
38 | x = torch.cat([x, torch.zeros_like(x)], dim=0)[:3,:,:]
39 | if x.min() < 0:
40 | mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]
41 | std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]
42 | x = (x * std) + mean
43 | assert(x.min() >= 0)
44 | assert(x.max() <= 1)
45 | x = (x * 255.99)
46 | x = x.numpy().astype('uint8').transpose(1,2,0)
47 | imageio.imwrite(
48 | os.path.join(args.output_dir,
49 | f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'),
50 | x)
51 |
--------------------------------------------------------------------------------
/domainbed/scripts/sweep.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | """
4 | Run sweeps
5 | """
6 |
7 | import argparse
8 | import copy
9 | import getpass
10 | import hashlib
11 | import json
12 | import os
13 | import random
14 | import shutil
15 | import time
16 | import uuid
17 |
18 | import numpy as np
19 | import torch
20 |
21 | from domainbed import datasets
22 | from domainbed import hparams_registry
23 | from domainbed import algorithms
24 | from domainbed.lib import misc
25 | from domainbed import command_launchers
26 |
27 | import tqdm
28 | import shlex
29 |
30 | class Job:
31 | NOT_LAUNCHED = 'Not launched'
32 | INCOMPLETE = 'Incomplete'
33 | DONE = 'Done'
34 |
35 | def __init__(self, train_args, sweep_output_dir):
36 | args_str = json.dumps(train_args, sort_keys=True)
37 | args_hash = hashlib.md5(args_str.encode('utf-8')).hexdigest()
38 | self.output_dir = os.path.join(sweep_output_dir, args_hash)
39 |
40 | self.train_args = copy.deepcopy(train_args)
41 | self.train_args['output_dir'] = self.output_dir
42 | command = ['python', '-m', 'domainbed.scripts.train']
43 | for k, v in sorted(self.train_args.items()):
44 | if isinstance(v, list):
45 | v = ' '.join([str(v_) for v_ in v])
46 | elif isinstance(v, str):
47 | v = shlex.quote(v)
48 | command.append(f'--{k} {v}')
49 | self.command_str = ' '.join(command)
50 |
51 | if os.path.exists(os.path.join(self.output_dir, 'done')):
52 | self.state = Job.DONE
53 | elif os.path.exists(self.output_dir):
54 | self.state = Job.INCOMPLETE
55 | else:
56 | self.state = Job.NOT_LAUNCHED
57 |
58 | def __str__(self):
59 | job_info = (self.train_args['dataset'],
60 | self.train_args['algorithm'],
61 | self.train_args['test_envs'],
62 | self.train_args['hparams_seed'])
63 | return '{}: {} {}'.format(
64 | self.state,
65 | self.output_dir,
66 | job_info)
67 |
68 | @staticmethod
69 | def launch(jobs, launcher_fn):
70 | print('Launching...')
71 | jobs = jobs.copy()
72 | np.random.shuffle(jobs)
73 | print('Making job directories:')
74 | for job in tqdm.tqdm(jobs, leave=False):
75 | os.makedirs(job.output_dir, exist_ok=True)
76 | commands = [job.command_str for job in jobs]
77 | launcher_fn(commands)
78 | print(f'Launched {len(jobs)} jobs!')
79 |
80 | @staticmethod
81 | def delete(jobs):
82 | print('Deleting...')
83 | for job in jobs:
84 | shutil.rmtree(job.output_dir)
85 | print(f'Deleted {len(jobs)} jobs!')
86 |
87 | def all_test_env_combinations(n):
88 | """
89 | For a dataset with n >= 3 envs, return all combinations of 1 and 2 test
90 | envs.
91 | """
92 | assert(n >= 3)
93 | for i in range(n):
94 | yield [i]
95 | for j in range(i+1, n):
96 | yield [i, j]
97 |
98 | def make_args_list(n_trials, dataset_names, algorithms, n_hparams_from, n_hparams, steps,
99 | data_dir, task, holdout_fraction, single_test_envs, hparams):
100 | args_list = []
101 | for trial_seed in range(n_trials):
102 | for dataset in dataset_names:
103 | for algorithm in algorithms:
104 | if single_test_envs:
105 | all_test_envs = [
106 | [i] for i in range(datasets.num_environments(dataset))]
107 | else:
108 | all_test_envs = all_test_env_combinations(
109 | datasets.num_environments(dataset))
110 | for test_envs in all_test_envs:
111 | for hparams_seed in range(n_hparams_from, n_hparams):
112 | train_args = {}
113 | train_args['dataset'] = dataset
114 | train_args['algorithm'] = algorithm
115 | train_args['test_envs'] = test_envs
116 | train_args['holdout_fraction'] = holdout_fraction
117 | train_args['hparams_seed'] = hparams_seed
118 | train_args['data_dir'] = data_dir
119 | train_args['task'] = task
120 | train_args['trial_seed'] = trial_seed
121 | train_args['seed'] = misc.seed_hash(dataset,
122 | algorithm, test_envs, hparams_seed, trial_seed)
123 | if steps is not None:
124 | train_args['steps'] = steps
125 | if hparams is not None:
126 | train_args['hparams'] = hparams
127 | args_list.append(train_args)
128 | return args_list
129 |
130 | def ask_for_confirmation():
131 | response = input('Are you sure? (y/n) ')
132 | if not response.lower().strip()[:1] == "y":
133 | print('Nevermind!')
134 | exit(0)
135 |
136 | DATASETS = [d for d in datasets.DATASETS if "Debug" not in d]
137 |
138 | if __name__ == "__main__":
139 | parser = argparse.ArgumentParser(description='Run a sweep')
140 | parser.add_argument('command', choices=['launch', 'delete_incomplete'])
141 | parser.add_argument('--datasets', nargs='+', type=str, default=DATASETS)
142 | parser.add_argument('--algorithms', nargs='+', type=str, default=algorithms.ALGORITHMS)
143 | parser.add_argument('--task', type=str, default="domain_generalization")
144 | parser.add_argument('--n_hparams_from', type=int, default=0)
145 | parser.add_argument('--n_hparams', type=int, default=20)
146 | parser.add_argument('--output_dir', type=str, required=True)
147 | parser.add_argument('--data_dir', type=str, required=True)
148 | parser.add_argument('--seed', type=int, default=0)
149 | parser.add_argument('--n_trials', type=int, default=3)
150 | parser.add_argument('--command_launcher', type=str, required=True)
151 | parser.add_argument('--steps', type=int, default=None)
152 | parser.add_argument('--hparams', type=str, default=None)
153 | parser.add_argument('--holdout_fraction', type=float, default=0.2)
154 | parser.add_argument('--single_test_envs', action='store_true')
155 | parser.add_argument('--skip_confirmation', action='store_true')
156 | args = parser.parse_args()
157 |
158 | args_list = make_args_list(
159 | n_trials=args.n_trials,
160 | dataset_names=args.datasets,
161 | algorithms=args.algorithms,
162 | n_hparams_from=args.n_hparams_from,
163 | n_hparams=args.n_hparams,
164 | steps=args.steps,
165 | data_dir=args.data_dir,
166 | task=args.task,
167 | holdout_fraction=args.holdout_fraction,
168 | single_test_envs=args.single_test_envs,
169 | hparams=args.hparams
170 | )
171 |
172 | jobs = [Job(train_args, args.output_dir) for train_args in args_list]
173 |
174 | for job in jobs:
175 | print(job)
176 | print("{} jobs: {} done, {} incomplete, {} not launched.".format(
177 | len(jobs),
178 | len([j for j in jobs if j.state == Job.DONE]),
179 | len([j for j in jobs if j.state == Job.INCOMPLETE]),
180 | len([j for j in jobs if j.state == Job.NOT_LAUNCHED]))
181 | )
182 |
183 | if args.command == 'launch':
184 | to_launch = [j for j in jobs if j.state == Job.NOT_LAUNCHED]
185 | print(f'About to launch {len(to_launch)} jobs.')
186 | if not args.skip_confirmation:
187 | ask_for_confirmation()
188 | launcher_fn = command_launchers.REGISTRY[args.command_launcher]
189 | Job.launch(to_launch, launcher_fn)
190 |
191 | elif args.command == 'delete_incomplete':
192 | to_delete = [j for j in jobs if j.state == Job.INCOMPLETE]
193 | print(f'About to delete {len(to_delete)} jobs.')
194 | if not args.skip_confirmation:
195 | ask_for_confirmation()
196 | Job.delete(to_delete)
197 |
--------------------------------------------------------------------------------
/domainbed/scripts/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import argparse
4 | import collections
5 | import json
6 | import os
7 | import random
8 | import sys
9 | import time
10 | import uuid
11 |
12 | import numpy as np
13 | import PIL
14 | import torch
15 | import torchvision
16 | import torch.utils.data
17 |
18 | from domainbed import datasets
19 | from domainbed import hparams_registry
20 | from domainbed import algorithms
21 | from domainbed.lib import misc
22 | from domainbed.lib.fast_data_loader import InfiniteDataLoader, FastDataLoader
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser(description='Domain generalization')
26 | parser.add_argument('--data_dir', type=str)
27 | parser.add_argument('--dataset', type=str, default="RotatedMNIST")
28 | parser.add_argument('--algorithm', type=str, default="ERM")
29 | parser.add_argument('--task', type=str, default="domain_generalization",
30 | choices=["domain_generalization", "domain_adaptation"])
31 | parser.add_argument('--hparams', type=str,
32 | help='JSON-serialized hparams dict')
33 | parser.add_argument('--hparams_seed', type=int, default=0,
34 | help='Seed for random hparams (0 means "default hparams")')
35 | parser.add_argument('--trial_seed', type=int, default=0,
36 | help='Trial number (used for seeding split_dataset and '
37 | 'random_hparams).')
38 | parser.add_argument('--seed', type=int, default=0,
39 | help='Seed for everything else')
40 | parser.add_argument('--steps', type=int, default=None,
41 | help='Number of steps. Default is dataset-dependent.')
42 | parser.add_argument('--checkpoint_freq', type=int, default=None,
43 | help='Checkpoint every N steps. Default is dataset-dependent.')
44 | parser.add_argument('--test_envs', type=int, nargs='+', default=[0])
45 | parser.add_argument('--output_dir', type=str, default="train_output")
46 | parser.add_argument('--holdout_fraction', type=float, default=0.2)
47 | parser.add_argument('--uda_holdout_fraction', type=float, default=0,
48 | help="For domain adaptation, % of test to use unlabeled for training.")
49 | parser.add_argument('--skip_model_save', action='store_true')
50 | parser.add_argument('--save_model_every_checkpoint', action='store_true')
51 | args = parser.parse_args()
52 |
53 | # If we ever want to implement checkpointing, just persist these values
54 | # every once in a while, and then load them from disk here.
55 | start_step = 0
56 | algorithm_dict = None
57 |
58 | os.makedirs(args.output_dir, exist_ok=True)
59 | sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt'))
60 | sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt'))
61 |
62 | print("Environment:")
63 | print("\tPython: {}".format(sys.version.split(" ")[0]))
64 | print("\tPyTorch: {}".format(torch.__version__))
65 | print("\tTorchvision: {}".format(torchvision.__version__))
66 | print("\tCUDA: {}".format(torch.version.cuda))
67 | print("\tCUDNN: {}".format(torch.backends.cudnn.version()))
68 | print("\tNumPy: {}".format(np.__version__))
69 | print("\tPIL: {}".format(PIL.__version__))
70 |
71 | print('Args:')
72 | for k, v in sorted(vars(args).items()):
73 | print('\t{}: {}'.format(k, v))
74 |
75 | if args.hparams_seed == 0:
76 | hparams = hparams_registry.default_hparams(args.algorithm, args.dataset)
77 | else:
78 | hparams = hparams_registry.random_hparams(args.algorithm, args.dataset,
79 | misc.seed_hash(args.hparams_seed, args.trial_seed))
80 | if args.hparams:
81 | hparams.update(json.loads(args.hparams))
82 |
83 | print('HParams:')
84 | for k, v in sorted(hparams.items()):
85 | print('\t{}: {}'.format(k, v))
86 |
87 | random.seed(args.seed)
88 | np.random.seed(args.seed)
89 | torch.manual_seed(args.seed)
90 | torch.backends.cudnn.deterministic = True
91 | torch.backends.cudnn.benchmark = False
92 |
93 | if torch.cuda.is_available():
94 | device = "cuda"
95 | else:
96 | device = "cpu"
97 |
98 | if args.dataset in vars(datasets):
99 | dataset = vars(datasets)[args.dataset](args.data_dir,
100 | args.test_envs, hparams)
101 | else:
102 | raise NotImplementedError
103 |
104 | # Split each env into an 'in-split' and an 'out-split'. We'll train on
105 | # each in-split except the test envs, and evaluate on all splits.
106 |
107 | # To allow unsupervised domain adaptation experiments, we split each test
108 | # env into 'in-split', 'uda-split' and 'out-split'. The 'in-split' is used
109 | # by collect_results.py to compute classification accuracies. The
110 | # 'out-split' is used by the Oracle model selectino method. The unlabeled
111 | # samples in 'uda-split' are passed to the algorithm at training time if
112 | # args.task == "domain_adaptation". If we are interested in comparing
113 | # domain generalization and domain adaptation results, then domain
114 | # generalization algorithms should create the same 'uda-splits', which will
115 | # be discared at training.
116 | in_splits = []
117 | out_splits = []
118 | uda_splits = []
119 | for env_i, env in enumerate(dataset):
120 | uda = []
121 |
122 | out, in_ = misc.split_dataset(env,
123 | int(len(env)*args.holdout_fraction),
124 | misc.seed_hash(args.trial_seed, env_i))
125 |
126 | if env_i in args.test_envs:
127 | uda, in_ = misc.split_dataset(in_,
128 | int(len(in_)*args.uda_holdout_fraction),
129 | misc.seed_hash(args.trial_seed, env_i))
130 |
131 | if hparams['class_balanced']:
132 | in_weights = misc.make_weights_for_balanced_classes(in_)
133 | out_weights = misc.make_weights_for_balanced_classes(out)
134 | if uda is not None:
135 | uda_weights = misc.make_weights_for_balanced_classes(uda)
136 | else:
137 | in_weights, out_weights, uda_weights = None, None, None
138 | in_splits.append((in_, in_weights))
139 | out_splits.append((out, out_weights))
140 | if len(uda):
141 | uda_splits.append((uda, uda_weights))
142 |
143 | if args.task == "domain_adaptation" and len(uda_splits) == 0:
144 | raise ValueError("Not enough unlabeled samples for domain adaptation.")
145 |
146 | train_loaders = [InfiniteDataLoader(
147 | dataset=env,
148 | weights=env_weights,
149 | batch_size=hparams['batch_size'],
150 | num_workers=dataset.N_WORKERS)
151 | for i, (env, env_weights) in enumerate(in_splits)
152 | if i not in args.test_envs]
153 |
154 | uda_loaders = [InfiniteDataLoader(
155 | dataset=env,
156 | weights=env_weights,
157 | batch_size=hparams['batch_size'],
158 | num_workers=dataset.N_WORKERS)
159 | for i, (env, env_weights) in enumerate(uda_splits)
160 | if i in args.test_envs]
161 |
162 | eval_loaders = [FastDataLoader(
163 | dataset=env,
164 | batch_size=64,
165 | num_workers=dataset.N_WORKERS)
166 | for env, _ in (in_splits + out_splits + uda_splits)]
167 | eval_weights = [None for _, weights in (in_splits + out_splits + uda_splits)]
168 | eval_loader_names = ['env{}_in'.format(i)
169 | for i in range(len(in_splits))]
170 | eval_loader_names += ['env{}_out'.format(i)
171 | for i in range(len(out_splits))]
172 | eval_loader_names += ['env{}_uda'.format(i)
173 | for i in range(len(uda_splits))]
174 |
175 | algorithm_class = algorithms.get_algorithm_class(args.algorithm)
176 | algorithm = algorithm_class(dataset.input_shape, dataset.num_classes,
177 | len(dataset) - len(args.test_envs), hparams)
178 |
179 | if algorithm_dict is not None:
180 | algorithm.load_state_dict(algorithm_dict)
181 |
182 | algorithm.to(device)
183 |
184 | train_minibatches_iterator = zip(*train_loaders)
185 | uda_minibatches_iterator = zip(*uda_loaders)
186 | checkpoint_vals = collections.defaultdict(lambda: [])
187 |
188 | steps_per_epoch = min([len(env)/hparams['batch_size'] for env,_ in in_splits])
189 |
190 | n_steps = args.steps or dataset.N_STEPS
191 | checkpoint_freq = args.checkpoint_freq or dataset.CHECKPOINT_FREQ
192 |
193 | def save_checkpoint(filename):
194 | if args.skip_model_save:
195 | return
196 | save_dict = {
197 | "args": vars(args),
198 | "model_input_shape": dataset.input_shape,
199 | "model_num_classes": dataset.num_classes,
200 | "model_num_domains": len(dataset) - len(args.test_envs),
201 | "model_hparams": hparams,
202 | "model_dict": algorithm.cpu().state_dict()
203 | }
204 | torch.save(save_dict, os.path.join(args.output_dir, filename))
205 |
206 |
207 | last_results_keys = None
208 | for step in range(start_step, n_steps):
209 | step_start_time = time.time()
210 | minibatches_device = [(x.to(device), y.to(device))
211 | for x,y in next(train_minibatches_iterator)]
212 | if args.task == "domain_adaptation":
213 | uda_device = [x.to(device)
214 | for x,_ in next(uda_minibatches_iterator)]
215 | else:
216 | uda_device = None
217 | step_vals = algorithm.update(minibatches_device, uda_device)
218 | checkpoint_vals['step_time'].append(time.time() - step_start_time)
219 |
220 | for key, val in step_vals.items():
221 | checkpoint_vals[key].append(val)
222 |
223 | if (step % checkpoint_freq == 0) or (step == n_steps - 1):
224 | results = {
225 | 'step': step,
226 | 'epoch': step / steps_per_epoch,
227 | }
228 |
229 | for key, val in checkpoint_vals.items():
230 | results[key] = np.mean(val)
231 |
232 | evals = zip(eval_loader_names, eval_loaders, eval_weights)
233 | for name, loader, weights in evals:
234 | acc = misc.accuracy(algorithm, loader, weights, device)
235 | results[name+'_acc'] = acc
236 |
237 | results['mem_gb'] = torch.cuda.max_memory_allocated() / (1024.*1024.*1024.)
238 |
239 | results_keys = sorted(results.keys())
240 | if results_keys != last_results_keys:
241 | misc.print_row(results_keys, colwidth=12)
242 | last_results_keys = results_keys
243 | misc.print_row([results[key] for key in results_keys],
244 | colwidth=12)
245 |
246 | results.update({
247 | 'hparams': hparams,
248 | 'args': vars(args)
249 | })
250 |
251 | epochs_path = os.path.join(args.output_dir, 'results.jsonl')
252 | with open(epochs_path, 'a') as f:
253 | f.write(json.dumps(results, sort_keys=True) + "\n")
254 |
255 | algorithm_dict = algorithm.state_dict()
256 | start_step = step + 1
257 | checkpoint_vals = collections.defaultdict(lambda: [])
258 |
259 | if args.save_model_every_checkpoint:
260 | save_checkpoint(f'model_step{step}.pkl')
261 |
262 | save_checkpoint('model.pkl')
263 |
264 | with open(os.path.join(args.output_dir, 'done'), 'w') as f:
265 | f.write('done')
266 |
--------------------------------------------------------------------------------
/fig_intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alexrame/fishr/7b8fdf1e0b15226ded9b58efd37698e74e616ab7/fig_intro.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1
2 | torchvision==0.9.1
3 | backpack-for-pytorch ==1.3.0
4 | numpy==1.20.2
5 |
--------------------------------------------------------------------------------