├── .gitignore ├── LICENSE ├── README.md ├── assets └── bpe_simple_vocab_16e6.txt.gz ├── clip ├── clip.py ├── model.py └── simple_tokenizer.py ├── config.py ├── download-weights.sh ├── generator.py ├── gpt2 ├── config.py ├── encoder.py ├── model.py ├── sample.py ├── utils.py └── weights │ ├── encoder.json │ └── vocab.bpe ├── gpt2_images ├── dog.jpeg ├── goldfish.jpeg ├── harmonica.jpeg ├── harp.jpeg ├── knot.jpeg ├── radio_telescope.jpeg ├── teapot.jpeg ├── telephone.jpeg └── zebra.jpeg ├── latent.py ├── models.py ├── operators.py ├── problem.py ├── requirements.txt ├── run.py ├── stylegan2 ├── __init__.py ├── convert_from_tf.py ├── external_models │ ├── __init__.py │ ├── inception.py │ └── lpips.py ├── loss_fns.py ├── metrics │ ├── __init__.py │ ├── fid.py │ └── ppl.py ├── models.py ├── modules.py ├── project.py ├── train.py └── utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | /env 2 | __pycache__/ 3 | /tmp 4 | 5 | /stylegan2/weights 6 | /gpt2/weights/gpt2-pytorch_model.bin 7 | 8 | /.vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP-GLaSS 2 | 3 | Repository for the paper [Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search](https://arxiv.org/abs/2102.01645) 4 | 5 | 6 | ### **An in-browser demo is available [here](https://colab.research.google.com/drive/1fWka_U56NhCegbbrQPt4PWpHPtNRdU49?usp=sharing)** 7 | 8 | 9 | ## Installation 10 | 11 | Clone this repository 12 | 13 | ``` 14 | git clone https://github.com/galatolofederico/clip-glass && cd clip-glass 15 | ``` 16 | 17 | Create a virtual environment and install the requirements 18 | 19 | ``` 20 | virtualenv --python=python3.6 env && . ./env/bin/activate 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Run CLIP-GLaSS 25 | 26 | You can run `CLIP-GLaSS` with: 27 | 28 | ``` 29 | python run.py --config --target 30 | ``` 31 | 32 | Specifying `` and `` according to the following table: 33 | 34 | | Config | Meaning | Target Type | 35 | |:--------------------:|:--------------------------------------------------------------------------:|:-----------:| 36 | | GPT2 | Use GPT2 to solve the Image-to-Text task | Image | 37 | | DeepMindBigGAN512 | Use DeepMind's BigGAN 512x512 to solve the Text-to-Image task | Text | 38 | | DeepMindBigGAN256 | Use DeepMind's BigGAN 256x256 to solve the Text-to-Image task | Text | 39 | | StyleGAN2_ffhq_d | Use StyleGAN2-ffhq to solve the Text-to-Image task | Text | 40 | | StyleGAN2_ffhq_nod | Use StyleGAN2-ffhq without Discriminator to solve the Text-to-Image task | Text | 41 | | StyleGAN2_church_d | Use StyleGAN2-church to solve the Text-to-Image task | Text | 42 | | StyleGAN2_church_nod | Use StyleGAN2-church without Discriminator to solve the Text-to-Image task | Text | 43 | | StyleGAN2_car_d | Use StyleGAN2-car to solve the Text-to-Image task | Text | 44 | | StyleGAN2_car_nod | Use StyleGAN2-car without Discriminator to solve the Text-to-Image task | Text | 45 | 46 | 47 | If you do not have downloaded the models weights you will be prompted to run `./download-weights.sh` 48 | You will find the results in the folder `./tmp`, a different output folder can be specified with `--tmp-folder` 49 | 50 | #### Examples 51 | 52 | ``` 53 | python run.py --config StyleGAN2_ffhq_d --target "the face of a man with brown eyes and stubble beard" 54 | python run.py --config GPT2 --target gpt2_images/dog.jpeg 55 | ``` 56 | 57 | 58 | ## Acknowledgments and licensing 59 | 60 | This work heavily relies on the following amazing repositories and would have not been possible without them: 61 | 62 | * [CLIP](https://github.com/openai/CLIP) from [openai](https://github.com/openai) (included in the folder `clip`) 63 | * [pytorch-pretrained-BigGAN](https://github.com/huggingface/pytorch-pretrained-BigGAN) from [huggingface](https://github.com/huggingface) 64 | * [stylegan2-pytorch](https://github.com/Tetratrio/stylegan2_pytorch) from [Adrian Sahlman](https://github.com/Tetratrio) (included in the folder `stylegan2`) 65 | * [gpt-2-pytorch](https://github.com/graykode/gpt-2-Pytorch) from [Tae-Hwan Jung](https://github.com/graykode) (included in the folder `gpt2`) 66 | 67 | All their work can be shared under the terms of the respective original licenses. 68 | 69 | All my original work (everything except the content of the folders `clip`, `stylegan2` and `gpt2`) is released under the terms of the [GNU/GPLv3](https://choosealicense.com/licenses/gpl-3.0/) license. Copying, adapting and republishing it is not only consent but also encouraged. 70 | 71 | ## Citing 72 | 73 | If you want to cite use you can use this BibTeX 74 | 75 | ``` 76 | @article{generating2021, 77 | author={Federico Galatolo. and Mario Cimino. and Gigliola Vaglini}, 78 | title={Generating Images from Caption and Vice Versa via CLIP-Guided Generative Latent Space Search}, 79 | journal={Proceedings of the International Conference on Image Processing and Vision Engineering}, 80 | year={2021}, 81 | volume={}, 82 | pages={}, 83 | publisher={SCITEPRESS - Science and Technology Publications}, 84 | doi={10.5220/0010503701660174}, 85 | issn={}, 86 | } 87 | ``` 88 | 89 | ## Contacts 90 | 91 | For any further question feel free to reach me at [federico.galatolo@ing.unipi.it](mailto:federico.galatolo@ing.unipi.it) or on Telegram [@galatolo](https://t.me/galatolo) 92 | -------------------------------------------------------------------------------- /assets/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/assets/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from clip.model import build_model 13 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | __all__ = ["available_models", "load", "tokenize"] 16 | _tokenizer = _Tokenizer() 17 | 18 | _MODELS = { 19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 20 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 21 | } 22 | 23 | 24 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 25 | os.makedirs(root, exist_ok=True) 26 | filename = os.path.basename(url) 27 | 28 | expected_sha256 = url.split("/")[-2] 29 | download_target = os.path.join(root, filename) 30 | 31 | if os.path.exists(download_target) and not os.path.isfile(download_target): 32 | raise RuntimeError(f"{download_target} exists and is not a regular file") 33 | 34 | if os.path.isfile(download_target): 35 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 36 | return download_target 37 | else: 38 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 39 | 40 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 41 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop: 42 | while True: 43 | buffer = source.read(8192) 44 | if not buffer: 45 | break 46 | 47 | output.write(buffer) 48 | loop.update(len(buffer)) 49 | 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 51 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 52 | 53 | return download_target 54 | 55 | 56 | def available_models(): 57 | return list(_MODELS.keys()) 58 | 59 | 60 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 61 | if name not in _MODELS: 62 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 63 | 64 | model_path = _download(_MODELS[name]) 65 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 66 | n_px = model.input_resolution.item() 67 | 68 | transform = Compose([ 69 | Resize(n_px, interpolation=Image.BICUBIC), 70 | CenterCrop(n_px), 71 | lambda image: image.convert("RGB"), 72 | ToTensor(), 73 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 74 | ]) 75 | 76 | if not jit: 77 | model = build_model(model.state_dict()).to(device) 78 | return model, transform 79 | 80 | # patch the device names 81 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 82 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 83 | 84 | def patch_device(module): 85 | graphs = [module.graph] if hasattr(module, "graph") else [] 86 | if hasattr(module, "forward1"): 87 | graphs.append(module.forward1.graph) 88 | 89 | for graph in graphs: 90 | for node in graph.findAllNodes("prim::Constant"): 91 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 92 | node.copyAttributes(device_node) 93 | 94 | model.apply(patch_device) 95 | patch_device(model.encode_image) 96 | patch_device(model.encode_text) 97 | 98 | # patch dtype to float32 on CPU 99 | if device == "cpu": 100 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 101 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 102 | float_node = float_input.node() 103 | 104 | def patch_float(module): 105 | graphs = [module.graph] if hasattr(module, "graph") else [] 106 | if hasattr(module, "forward1"): 107 | graphs.append(module.forward1.graph) 108 | 109 | for graph in graphs: 110 | for node in graph.findAllNodes("aten::to"): 111 | inputs = list(node.inputs()) 112 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 113 | if inputs[i].node()["value"] == 5: 114 | inputs[i].node().copyAttributes(float_node) 115 | 116 | model.apply(patch_float) 117 | patch_float(model.encode_image) 118 | patch_float(model.encode_text) 119 | 120 | model.float() 121 | 122 | return model, transform 123 | 124 | 125 | def tokenize(texts: Union[str, List[str]], context_length: int = 77): 126 | if isinstance(texts, str): 127 | texts = [texts] 128 | 129 | sot_token = _tokenizer.encoder["<|startoftext|>"] 130 | eot_token = _tokenizer.encoder["<|endoftext|>"] 131 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 132 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 133 | 134 | for i, tokens in enumerate(all_tokens): 135 | if len(tokens) > context_length: 136 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 137 | result[i, :len(tokens)] = torch.tensor(tokens) 138 | 139 | return result -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | expansion = 4 11 | 12 | def __init__(self, inplanes, planes, stride=1): 13 | super().__init__() 14 | 15 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 16 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | 19 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 23 | 24 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 26 | 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | self.stride = stride 30 | 31 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 32 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 33 | self.downsample = nn.Sequential(OrderedDict([ 34 | ("-1", nn.AvgPool2d(stride)), 35 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 36 | ("1", nn.BatchNorm2d(planes * self.expansion)) 37 | ])) 38 | 39 | def forward(self, x: torch.Tensor): 40 | identity = x 41 | 42 | out = self.relu(self.bn1(self.conv1(x))) 43 | out = self.relu(self.bn2(self.conv2(out))) 44 | out = self.avgpool(out) 45 | out = self.bn3(self.conv3(out)) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | return out 53 | 54 | 55 | class AttentionPool2d(nn.Module): 56 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 57 | super().__init__() 58 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 59 | self.k_proj = nn.Linear(embed_dim, embed_dim) 60 | self.q_proj = nn.Linear(embed_dim, embed_dim) 61 | self.v_proj = nn.Linear(embed_dim, embed_dim) 62 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 63 | self.num_heads = num_heads 64 | 65 | def forward(self, x): 66 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 67 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 68 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 69 | x, _ = F.multi_head_attention_forward( 70 | query=x, key=x, value=x, 71 | embed_dim_to_check=x.shape[-1], 72 | num_heads=self.num_heads, 73 | q_proj_weight=self.q_proj.weight, 74 | k_proj_weight=self.k_proj.weight, 75 | v_proj_weight=self.v_proj.weight, 76 | in_proj_weight=None, 77 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 78 | bias_k=None, 79 | bias_v=None, 80 | add_zero_attn=False, 81 | dropout_p=0, 82 | out_proj_weight=self.c_proj.weight, 83 | out_proj_bias=self.c_proj.bias, 84 | use_separate_proj_weight=True, 85 | training=self.training, 86 | need_weights=False 87 | ) 88 | 89 | return x[0] 90 | 91 | 92 | class ModifiedResNet(nn.Module): 93 | """ 94 | A ResNet class that is similar to torchvision's but contains the following changes: 95 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 96 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 97 | - The final pooling layer is a QKV attention instead of an average pool 98 | """ 99 | 100 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 101 | super().__init__() 102 | self.output_dim = output_dim 103 | self.input_resolution = input_resolution 104 | 105 | # the 3-layer stem 106 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 107 | self.bn1 = nn.BatchNorm2d(width // 2) 108 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 109 | self.bn2 = nn.BatchNorm2d(width // 2) 110 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 111 | self.bn3 = nn.BatchNorm2d(width) 112 | self.avgpool = nn.AvgPool2d(2) 113 | self.relu = nn.ReLU(inplace=True) 114 | 115 | # residual layers 116 | self._inplanes = width # this is a *mutable* variable used during construction 117 | self.layer1 = self._make_layer(width, layers[0]) 118 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 119 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 120 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 121 | 122 | embed_dim = width * 32 # the ResNet feature dimension 123 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 124 | 125 | def _make_layer(self, planes, blocks, stride=1): 126 | layers = [Bottleneck(self._inplanes, planes, stride)] 127 | 128 | self._inplanes = planes * Bottleneck.expansion 129 | for _ in range(1, blocks): 130 | layers.append(Bottleneck(self._inplanes, planes)) 131 | 132 | return nn.Sequential(*layers) 133 | 134 | def forward(self, x): 135 | def stem(x): 136 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 137 | x = self.relu(bn(conv(x))) 138 | x = self.avgpool(x) 139 | return x 140 | 141 | x = x.type(self.conv1.weight.dtype) 142 | x = stem(x) 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | x = self.attnpool(x) 148 | 149 | return x 150 | 151 | 152 | class LayerNorm(nn.LayerNorm): 153 | """Subclass torch's LayerNorm to handle fp16.""" 154 | 155 | def forward(self, x: torch.Tensor): 156 | orig_type = x.dtype 157 | ret = super().forward(x.type(torch.float32)) 158 | return ret.type(orig_type) 159 | 160 | 161 | class QuickGELU(nn.Module): 162 | def forward(self, x: torch.Tensor): 163 | return x * torch.sigmoid(1.702 * x) 164 | 165 | 166 | class ResidualAttentionBlock(nn.Module): 167 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 168 | super().__init__() 169 | 170 | self.attn = nn.MultiheadAttention(d_model, n_head) 171 | self.ln_1 = LayerNorm(d_model) 172 | self.mlp = nn.Sequential(OrderedDict([ 173 | ("c_fc", nn.Linear(d_model, d_model * 4)), 174 | ("gelu", QuickGELU()), 175 | ("c_proj", nn.Linear(d_model * 4, d_model)) 176 | ])) 177 | self.ln_2 = LayerNorm(d_model) 178 | self.attn_mask = attn_mask 179 | 180 | def attention(self, x: torch.Tensor): 181 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 182 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 183 | 184 | def forward(self, x: torch.Tensor): 185 | x = x + self.attention(self.ln_1(x)) 186 | x = x + self.mlp(self.ln_2(x)) 187 | return x 188 | 189 | 190 | class Transformer(nn.Module): 191 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 192 | super().__init__() 193 | self.width = width 194 | self.layers = layers 195 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 196 | 197 | def forward(self, x: torch.Tensor): 198 | return self.resblocks(x) 199 | 200 | 201 | class VisualTransformer(nn.Module): 202 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 203 | super().__init__() 204 | self.input_resolution = input_resolution 205 | self.output_dim = output_dim 206 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 207 | 208 | scale = width ** -0.5 209 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 210 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 211 | self.ln_pre = LayerNorm(width) 212 | 213 | self.transformer = Transformer(width, layers, heads) 214 | 215 | self.ln_post = LayerNorm(width) 216 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 217 | 218 | def forward(self, x: torch.Tensor): 219 | x = self.conv1(x) # shape = [*, width, grid, grid] 220 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 221 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 222 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 223 | x = x + self.positional_embedding.to(x.dtype) 224 | x = self.ln_pre(x) 225 | 226 | x = x.permute(1, 0, 2) # NLD -> LND 227 | x = self.transformer(x) 228 | x = x.permute(1, 0, 2) # LND -> NLD 229 | 230 | x = self.ln_post(x[:, 0, :]) 231 | 232 | if self.proj is not None: 233 | x = x @ self.proj 234 | 235 | return x 236 | 237 | 238 | class CLIP(nn.Module): 239 | def __init__(self, 240 | embed_dim: int, 241 | # vision 242 | image_resolution: int, 243 | vision_layers: Union[Tuple[int, int, int, int], int], 244 | vision_width: int, 245 | vision_patch_size: int, 246 | # text 247 | context_length: int, 248 | vocab_size: int, 249 | transformer_width: int, 250 | transformer_heads: int, 251 | transformer_layers: int 252 | ): 253 | super().__init__() 254 | 255 | self.context_length = context_length 256 | 257 | if isinstance(vision_layers, (tuple, list)): 258 | vision_heads = vision_width * 32 // 64 259 | self.visual = ModifiedResNet( 260 | layers=vision_layers, 261 | output_dim=embed_dim, 262 | heads=vision_heads, 263 | input_resolution=image_resolution, 264 | width=vision_width 265 | ) 266 | else: 267 | vision_heads = vision_width // 64 268 | self.visual = VisualTransformer( 269 | input_resolution=image_resolution, 270 | patch_size=vision_patch_size, 271 | width=vision_width, 272 | layers=vision_layers, 273 | heads=vision_heads, 274 | output_dim=embed_dim 275 | ) 276 | 277 | self.transformer = Transformer( 278 | width=transformer_width, 279 | layers=transformer_layers, 280 | heads=transformer_heads, 281 | attn_mask=self.build_attention_mask() 282 | ) 283 | 284 | self.vocab_size = vocab_size 285 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 286 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 287 | self.ln_final = LayerNorm(transformer_width) 288 | 289 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 290 | self.logit_scale = nn.Parameter(torch.ones([])) 291 | 292 | def build_attention_mask(self): 293 | # lazily create causal attention mask, with full attention between the vision tokens 294 | # pytorch uses additive attention mask; fill with -inf 295 | mask = torch.empty(self.context_length, self.context_length) 296 | mask.fill_(float("-inf")) 297 | mask.triu_(1) # zero out the lower diagonal 298 | return mask 299 | 300 | @property 301 | def dtype(self): 302 | return self.visual.conv1.weight.dtype 303 | 304 | def encode_image(self, image): 305 | return self.visual(image.type(self.dtype)) 306 | 307 | def encode_text(self, text): 308 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 309 | 310 | x = x + self.positional_embedding.type(self.dtype) 311 | x = x.permute(1, 0, 2) # NLD -> LND 312 | x = self.transformer(x) 313 | x = x.permute(1, 0, 2) # LND -> NLD 314 | x = self.ln_final(x).type(self.dtype) 315 | 316 | # x.shape = [batch_size, n_ctx, transformer.width] 317 | # take features from the eot embedding (eot_token is the highest number in each sequence) 318 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 319 | 320 | return x 321 | 322 | def forward(self, image, text): 323 | image_features = self.encode_image(image) 324 | text_features = self.encode_text(text) 325 | 326 | # normalized features 327 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 328 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 329 | 330 | # cosine similarity as logits 331 | logit_scale = self.logit_scale.exp() 332 | logits_per_iamge = logit_scale * image_features @ text_features.t() 333 | logits_per_text = logit_scale * text_features @ image_features.t() 334 | 335 | # shape = [global_batch_size, global_batch_size] 336 | return logits_per_iamge, logits_per_text 337 | 338 | 339 | def convert_weights(model: nn.Module): 340 | """Convert applicable model parameters to fp16""" 341 | 342 | def _convert_weights_to_fp16(l): 343 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 344 | l.weight.data = l.weight.data.half() 345 | if l.bias is not None: 346 | l.bias.data = l.bias.data.half() 347 | 348 | if isinstance(l, nn.MultiheadAttention): 349 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 350 | tensor = getattr(l, attr) 351 | if tensor is not None: 352 | tensor.data = tensor.data.half() 353 | 354 | for name in ["text_projection", "proj"]: 355 | if hasattr(l, name): 356 | attr = getattr(l, name) 357 | if attr is not None: 358 | attr.data = attr.data.half() 359 | 360 | model.apply(_convert_weights_to_fp16) 361 | 362 | 363 | def build_model(state_dict: dict): 364 | vit = "visual.proj" in state_dict 365 | 366 | if vit: 367 | vision_width = state_dict["visual.conv1.weight"].shape[0] 368 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 369 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 370 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 371 | image_resolution = vision_patch_size * grid_size 372 | else: 373 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 374 | vision_layers = tuple(counts) 375 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 376 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 377 | vision_patch_size = None 378 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 379 | image_resolution = output_width * 32 380 | 381 | embed_dim = state_dict["text_projection"].shape[1] 382 | context_length = state_dict["positional_embedding"].shape[0] 383 | vocab_size = state_dict["token_embedding.weight"].shape[0] 384 | transformer_width = state_dict["ln_final.weight"].shape[0] 385 | transformer_heads = transformer_width // 64 386 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 387 | 388 | model = CLIP( 389 | embed_dim, 390 | image_resolution, vision_layers, vision_width, vision_patch_size, 391 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 392 | ) 393 | 394 | for key in ["input_resolution", "context_length", "vocab_size"]: 395 | del state_dict[key] 396 | 397 | convert_weights(model) 398 | model.load_state_dict(state_dict) 399 | return model.eval() 400 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return "./assets/bpe_simple_vocab_16e6.txt.gz" 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from models import DeepMindBigGAN, StyleGAN2, GPT2 2 | from latent import DeepMindBigGANLatentSpace, StyleGAN2LatentSpace, GPT2LatentSpace 3 | from utils import biggan_norm, biggan_denorm 4 | 5 | configs = dict( 6 | GPT2 = dict( 7 | task = "img2txt", 8 | dim_z = 20, 9 | max_tokens_len = 30, 10 | max_text_len = 50, 11 | encoder_size = 50257, 12 | latent = GPT2LatentSpace, 13 | model = GPT2, 14 | use_discriminator = False, 15 | init_text = "the picture of", 16 | weights = "./gpt2/weights/gpt2-pytorch_model.bin", 17 | encoder = "./gpt2/weights/encoder.json", 18 | vocab = "./gpt2/weights/vocab.bpe", 19 | stochastic = False, 20 | algorithm = "ga", 21 | pop_size = 100, 22 | batch_size = 25, 23 | problem_args = dict( 24 | n_var = 20, 25 | n_obj = 1, 26 | n_constr = 20, 27 | xl = 0, 28 | xu = 50256 29 | ) 30 | ), 31 | DeepMindBigGAN256 = dict( 32 | task = "txt2img", 33 | dim_z = 128, 34 | num_classes = 1000, 35 | latent = DeepMindBigGANLatentSpace, 36 | model = DeepMindBigGAN, 37 | weights = "biggan-deep-256", 38 | use_discriminator = False, 39 | algorithm = "ga", 40 | norm = biggan_norm, 41 | denorm = biggan_denorm, 42 | truncation = 1.0, 43 | pop_size = 64, 44 | batch_size = 32, 45 | problem_args = dict( 46 | n_var = 128 + 1000, 47 | n_obj = 1, 48 | n_constr = 128, 49 | xl = -2, 50 | xu = 2 51 | ) 52 | ), 53 | DeepMindBigGAN512 = dict( 54 | task = "txt2img", 55 | dim_z = 128, 56 | num_classes = 1000, 57 | latent = DeepMindBigGANLatentSpace, 58 | model = DeepMindBigGAN, 59 | weights = "biggan-deep-512", 60 | use_discriminator = False, 61 | algorithm = "ga", 62 | norm = biggan_norm, 63 | denorm = biggan_denorm, 64 | truncation = 1.0, 65 | pop_size = 32, 66 | batch_size = 8, 67 | problem_args = dict( 68 | n_var = 128 + 1000, 69 | n_obj = 1, 70 | n_constr = 128, 71 | xl = -2, 72 | xu = 2 73 | ) 74 | ), 75 | StyleGAN2_ffhq_d = dict( 76 | task = "txt2img", 77 | dim_z = 512, 78 | latent = StyleGAN2LatentSpace, 79 | model = StyleGAN2, 80 | use_discriminator = True, 81 | weights = "./stylegan2/weights/ffhq-config-f", 82 | algorithm = "nsga2", 83 | norm = biggan_norm, 84 | denorm = biggan_denorm, 85 | pop_size = 16, 86 | batch_size = 4, 87 | problem_args = dict( 88 | n_var = 512, 89 | n_obj = 2, 90 | n_constr = 512, 91 | xl = -10, 92 | xu = 10, 93 | ), 94 | ), 95 | StyleGAN2_car_d = dict( 96 | task = "txt2img", 97 | dim_z = 512, 98 | latent = StyleGAN2LatentSpace, 99 | model = StyleGAN2, 100 | use_discriminator = True, 101 | weights = "./stylegan2/weights/car-config-f", 102 | algorithm = "nsga2", 103 | norm = biggan_norm, 104 | denorm = biggan_denorm, 105 | pop_size = 16, 106 | batch_size = 4, 107 | problem_args = dict( 108 | n_var = 512, 109 | n_obj = 2, 110 | n_constr = 512, 111 | xl = -10, 112 | xu = 10 113 | ), 114 | ), 115 | StyleGAN2_church_d = dict( 116 | task = "txt2img", 117 | dim_z = 512, 118 | latent = StyleGAN2LatentSpace, 119 | model = StyleGAN2, 120 | use_discriminator = True, 121 | weights = "./stylegan2/weights/church-config-f", 122 | algorithm = "nsga2", 123 | norm = biggan_norm, 124 | denorm = biggan_denorm, 125 | pop_size = 16, 126 | batch_size = 4, 127 | problem_args = dict( 128 | n_var = 512, 129 | n_obj = 2, 130 | n_constr = 512, 131 | xl = -10, 132 | xu = 10 133 | ), 134 | ), 135 | StyleGAN2_ffhq_nod = dict( 136 | task = "txt2img", 137 | dim_z = 512, 138 | latent = StyleGAN2LatentSpace, 139 | model = StyleGAN2, 140 | use_discriminator = False, 141 | weights = "./stylegan2/weights/ffhq-config-f", 142 | algorithm = "ga", 143 | norm = biggan_norm, 144 | denorm = biggan_denorm, 145 | pop_size = 16, 146 | batch_size = 4, 147 | problem_args = dict( 148 | n_var = 512, 149 | n_obj = 1, 150 | n_constr = 512, 151 | xl = -10, 152 | xu = 10 153 | ) 154 | ), 155 | StyleGAN2_car_nod = dict( 156 | task = "txt2img", 157 | dim_z = 512, 158 | latent = StyleGAN2LatentSpace, 159 | model = StyleGAN2, 160 | use_discriminator = False, 161 | weights = "./stylegan2/weights/car-config-f", 162 | algorithm = "ga", 163 | norm = biggan_norm, 164 | denorm = biggan_denorm, 165 | pop_size = 16, 166 | batch_size = 4, 167 | problem_args = dict( 168 | n_var = 512, 169 | n_obj = 1, 170 | n_constr = 512, 171 | xl = -10, 172 | xu = 10 173 | ) 174 | ), 175 | StyleGAN2_church_nod = dict( 176 | task = "txt2img", 177 | dim_z = 512, 178 | latent = StyleGAN2LatentSpace, 179 | model = StyleGAN2, 180 | use_discriminator = False, 181 | weights = "./stylegan2/weights/church-config-f", 182 | algorithm = "ga", 183 | norm = biggan_norm, 184 | denorm = biggan_denorm, 185 | pop_size = 16, 186 | batch_size = 4, 187 | problem_args = dict( 188 | n_var = 512, 189 | n_obj = 1, 190 | n_constr = 512, 191 | xl = -10, 192 | xu = 10 193 | ) 194 | ) 195 | ) 196 | 197 | 198 | 199 | def get_config(name): 200 | return configs[name] -------------------------------------------------------------------------------- /download-weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$#" -ne 1 ]; then 4 | echo "./download-weights.sh " 5 | echo "Possibile are: StyleGAN2-ffhq, StyleGAN2-church, StyleGAN2-car, GPT2" 6 | echo "Example:" 7 | echo "./download-weights.sh StyleGAN2-ffhq" 8 | exit 9 | fi 10 | 11 | die(){ 12 | echo "$1" 13 | exit 14 | } 15 | 16 | download_stylegan2(){ 17 | config="$1" 18 | dest="./stylegan2/weights/$config" 19 | [ -f "$dest/G.pth" ] && die "Weights already downloaded" 20 | [ ! -d "$dest" ] && mkdir -p "$dest" 21 | python -m stylegan2.convert_from_tf --download "$config" --output "$dest/G.pth" "$dest/D.pth" "$dest/Gs.pth" 22 | } 23 | 24 | 25 | case $1 in 26 | "StyleGAN2-ffhq") 27 | download_stylegan2 "ffhq-config-f" 28 | ;; 29 | "StyleGAN2-church") 30 | download_stylegan2 "church-config-f" 31 | ;; 32 | "StyleGAN2-car") 33 | download_stylegan2 "car-config-f" 34 | ;; 35 | "GPT2") 36 | [ -f "gpt2/weights/gpt2-pytorch_model.bin" ] && die "Weights already downloaded" 37 | curl --output gpt2/weights/gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin 38 | ;; 39 | *) 40 | echo "Unknown model '$1'" 41 | ;; 42 | esac -------------------------------------------------------------------------------- /generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_biggan import BigGAN 3 | from clip import clip 4 | import kornia 5 | from PIL import Image 6 | from torchvision.utils import save_image 7 | 8 | 9 | from utils import save_grid, freeze_model 10 | 11 | class Generator: 12 | def __init__(self, config): 13 | self.config = config 14 | self.augmentation = None 15 | 16 | self.CLIP, clip_preprocess = clip.load("ViT-B/32", device=self.config.device, jit=False) 17 | self.CLIP = self.CLIP.eval() 18 | freeze_model(self.CLIP) 19 | self.model = self.config.model(config).to(self.config.device).eval() 20 | freeze_model(self.model) 21 | 22 | if config.task == "txt2img": 23 | self.tokens = clip.tokenize([self.config.target]).to(self.config.device) 24 | self.text_features = self.CLIP.encode_text(self.tokens).detach() 25 | if config.task == "img2txt": 26 | image = clip_preprocess(Image.open(self.config.target)).unsqueeze(0).to(self.config.device) 27 | self.image_features = self.CLIP.encode_image(image) 28 | 29 | def generate(self, ls, minibatch=None): 30 | z = ls() 31 | result = self.model.generate(*z, minibatch=minibatch) 32 | if hasattr(self.config, "norm"): 33 | result = self.config.norm(result) 34 | return result 35 | 36 | def discriminate(self, images, minibatch=None): 37 | images = self.config.denorm(images) 38 | return self.model.discriminate(images, minibatch) 39 | 40 | def has_discriminator(self): 41 | return self.model.has_discriminator() 42 | 43 | def clip_similarity(self, input): 44 | if self.config.task == "txt2img": 45 | image = kornia.resize(input, (224, 224)) 46 | if self.augmentation is not None: 47 | image = self.augmentation(image) 48 | 49 | image_features = self.CLIP.encode_image(image) 50 | 51 | sim = torch.cosine_similarity(image_features, self.text_features) 52 | elif self.config.task == "img2txt": 53 | try: 54 | text_tokens = clip.tokenize(input).to(self.config.device) 55 | except: 56 | return torch.zeros(len(input)) 57 | text_features = self.CLIP.encode_text(text_tokens) 58 | 59 | sim = torch.cosine_similarity(text_features, self.image_features) 60 | return sim 61 | 62 | 63 | def save(self, input, path): 64 | if self.config.task == "txt2img": 65 | if input.shape[0] > 1: 66 | save_grid(input.detach().cpu(), path) 67 | else: 68 | save_image(input[0], path) 69 | elif self.config.task == "img2txt": 70 | f = open(path, "w") 71 | f.write("\n".join(input)) 72 | f.close() -------------------------------------------------------------------------------- /gpt2/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | class GPT2Config(object): 7 | def __init__( 8 | self, 9 | vocab_size_or_config_json_file=50257, 10 | n_positions=1024, 11 | n_ctx=1024, 12 | n_embd=768, 13 | n_layer=12, 14 | n_head=12, 15 | layer_norm_epsilon=1e-5, 16 | initializer_range=0.02, 17 | ): 18 | self.vocab_size = vocab_size_or_config_json_file 19 | self.n_ctx = n_ctx 20 | self.n_positions = n_positions 21 | self.n_embd = n_embd 22 | self.n_layer = n_layer 23 | self.n_head = n_head 24 | self.layer_norm_epsilon = layer_norm_epsilon 25 | self.initializer_range = initializer_range -------------------------------------------------------------------------------- /gpt2/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | Word is represented as tuple of symbols (symbols being variable-length strings). 33 | """ 34 | pairs = set() 35 | prev_char = word[0] 36 | for char in word[1:]: 37 | pairs.add((prev_char, char)) 38 | prev_char = char 39 | return pairs 40 | 41 | class Encoder: 42 | def __init__(self, encoder, bpe_merges, errors='replace'): 43 | self.encoder = encoder 44 | self.decoder = {v:k for k,v in self.encoder.items()} 45 | self.errors = errors # how to handle errors in decoding 46 | self.byte_encoder = bytes_to_unicode() 47 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 48 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 49 | self.cache = {} 50 | 51 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 52 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 53 | 54 | def bpe(self, token): 55 | if token in self.cache: 56 | return self.cache[token] 57 | word = tuple(token) 58 | pairs = get_pairs(word) 59 | 60 | if not pairs: 61 | return token 62 | 63 | while True: 64 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 65 | if bigram not in self.bpe_ranks: 66 | break 67 | first, second = bigram 68 | new_word = [] 69 | i = 0 70 | while i < len(word): 71 | try: 72 | j = word.index(first, i) 73 | new_word.extend(word[i:j]) 74 | i = j 75 | except: 76 | new_word.extend(word[i:]) 77 | break 78 | 79 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 80 | new_word.append(first+second) 81 | i += 2 82 | else: 83 | new_word.append(word[i]) 84 | i += 1 85 | new_word = tuple(new_word) 86 | word = new_word 87 | if len(word) == 1: 88 | break 89 | else: 90 | pairs = get_pairs(word) 91 | word = ' '.join(word) 92 | self.cache[token] = word 93 | return word 94 | 95 | def encode(self, text): 96 | bpe_tokens = [] 97 | for token in re.findall(self.pat, text): 98 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 99 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 100 | return bpe_tokens 101 | 102 | def decode(self, tokens): 103 | text = ''.join([self.decoder[token] for token in tokens]) 104 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 105 | return text 106 | 107 | def get_encoder(config): 108 | with open(config.encoder, 'r') as f: 109 | encoder = json.load(f) 110 | with open(config.vocab, 'r', encoding="utf-8") as f: 111 | bpe_data = f.read() 112 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 113 | return Encoder( 114 | encoder=encoder, 115 | bpe_merges=bpe_merges, 116 | ) -------------------------------------------------------------------------------- /gpt2/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import copy 7 | import torch 8 | import math 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | 12 | def gelu(x): 13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 14 | 15 | class LayerNorm(nn.Module): 16 | def __init__(self, hidden_size, eps=1e-12): 17 | """Construct a layernorm module in the TF style (epsilon inside the square root). 18 | """ 19 | super(LayerNorm, self).__init__() 20 | self.weight = nn.Parameter(torch.ones(hidden_size)) 21 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 22 | self.variance_epsilon = eps 23 | 24 | def forward(self, x): 25 | u = x.mean(-1, keepdim=True) 26 | s = (x - u).pow(2).mean(-1, keepdim=True) 27 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 28 | return self.weight * x + self.bias 29 | 30 | class Conv1D(nn.Module): 31 | def __init__(self, nf, nx): 32 | super(Conv1D, self).__init__() 33 | self.nf = nf 34 | w = torch.empty(nx, nf) 35 | nn.init.normal_(w, std=0.02) 36 | self.weight = Parameter(w) 37 | self.bias = Parameter(torch.zeros(nf)) 38 | 39 | def forward(self, x): 40 | size_out = x.size()[:-1] + (self.nf,) 41 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 42 | x = x.view(*size_out) 43 | return x 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, nx, n_ctx, config, scale=False): 47 | super(Attention, self).__init__() 48 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 49 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 50 | assert n_state % config.n_head == 0 51 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 52 | self.n_head = config.n_head 53 | self.split_size = n_state 54 | self.scale = scale 55 | self.c_attn = Conv1D(n_state * 3, nx) 56 | self.c_proj = Conv1D(n_state, nx) 57 | 58 | def _attn(self, q, k, v): 59 | w = torch.matmul(q, k) 60 | if self.scale: 61 | w = w / math.sqrt(v.size(-1)) 62 | nd, ns = w.size(-2), w.size(-1) 63 | b = self.bias[:, :, ns-nd:ns, :ns] 64 | w = w * b - 1e10 * (1 - b) 65 | w = nn.Softmax(dim=-1)(w) 66 | return torch.matmul(w, v) 67 | 68 | def merge_heads(self, x): 69 | x = x.permute(0, 2, 1, 3).contiguous() 70 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 71 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 72 | 73 | def split_heads(self, x, k=False): 74 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 75 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 76 | if k: 77 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 78 | else: 79 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 80 | 81 | def forward(self, x, layer_past=None): 82 | x = self.c_attn(x) 83 | query, key, value = x.split(self.split_size, dim=2) 84 | query = self.split_heads(query) 85 | key = self.split_heads(key, k=True) 86 | value = self.split_heads(value) 87 | if layer_past is not None: 88 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 89 | key = torch.cat((past_key, key), dim=-1) 90 | value = torch.cat((past_value, value), dim=-2) 91 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 92 | a = self._attn(query, key, value) 93 | a = self.merge_heads(a) 94 | a = self.c_proj(a) 95 | return a, present 96 | 97 | class MLP(nn.Module): 98 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 99 | super(MLP, self).__init__() 100 | nx = config.n_embd 101 | self.c_fc = Conv1D(n_state, nx) 102 | self.c_proj = Conv1D(nx, n_state) 103 | self.act = gelu 104 | 105 | def forward(self, x): 106 | h = self.act(self.c_fc(x)) 107 | h2 = self.c_proj(h) 108 | return h2 109 | 110 | class Block(nn.Module): 111 | def __init__(self, n_ctx, config, scale=False): 112 | super(Block, self).__init__() 113 | nx = config.n_embd 114 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 115 | self.attn = Attention(nx, n_ctx, config, scale) 116 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 117 | self.mlp = MLP(4 * nx, config) 118 | 119 | def forward(self, x, layer_past=None): 120 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 121 | x = x + a 122 | m = self.mlp(self.ln_2(x)) 123 | x = x + m 124 | return x, present 125 | 126 | class GPT2Model(nn.Module): 127 | def __init__(self, config): 128 | super(GPT2Model, self).__init__() 129 | self.n_layer = config.n_layer 130 | self.n_embd = config.n_embd 131 | self.n_vocab = config.vocab_size 132 | 133 | self.wte = nn.Embedding(config.vocab_size, config.n_embd) 134 | self.wpe = nn.Embedding(config.n_positions, config.n_embd) 135 | block = Block(config.n_ctx, config, scale=True) 136 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 137 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 138 | 139 | def set_embeddings_weights(self, model_embeddings_weights): 140 | embed_shape = model_embeddings_weights.shape 141 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 142 | self.decoder.weight = model_embeddings_weights # Tied weights 143 | 144 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): 145 | if past is None: 146 | past_length = 0 147 | past = [None] * len(self.h) 148 | else: 149 | past_length = past[0][0].size(-2) 150 | if position_ids is None: 151 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, 152 | device=input_ids.device) 153 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 154 | 155 | input_shape = input_ids.size() 156 | input_ids = input_ids.view(-1, input_ids.size(-1)) 157 | position_ids = position_ids.view(-1, position_ids.size(-1)) 158 | 159 | inputs_embeds = self.wte(input_ids) 160 | position_embeds = self.wpe(position_ids) 161 | if token_type_ids is not None: 162 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 163 | token_type_embeds = self.wte(token_type_ids) 164 | else: 165 | token_type_embeds = 0 166 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 167 | 168 | presents = [] 169 | for block, layer_past in zip(self.h, past): 170 | hidden_states, present = block(hidden_states, layer_past) 171 | presents.append(present) 172 | hidden_states = self.ln_f(hidden_states) 173 | output_shape = input_shape + (hidden_states.size(-1),) 174 | 175 | return hidden_states.view(*output_shape), presents 176 | 177 | class GPT2LMHead(nn.Module): 178 | def __init__(self, model_embeddings_weights, config): 179 | super(GPT2LMHead, self).__init__() 180 | self.n_embd = config.n_embd 181 | self.set_embeddings_weights(model_embeddings_weights) 182 | 183 | def set_embeddings_weights(self, model_embeddings_weights): 184 | embed_shape = model_embeddings_weights.shape 185 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 186 | self.decoder.weight = model_embeddings_weights # Tied weights 187 | 188 | def forward(self, hidden_state): 189 | # Truncated Language modeling logits (we remove the last token) 190 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) 191 | lm_logits = self.decoder(hidden_state) 192 | return lm_logits 193 | 194 | class GPT2LMHeadModel(nn.Module): 195 | def __init__(self, config): 196 | super(GPT2LMHeadModel, self).__init__() 197 | self.transformer = GPT2Model(config) 198 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) 199 | 200 | def set_tied(self): 201 | """ Make sure we are sharing the embeddings 202 | """ 203 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight) 204 | 205 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): 206 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past) 207 | lm_logits = self.lm_head(hidden_states) 208 | if lm_labels is not None: 209 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 210 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)) 211 | return loss 212 | return lm_logits, presents -------------------------------------------------------------------------------- /gpt2/sample.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import trange 9 | 10 | def top_k_logits(logits, k): 11 | if k == 0: 12 | return logits 13 | values, _ = torch.topk(logits, k) 14 | min_values = values[:, -1] 15 | rets = [] 16 | for l, m in zip(logits, min_values): 17 | rets.append(torch.where(l < m, torch.ones_like(l, dtype=l.dtype) * -1e10, l)) 18 | rets = torch.stack(rets) 19 | return rets 20 | 21 | def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device="cuda", sample=True): 22 | prev = context 23 | output = context 24 | past = None 25 | with torch.no_grad(): 26 | for i in range(length): 27 | logits, past = model(prev, past=past) 28 | logits = logits[:, -1, :] / temperature 29 | logits = top_k_logits(logits, k=top_k) 30 | log_probs = F.softmax(logits, dim=-1) 31 | if sample: 32 | prev = torch.multinomial(log_probs, num_samples=1) 33 | else: 34 | _, prev = torch.topk(log_probs, k=1, dim=-1) 35 | output = torch.cat((output, prev), dim=1) 36 | 37 | return output.cpu().numpy().tolist() -------------------------------------------------------------------------------- /gpt2/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by TaeHwan Jung(@graykode) 3 | Original Paper and repository here : https://github.com/openai/gpt-2 4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT 5 | ''' 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | def load_weight(model, state_dict): 11 | old_keys = [] 12 | new_keys = [] 13 | for key in state_dict.keys(): 14 | new_key = None 15 | if key.endswith(".g"): 16 | new_key = key[:-2] + ".weight" 17 | elif key.endswith(".b"): 18 | new_key = key[:-2] + ".bias" 19 | elif key.endswith(".w"): 20 | new_key = key[:-2] + ".weight" 21 | if new_key: 22 | old_keys.append(key) 23 | new_keys.append(new_key) 24 | for old_key, new_key in zip(old_keys, new_keys): 25 | state_dict[new_key] = state_dict.pop(old_key) 26 | 27 | missing_keys = [] 28 | unexpected_keys = [] 29 | error_msgs = [] 30 | # copy state_dict so _load_from_state_dict can modify it 31 | metadata = getattr(state_dict, "_metadata", None) 32 | state_dict = state_dict.copy() 33 | if metadata is not None: 34 | state_dict._metadata = metadata 35 | 36 | def load(module, prefix=""): 37 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 38 | module._load_from_state_dict( 39 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs 40 | ) 41 | for name, child in module._modules.items(): 42 | if child is not None: 43 | load(child, prefix + name + ".") 44 | 45 | start_model = model 46 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()): 47 | start_model = model.transformer 48 | load(start_model, prefix="") 49 | 50 | # Make sure we are still sharing the output and input embeddings after loading weights 51 | model.set_tied() 52 | return model -------------------------------------------------------------------------------- /gpt2_images/dog.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/dog.jpeg -------------------------------------------------------------------------------- /gpt2_images/goldfish.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/goldfish.jpeg -------------------------------------------------------------------------------- /gpt2_images/harmonica.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/harmonica.jpeg -------------------------------------------------------------------------------- /gpt2_images/harp.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/harp.jpeg -------------------------------------------------------------------------------- /gpt2_images/knot.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/knot.jpeg -------------------------------------------------------------------------------- /gpt2_images/radio_telescope.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/radio_telescope.jpeg -------------------------------------------------------------------------------- /gpt2_images/teapot.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/teapot.jpeg -------------------------------------------------------------------------------- /gpt2_images/telephone.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/telephone.jpeg -------------------------------------------------------------------------------- /gpt2_images/zebra.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/zebra.jpeg -------------------------------------------------------------------------------- /latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_pretrained_biggan import truncated_noise_sample 3 | 4 | class DeepMindBigGANLatentSpace(torch.nn.Module): 5 | def __init__(self, config): 6 | super(DeepMindBigGANLatentSpace, self).__init__() 7 | self.config = config 8 | 9 | self.z = torch.nn.Parameter(torch.tensor(truncated_noise_sample(self.config.batch_size)).to(self.config.device)) 10 | self.class_labels = torch.nn.Parameter(torch.rand(self.config.batch_size, self.config.num_classes).to(self.config.device)) 11 | 12 | def set_values(self, z, class_labels): 13 | self.z.data = z 14 | self.class_labels.data = class_labels 15 | 16 | def set_from_population(self, x): 17 | self.z.data = torch.tensor(x[:,:self.config.dim_z].astype(float)).float().to(self.config.device) 18 | self.class_labels.data = torch.tensor(x[:,self.config.dim_z:].astype(float)).float().to(self.config.device) 19 | 20 | def forward(self): 21 | z = torch.clip(self.z, -2, 2) 22 | class_labels = torch.softmax(self.class_labels, dim=1) 23 | 24 | return z, class_labels 25 | 26 | 27 | class StyleGAN2LatentSpace(torch.nn.Module): 28 | def __init__(self, config): 29 | super(StyleGAN2LatentSpace, self).__init__() 30 | self.config = config 31 | 32 | self.z = torch.nn.Parameter(torch.randn(self.config.batch_size, self.config.dim_z).to(self.config.device)) 33 | 34 | def set_values(self, z): 35 | self.z.data = z 36 | 37 | def set_from_population(self, x): 38 | self.z.data = torch.tensor(x.astype(float)).float().to(self.config.device) 39 | 40 | def forward(self): 41 | return (self.z, ) 42 | 43 | 44 | class GPT2LatentSpace(torch.nn.Module): 45 | def __init__(self, config): 46 | super(GPT2LatentSpace, self).__init__() 47 | self.config = config 48 | 49 | self.z = torch.randint(0, self.config.encoder_size, size=(self.config.batch_size, self.config.dim_z)).to(self.config.device) 50 | #self.z = torch.zeros(self.config.batch_size, self.config.dim_z) 51 | 52 | def set_values(self, z): 53 | self.z.data = z 54 | 55 | def set_from_population(self, x): 56 | self.z.data = torch.tensor(x.astype(int)).long().to(self.config.device) 57 | 58 | def forward(self): 59 | return (self.z, ) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from pytorch_pretrained_biggan import BigGAN as DMBigGAN 5 | import stylegan2 6 | 7 | from gpt2.model import GPT2LMHeadModel 8 | from gpt2.utils import load_weight 9 | from gpt2.config import GPT2Config 10 | from gpt2.sample import sample_sequence 11 | from gpt2.encoder import get_encoder 12 | 13 | 14 | class GPT2(torch.nn.Module): 15 | def __init__(self, config): 16 | super(GPT2, self).__init__() 17 | self.config = config 18 | if not os.path.exists(self.config.weights): 19 | print("Weights not found!\nRun: ./download-weights.sh GPT2") 20 | sys.exit(1) 21 | 22 | state_dict = torch.load(self.config.weights, map_location=self.config.device) 23 | 24 | self.enc = get_encoder(config) 25 | self.model = GPT2LMHeadModel(GPT2Config()) 26 | self.model = load_weight(self.model, state_dict) 27 | self.model.to(self.config.device) 28 | self.model.eval() 29 | 30 | self.init_tokens = torch.tensor(self.enc.encode(self.config.init_text)).to(self.config.device) 31 | 32 | def parse_out(self, out): 33 | texts = [] 34 | for seq in out: 35 | if self.enc.encoder["<|endoftext|>"] in seq: 36 | text = seq[self.config.dim_z:seq.index(self.enc.encoder["<|endoftext|>"])] 37 | else: 38 | text = seq[self.config.dim_z:] 39 | text = self.enc.decode(text) 40 | 41 | texts.append(text[:self.config.max_text_len]) 42 | return texts 43 | 44 | 45 | def generate(self, z, minibatch=None): 46 | #TODO: implement minibatch 47 | init_tokens = self.init_tokens.repeat(z.shape[0], 1) 48 | z = torch.cat((z, init_tokens), dim=1) 49 | 50 | out = sample_sequence( 51 | model=self.model, 52 | length=self.config.max_tokens_len, 53 | context=z, 54 | start_token=None, 55 | batch_size=self.config.batch_size, 56 | temperature=0.7, 57 | top_k=40, 58 | device=self.config.device, 59 | sample=self.config.stochastic 60 | ) 61 | 62 | return self.parse_out(out) 63 | 64 | 65 | class DeepMindBigGAN(torch.nn.Module): 66 | def __init__(self, config): 67 | super(DeepMindBigGAN, self).__init__() 68 | self.config = config 69 | self.G = DMBigGAN.from_pretrained(config.weights) 70 | self.D = None 71 | 72 | def has_discriminator(self): 73 | return False 74 | 75 | def generate(self, z, class_labels, minibatch = None): 76 | if minibatch is None: 77 | return self.G(z, class_labels, self.config.truncation) 78 | else: 79 | assert z.shape[0] % minibatch == 0 80 | gen_images = [] 81 | for i in range(0, z.shape[0] // minibatch): 82 | z_minibatch = z[i*minibatch:(i+1)*minibatch, :] 83 | cl_minibatch = class_labels[i*minibatch:(i+1)*minibatch, :] 84 | gen_images.append(self.G(z_minibatch, cl_minibatch, self.config.truncation)) 85 | gen_images = torch.cat(gen_images) 86 | return gen_images 87 | 88 | 89 | 90 | class StyleGAN2(torch.nn.Module): 91 | def __init__(self, config): 92 | super(StyleGAN2, self).__init__() 93 | if not os.path.exists(os.path.join(config.weights, "G.pth")): 94 | if "ffhq" in config.config: 95 | model = "ffhq" 96 | elif "car" in config.config: 97 | model = "car" 98 | elif "church" in config.config: 99 | model = "church" 100 | print("Weights not found!\nRun : ./download-weights.sh StyleGAN2-%s" % (model)) 101 | sys.exit(1) 102 | self.G = stylegan2.models.load(os.path.join(config.weights, "G.pth")) 103 | self.D = stylegan2.models.load(os.path.join(config.weights, "D.pth")) 104 | 105 | def has_discriminator(self): 106 | return True 107 | 108 | def generate(self, z, minibatch = None): 109 | if minibatch is None: 110 | return self.G(z) 111 | else: 112 | assert z.shape[0] % minibatch == 0 113 | gen_images = [] 114 | for i in range(0, z.shape[0] // minibatch): 115 | z_minibatch = z[i*minibatch:(i+1)*minibatch, :] 116 | gen_images.append(self.G(z_minibatch)) 117 | gen_images = torch.cat(gen_images) 118 | return gen_images 119 | 120 | def discriminate(self, images, minibatch = None): 121 | if minibatch is None: 122 | return self.D(images) 123 | else: 124 | assert images.shape[0] % minibatch == 0 125 | discriminations = [] 126 | for i in range(0, images.shape[0] // minibatch): 127 | images_minibatch = images[i*minibatch:(i+1)*minibatch, :] 128 | discriminations.append(self.D(images_minibatch)) 129 | discriminations = torch.cat(discriminations) 130 | return discriminations -------------------------------------------------------------------------------- /operators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy.stats import truncnorm 4 | 5 | from pymoo.factory import get_sampling, get_crossover, get_mutation 6 | from pymoo.operators.mixed_variable_operator import MixedVariableSampling, MixedVariableMutation, MixedVariableCrossover 7 | from pymoo.model.sampling import Sampling 8 | 9 | class TruncatedNormalRandomSampling(Sampling): 10 | def __init__(self, var_type=np.float): 11 | super().__init__() 12 | self.var_type = var_type 13 | 14 | def _do(self, problem, n_samples, **kwargs): 15 | return truncnorm.rvs(-2, 2, size=(n_samples, problem.n_var)).astype(np.float32) 16 | 17 | class NormalRandomSampling(Sampling): 18 | def __init__(self, mu=0, std=1, var_type=np.float): 19 | super().__init__() 20 | self.mu = mu 21 | self.std = std 22 | self.var_type = var_type 23 | 24 | def _do(self, problem, n_samples, **kwargs): 25 | return np.random.normal(self.mu, self.std, size=(n_samples, problem.n_var)) 26 | 27 | class BinaryRandomSampling(Sampling): 28 | def __init__(self, prob=0.5): 29 | super().__init__() 30 | self.prob = prob 31 | 32 | def _do(self, problem, n_samples, **kwargs): 33 | val = np.random.random((n_samples, problem.n_var)) 34 | return (val < self.prob).astype(np.bool) 35 | 36 | 37 | def get_operators(config): 38 | if config.config == "DeepMindBigGAN256" or config.config == "DeepMindBigGAN512": 39 | mask = ["real"]*config.dim_z + ["bool"]*config.num_classes 40 | 41 | real_sampling = None 42 | if config.config == "DeepMindBigGAN256" or config.config == "DeepMindBigGAN512": 43 | real_sampling = TruncatedNormalRandomSampling() 44 | 45 | sampling = MixedVariableSampling(mask, { 46 | "real": real_sampling, 47 | "bool": BinaryRandomSampling(prob=5/1000) 48 | }) 49 | 50 | crossover = MixedVariableCrossover(mask, { 51 | "real": get_crossover("real_sbx", prob=1.0, eta=3.0), 52 | "bool": get_crossover("bin_hux", prob=0.2) 53 | }) 54 | 55 | mutation = MixedVariableMutation(mask, { 56 | "real": get_mutation("real_pm", prob=0.5, eta=3.0), 57 | "bool": get_mutation("bin_bitflip", prob=10/1000) 58 | }) 59 | 60 | return dict( 61 | sampling=sampling, 62 | crossover=crossover, 63 | mutation=mutation 64 | ) 65 | 66 | elif config.config.split("_")[0] == "StyleGAN2": 67 | return dict( 68 | sampling=NormalRandomSampling(), 69 | crossover=get_crossover("real_sbx", prob=1.0, eta=3.0), 70 | mutation=get_mutation("real_pm", prob=0.5, eta=3.0) 71 | ) 72 | 73 | elif config.config == "GPT2": 74 | return dict( 75 | sampling=get_sampling("int_random"), 76 | crossover=get_crossover("int_sbx", prob=1.0, eta=3.0), 77 | mutation=get_mutation("int_pm", prob=0.5, eta=3.0) 78 | ) 79 | 80 | else: 81 | raise Exception("Unknown config") 82 | 83 | -------------------------------------------------------------------------------- /problem.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from pymoo.model.problem import Problem 5 | from generator import Generator 6 | 7 | class GenerationProblem(Problem): 8 | def __init__(self, config): 9 | self.generator = Generator(config) 10 | self.config = config 11 | 12 | super().__init__(**self.config.problem_args) 13 | 14 | def _evaluate(self, x, out, *args, **kwargs): 15 | ls = self.config.latent(self.config) 16 | ls.set_from_population(x) 17 | 18 | with torch.no_grad(): 19 | generated = self.generator.generate(ls, minibatch=self.config.batch_size) 20 | sim = self.generator.clip_similarity(generated).cpu().numpy() 21 | if self.config.problem_args["n_obj"] == 2 and self.config.use_discriminator: 22 | dis = self.generator.discriminate(generated, minibatch=self.config.batch_size) 23 | hinge = torch.relu(1 - dis) 24 | hinge = hinge.squeeze(1).cpu().numpy() 25 | out["F"] = np.column_stack((-sim, hinge)) 26 | else: 27 | out["F"] = -sim 28 | 29 | out["G"] = np.zeros((x.shape[0])) 30 | 31 | 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | autograd==1.3 3 | boto3==1.16.63 4 | botocore==1.19.63 5 | cachetools==4.2.1 6 | certifi==2020.12.5 7 | chardet==4.0.0 8 | cma==2.7.0 9 | cycler==0.10.0 10 | dataclasses==0.8 11 | ftfy==5.8 12 | future==0.18.2 13 | google-auth==1.24.0 14 | google-auth-oauthlib==0.4.2 15 | grpcio==1.35.0 16 | idna==2.10 17 | importlib-metadata==3.4.0 18 | jmespath==0.10.0 19 | kiwisolver==1.3.1 20 | kornia==0.4.1 21 | Markdown==3.3.3 22 | matplotlib==3.3.4 23 | numpy==1.19.5 24 | oauthlib==3.1.0 25 | Pillow==8.1.0 26 | protobuf==3.14.0 27 | pyasn1==0.4.8 28 | pyasn1-modules==0.2.8 29 | pymoo==0.4.2.1 30 | pyparsing==2.4.7 31 | python-dateutil==2.8.1 32 | pytorch-pretrained-biggan==0.1.1 33 | PyYAML==5.4.1 34 | regex==2020.11.13 35 | requests==2.25.1 36 | requests-oauthlib==1.3.0 37 | rsa==4.7 38 | s3transfer==0.3.4 39 | scipy==1.5.4 40 | six==1.15.0 41 | tensorboard==2.4.1 42 | tensorboard-plugin-wit==1.8.0 43 | torch==1.7.1 44 | torchvision==0.8.2 45 | tqdm==4.56.0 46 | typing-extensions==3.7.4.3 47 | urllib3==1.26.3 48 | wcwidth==0.2.5 49 | Werkzeug==1.0.1 50 | zipp==3.4.0 51 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import pickle 6 | from pymoo.optimize import minimize 7 | from pymoo.algorithms.so_genetic_algorithm import GA 8 | from pymoo.factory import get_algorithm, get_decision_making, get_decomposition 9 | from pymoo.visualization.scatter import Scatter 10 | 11 | from config import get_config 12 | from problem import GenerationProblem 13 | from operators import get_operators 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument("--device", type=str, default="cuda") 18 | parser.add_argument("--config", type=str, default="DeepMindBigGAN512") 19 | parser.add_argument("--generations", type=int, default=500) 20 | parser.add_argument("--save-each", type=int, default=50) 21 | parser.add_argument("--tmp-folder", type=str, default="./tmp") 22 | parser.add_argument("--target", type=str, default="a wolf at night with the moon in the background") 23 | 24 | config = parser.parse_args() 25 | vars(config).update(get_config(config.config)) 26 | 27 | 28 | iteration = 0 29 | def save_callback(algorithm): 30 | global iteration 31 | global config 32 | 33 | iteration += 1 34 | if iteration % config.save_each == 0 or iteration == config.generations: 35 | if config.problem_args["n_obj"] == 1: 36 | sortedpop = sorted(algorithm.pop, key=lambda p: p.F) 37 | X = np.stack([p.X for p in sortedpop]) 38 | else: 39 | X = algorithm.pop.get("X") 40 | 41 | ls = config.latent(config) 42 | ls.set_from_population(X) 43 | 44 | with torch.no_grad(): 45 | generated = algorithm.problem.generator.generate(ls, minibatch=config.batch_size) 46 | if config.task == "txt2img": 47 | ext = "jpg" 48 | elif config.task == "img2txt": 49 | ext = "txt" 50 | name = "genetic-it-%d.%s" % (iteration, ext) if iteration < config.generations else "genetic-it-final.%s" % (ext, ) 51 | algorithm.problem.generator.save(generated, os.path.join(config.tmp_folder, name)) 52 | 53 | 54 | problem = GenerationProblem(config) 55 | operators = get_operators(config) 56 | 57 | if not os.path.exists(config.tmp_folder): os.mkdir(config.tmp_folder) 58 | 59 | algorithm = get_algorithm( 60 | config.algorithm, 61 | pop_size=config.pop_size, 62 | sampling=operators["sampling"], 63 | crossover=operators["crossover"], 64 | mutation=operators["mutation"], 65 | eliminate_duplicates=True, 66 | callback=save_callback, 67 | **(config.algorithm_args[config.algorithm] if "algorithm_args" in config and config.algorithm in config.algorithm_args else dict()) 68 | ) 69 | 70 | res = minimize( 71 | problem, 72 | algorithm, 73 | ("n_gen", config.generations), 74 | save_history=False, 75 | verbose=True, 76 | ) 77 | 78 | 79 | pickle.dump(dict( 80 | X = res.X, 81 | F = res.F, 82 | G = res.G, 83 | CV = res.CV, 84 | ), open(os.path.join(config.tmp_folder, "genetic_result"), "wb")) 85 | 86 | if config.problem_args["n_obj"] == 2: 87 | plot = Scatter(labels=["similarity", "discriminator",]) 88 | plot.add(res.F, color="red") 89 | plot.save(os.path.join(config.tmp_folder, "F.jpg")) 90 | 91 | 92 | if config.problem_args["n_obj"] == 1: 93 | sortedpop = sorted(res.pop, key=lambda p: p.F) 94 | X = np.stack([p.X for p in sortedpop]) 95 | else: 96 | X = res.pop.get("X") 97 | 98 | ls = config.latent(config) 99 | ls.set_from_population(X) 100 | 101 | torch.save(ls.state_dict(), os.path.join(config.tmp_folder, "ls_result")) 102 | 103 | if config.problem_args["n_obj"] == 1: 104 | X = np.atleast_2d(res.X) 105 | else: 106 | try: 107 | result = get_decision_making("pseudo-weights", [0, 1]).do(res.F) 108 | except: 109 | print("Warning: cant use pseudo-weights") 110 | result = get_decomposition("asf").do(res.F, [0, 1]).argmin() 111 | 112 | X = res.X[result] 113 | X = np.atleast_2d(X) 114 | 115 | ls.set_from_population(X) 116 | 117 | with torch.no_grad(): 118 | generated = problem.generator.generate(ls) 119 | 120 | if config.task == "txt2img": 121 | ext = "jpg" 122 | elif config.task == "img2txt": 123 | ext = "txt" 124 | 125 | problem.generator.save(generated, os.path.join(config.tmp_folder, "output.%s" % (ext))) -------------------------------------------------------------------------------- /stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | from . import external_models 2 | from . import metrics 3 | from . import models 4 | from . import project 5 | from . import train 6 | -------------------------------------------------------------------------------- /stylegan2/convert_from_tf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import pickle 4 | import argparse 5 | import io 6 | import requests 7 | import torch 8 | import stylegan2 9 | from stylegan2 import utils 10 | 11 | 12 | pretrained_model_urls = { 13 | 'car-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl', 14 | 'car-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl', 15 | 'cat-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl', 16 | 'church-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl', 17 | 'ffhq-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl', 18 | 'ffhq-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl', 19 | 'horse-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl', 20 | 'car-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl', 21 | 'car-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl', 22 | 'car-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl', 23 | 'car-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl', 24 | 'car-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl', 25 | 'car-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl', 26 | 'car-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl', 27 | 'car-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl', 28 | 'car-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl', 29 | 'ffhq-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl', 30 | 'ffhq-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl', 31 | 'ffhq-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl', 32 | 'ffhq-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl', 33 | 'ffhq-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl', 34 | 'ffhq-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl', 35 | 'ffhq-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl', 36 | 'ffhq-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl', 37 | 'ffhq-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl', 38 | } 39 | 40 | 41 | class Unpickler(pickle.Unpickler): 42 | def find_class(self, module, name): 43 | if module == 'dnnlib.tflib.network' and name == 'Network': 44 | return utils.AttributeDict 45 | return super(Unpickler, self).find_class(module, name) 46 | 47 | 48 | def load_tf_models_file(fpath): 49 | with open(fpath, 'rb') as fp: 50 | return Unpickler(fp).load() 51 | 52 | 53 | def load_tf_models_url(url): 54 | print('Downloading file {}...'.format(url)) 55 | with requests.Session() as session: 56 | with session.get(url) as ret: 57 | fp = io.BytesIO(ret.content) 58 | return Unpickler(fp).load() 59 | 60 | 61 | def convert_kwargs(static_kwargs, kwargs_mapping): 62 | kwargs = utils.AttributeDict() 63 | for key, value in static_kwargs.items(): 64 | if key in kwargs_mapping: 65 | if value == 'lrelu': 66 | value = 'leaky:0.2' 67 | for k in utils.to_list(kwargs_mapping[key]): 68 | kwargs[k] = value 69 | return kwargs 70 | 71 | 72 | _PERMITTED_MODELS = ['G_main', 'G_mapping', 'G_synthesis_stylegan2', 'D_stylegan2', 'D_main', 'G_synthesis'] 73 | def convert_from_tf(tf_state): 74 | tf_state = utils.AttributeDict.convert_dict_recursive(tf_state) 75 | model_type = tf_state.build_func_name 76 | assert model_type in _PERMITTED_MODELS, \ 77 | 'Found model type {}. '.format(model_type) + \ 78 | 'Allowed model types are: {}'.format(_PERMITTED_MODELS) 79 | 80 | if model_type == 'G_main': 81 | kwargs = convert_kwargs( 82 | static_kwargs=tf_state.static_kwargs, 83 | kwargs_mapping={ 84 | 'dlatent_avg_beta': 'dlatent_avg_beta' 85 | } 86 | ) 87 | kwargs.G_mapping = convert_from_tf(tf_state.components.mapping) 88 | kwargs.G_synthesis = convert_from_tf(tf_state.components.synthesis) 89 | G = stylegan2.models.Generator(**kwargs) 90 | for name, var in tf_state.variables: 91 | if name == 'dlatent_avg': 92 | G.dlatent_avg.data.copy_(torch.from_numpy(var)) 93 | kwargs = convert_kwargs( 94 | static_kwargs=tf_state.static_kwargs, 95 | kwargs_mapping={ 96 | 'truncation_psi': 'truncation_psi', 97 | 'truncation_cutoff': 'truncation_cutoff', 98 | 'truncation_psi_val': 'truncation_psi', 99 | 'truncation_cutoff_val': 'truncation_cutoff' 100 | } 101 | ) 102 | G.set_truncation(**kwargs) 103 | return G 104 | 105 | if model_type == 'G_mapping': 106 | kwargs = convert_kwargs( 107 | static_kwargs=tf_state.static_kwargs, 108 | kwargs_mapping={ 109 | 'mapping_nonlinearity': 'activation', 110 | 'normalize_latents': 'normalize_input', 111 | 'mapping_lr_mul': 'lr_mul' 112 | } 113 | ) 114 | kwargs.num_layers = sum( 115 | 1 for var_name, _ in tf_state.variables 116 | if re.match('Dense[0-9]+/weight', var_name) 117 | ) 118 | for var_name, var in tf_state.variables: 119 | if var_name == 'LabelConcat/weight': 120 | kwargs.label_size = var.shape[0] 121 | if var_name == 'Dense0/weight': 122 | kwargs.latent_size = var.shape[0] 123 | kwargs.hidden = var.shape[1] 124 | if var_name == 'Dense{}/bias'.format(kwargs.num_layers - 1): 125 | kwargs.out_size = var.shape[0] 126 | G_mapping = stylegan2.models.GeneratorMapping(**kwargs) 127 | for var_name, var in tf_state.variables: 128 | if re.match('Dense[0-9]+/[a-zA-Z]*', var_name): 129 | layer_idx = int(re.search('Dense(\d+)/[a-zA-Z]*', var_name).groups()[0]) 130 | if var_name.endswith('weight'): 131 | G_mapping.main[layer_idx].layer.weight.data.copy_( 132 | torch.from_numpy(var.T).contiguous()) 133 | elif var_name.endswith('bias'): 134 | G_mapping.main[layer_idx].bias.data.copy_(torch.from_numpy(var)) 135 | if var_name == 'LabelConcat/weight': 136 | G_mapping.embedding.weight.data.copy_(torch.from_numpy(var)) 137 | return G_mapping 138 | 139 | if model_type == 'G_synthesis_stylegan2' or model_type == 'G_synthesis': 140 | assert tf_state.static_kwargs.get('fused_modconv', True), \ 141 | 'Can not load TF networks that use `fused_modconv=False`' 142 | noise_tensors = [] 143 | conv_vars = {} 144 | for var_name, var in tf_state.variables: 145 | if var_name.startswith('noise'): 146 | noise_tensors.append(torch.from_numpy(var)) 147 | else: 148 | layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0]) 149 | if layer_size not in conv_vars: 150 | conv_vars[layer_size] = {} 151 | var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '') 152 | conv_vars[layer_size][var_name] = var 153 | noise_tensors = sorted(noise_tensors, key=lambda x:x.size(-1)) 154 | kwargs = convert_kwargs( 155 | static_kwargs=tf_state.static_kwargs, 156 | kwargs_mapping={ 157 | 'nonlinearity': 'activation', 158 | 'resample_filter': ['conv_filter', 'skip_filter'] 159 | } 160 | ) 161 | kwargs.skip = False 162 | kwargs.resnet = True 163 | kwargs.channels = [] 164 | for size in sorted(conv_vars.keys(), reverse=True): 165 | if size == 4: 166 | if 'ToRGB/weight' in conv_vars[size]: 167 | kwargs.skip = True 168 | kwargs.resnet = False 169 | kwargs.latent_size = conv_vars[size]['Conv/mod_weight'].shape[0] 170 | kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0]) 171 | else: 172 | kwargs.channels.append(conv_vars[size]['Conv1/bias'].shape[0]) 173 | if 'ToRGB/bias' in conv_vars[size]: 174 | kwargs.data_channels = conv_vars[size]['ToRGB/bias'].shape[0] 175 | G_synthesis = stylegan2.models.GeneratorSynthesis(**kwargs) 176 | G_synthesis.const.data.copy_(torch.from_numpy(conv_vars[4]['Const/const']).squeeze(0)) 177 | def assign_weights(layer, weight, bias, mod_weight, mod_bias, noise_strength, transposed=False): 178 | layer.bias.data.copy_(torch.from_numpy(bias)) 179 | layer.layer.weight.data.copy_(torch.tensor(noise_strength)) 180 | layer.layer.layer.dense.layer.weight.data.copy_( 181 | torch.from_numpy(mod_weight.T).contiguous()) 182 | layer.layer.layer.dense.bias.data.copy_(torch.from_numpy(mod_bias + 1)) 183 | weight = torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous() 184 | if transposed: 185 | weight = weight.flip(dims=[2,3]) 186 | layer.layer.layer.weight.data.copy_(weight) 187 | conv_blocks = G_synthesis.conv_blocks 188 | for i, size in enumerate(sorted(conv_vars.keys())): 189 | block = conv_blocks[i] 190 | if size == 4: 191 | assign_weights( 192 | layer=block.conv_block[0], 193 | weight=conv_vars[size]['Conv/weight'], 194 | bias=conv_vars[size]['Conv/bias'], 195 | mod_weight=conv_vars[size]['Conv/mod_weight'], 196 | mod_bias=conv_vars[size]['Conv/mod_bias'], 197 | noise_strength=conv_vars[size]['Conv/noise_strength'], 198 | ) 199 | else: 200 | assign_weights( 201 | layer=block.conv_block[0], 202 | weight=conv_vars[size]['Conv0_up/weight'], 203 | bias=conv_vars[size]['Conv0_up/bias'], 204 | mod_weight=conv_vars[size]['Conv0_up/mod_weight'], 205 | mod_bias=conv_vars[size]['Conv0_up/mod_bias'], 206 | noise_strength=conv_vars[size]['Conv0_up/noise_strength'], 207 | transposed=True 208 | ) 209 | assign_weights( 210 | layer=block.conv_block[1], 211 | weight=conv_vars[size]['Conv1/weight'], 212 | bias=conv_vars[size]['Conv1/bias'], 213 | mod_weight=conv_vars[size]['Conv1/mod_weight'], 214 | mod_bias=conv_vars[size]['Conv1/mod_bias'], 215 | noise_strength=conv_vars[size]['Conv1/noise_strength'], 216 | ) 217 | if 'Skip/weight' in conv_vars[size]: 218 | block.projection.weight.data.copy_(torch.from_numpy( 219 | conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous()) 220 | to_RGB = G_synthesis.to_data_layers[i] 221 | if to_RGB is not None: 222 | to_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['ToRGB/bias'])) 223 | to_RGB.layer.weight.data.copy_(torch.from_numpy( 224 | conv_vars[size]['ToRGB/weight']).permute((3, 2, 0, 1)).contiguous()) 225 | to_RGB.layer.dense.bias.data.copy_( 226 | torch.from_numpy(conv_vars[size]['ToRGB/mod_bias'] + 1)) 227 | to_RGB.layer.dense.layer.weight.data.copy_( 228 | torch.from_numpy(conv_vars[size]['ToRGB/mod_weight'].T).contiguous()) 229 | if not tf_state.static_kwargs.get('randomize_noise', True): 230 | G_synthesis.static_noise(noise_tensors=noise_tensors) 231 | return G_synthesis 232 | 233 | if model_type == 'D_stylegan2' or model_type == 'D_main': 234 | output_vars = {} 235 | conv_vars = {} 236 | for var_name, var in tf_state.variables: 237 | if var_name.startswith('Output'): 238 | output_vars[var_name.replace('Output/', '')] = var 239 | else: 240 | layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0]) 241 | if layer_size not in conv_vars: 242 | conv_vars[layer_size] = {} 243 | var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '') 244 | conv_vars[layer_size][var_name] = var 245 | kwargs = convert_kwargs( 246 | static_kwargs=tf_state.static_kwargs, 247 | kwargs_mapping={ 248 | 'nonlinearity': 'activation', 249 | 'resample_filter': ['conv_filter', 'skip_filter'], 250 | 'mbstd_group_size': 'mbstd_group_size' 251 | } 252 | ) 253 | kwargs.skip = False 254 | kwargs.resnet = True 255 | kwargs.channels = [] 256 | for size in sorted(conv_vars.keys(), reverse=True): 257 | if size == 4: 258 | if 'FromRGB/weight' in conv_vars[size]: 259 | kwargs.skip = True 260 | kwargs.resnet = False 261 | kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0]) 262 | kwargs.dense_hidden = conv_vars[size]['Dense0/bias'].shape[0] 263 | else: 264 | kwargs.channels.append(conv_vars[size]['Conv0/bias'].shape[0]) 265 | if 'FromRGB/weight' in conv_vars[size]: 266 | kwargs.data_channels = conv_vars[size]['FromRGB/weight'].shape[-2] 267 | output_size = output_vars['bias'].shape[0] 268 | if output_size > 1: 269 | kwargs.label_size = output_size 270 | D = stylegan2.models.Discriminator(**kwargs) 271 | def assign_weights(layer, weight, bias): 272 | layer.bias.data.copy_(torch.from_numpy(bias)) 273 | layer.layer.weight.data.copy_( 274 | torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous()) 275 | conv_blocks = D.conv_blocks 276 | for i, size in enumerate(sorted(conv_vars.keys())): 277 | block = conv_blocks[-i - 1] 278 | if size == 4: 279 | assign_weights( 280 | layer=block[-1].conv_block[0], 281 | weight=conv_vars[size]['Conv/weight'], 282 | bias=conv_vars[size]['Conv/bias'], 283 | ) 284 | else: 285 | assign_weights( 286 | layer=block.conv_block[0], 287 | weight=conv_vars[size]['Conv0/weight'], 288 | bias=conv_vars[size]['Conv0/bias'], 289 | ) 290 | assign_weights( 291 | layer=block.conv_block[1], 292 | weight=conv_vars[size]['Conv1_down/weight'], 293 | bias=conv_vars[size]['Conv1_down/bias'], 294 | ) 295 | if 'Skip/weight' in conv_vars[size]: 296 | block.projection.weight.data.copy_(torch.from_numpy( 297 | conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous()) 298 | from_RGB = D.from_data_layers[-i - 1] 299 | if from_RGB is not None: 300 | from_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['FromRGB/bias'])) 301 | from_RGB.layer.weight.data.copy_(torch.from_numpy( 302 | conv_vars[size]['FromRGB/weight']).permute((3, 2, 0, 1)).contiguous()) 303 | return D 304 | 305 | 306 | def get_arg_parser(): 307 | parser = argparse.ArgumentParser( 308 | description='Convert tensorflow stylegan2 model to pytorch.', 309 | epilog='Pretrained models that can be downloaded:\n{}'.format( 310 | '\n'.join(pretrained_model_urls.keys())) 311 | ) 312 | 313 | parser.add_argument( 314 | '-i', 315 | '--input', 316 | help='File path to pickled tensorflow models.', 317 | type=str, 318 | default=None, 319 | ) 320 | 321 | parser.add_argument( 322 | '-d', 323 | '--download', 324 | help='Download the specified pretrained model. Use --help for info on available models.', 325 | type=str, 326 | default=None, 327 | ) 328 | 329 | parser.add_argument( 330 | '-o', 331 | '--output', 332 | help='One or more output file paths. Alternatively a directory path ' + \ 333 | 'where all models will be saved. Default: current directory', 334 | type=str, 335 | nargs='*', 336 | default=['.'], 337 | ) 338 | 339 | return parser 340 | 341 | 342 | def main(): 343 | args = get_arg_parser().parse_args() 344 | assert bool(args.input) != bool(args.download), \ 345 | 'Incorrect input format. Can only take either one ' + \ 346 | 'input filepath to a pickled tensorflow model or ' + \ 347 | 'a model name to download, but not both at the same ' + \ 348 | 'time or none at all.' 349 | if args.input: 350 | unpickled = load_tf_models_file(args.input) 351 | else: 352 | assert args.download in pretrained_model_urls.keys(), \ 353 | 'Unknown model {}. Use --help for list of models.'.format(args.download) 354 | unpickled = load_tf_models_url(pretrained_model_urls[args.download]) 355 | if not isinstance(unpickled, (tuple, list)): 356 | unpickled = [unpickled] 357 | print('Converting tensorflow models and saving them...') 358 | converted = [convert_from_tf(tf_state) for tf_state in unpickled] 359 | if len(args.output) == 1 and (os.path.isdir(args.output[0]) or not os.path.splitext(args.output[0])[-1]): 360 | if not os.path.exists(args.output[0]): 361 | os.makedirs(args.output[0]) 362 | for tf_state, torch_model in zip(unpickled, converted): 363 | torch_model.save(os.path.join(args.output[0], tf_state['name'] + '.pth')) 364 | else: 365 | assert len(args.output) == len(converted), 'Found {} models '.format(len(converted)) + \ 366 | 'in pickled file but only {} output paths were given.'.format(len(args.output)) 367 | for out_path, torch_model in zip(args.output, converted): 368 | torch_model.save(out_path) 369 | print('Done!') 370 | 371 | 372 | if __name__ == '__main__': 373 | main() -------------------------------------------------------------------------------- /stylegan2/external_models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import inception 2 | from . import lpips 3 | -------------------------------------------------------------------------------- /stylegan2/external_models/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/mseitzer/pytorch-fid/ 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | Unless required by applicable law or agreed to in writing, software 9 | distributed under the License is distributed on an "AS IS" BASIS, 10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | See the License for the specific language governing permissions and 12 | limitations under the License. 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torchvision import models 18 | 19 | try: 20 | from torchvision.models.utils import load_state_dict_from_url 21 | except ImportError: 22 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 23 | 24 | # Inception weights ported to Pytorch from 25 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 26 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 27 | 28 | 29 | class InceptionV3FeatureExtractor(nn.Module): 30 | """Pretrained InceptionV3 network returning feature maps""" 31 | 32 | # Index of default block of inception to return, 33 | # corresponds to output of final average pooling 34 | DEFAULT_BLOCK_INDEX = 3 35 | 36 | # Maps feature dimensionality to their output blocks indices 37 | BLOCK_INDEX_BY_DIM = { 38 | 64: 0, # First max pooling features 39 | 192: 1, # Second max pooling featurs 40 | 768: 2, # Pre-aux classifier features 41 | 2048: 3 # Final average pooling features 42 | } 43 | 44 | def __init__(self, 45 | output_block=DEFAULT_BLOCK_INDEX, 46 | pixel_min=-1, 47 | pixel_max=1): 48 | """ 49 | Build pretrained InceptionV3 50 | Arguments: 51 | output_block (int): Index of block to return features of. 52 | Possible values are: 53 | - 0: corresponds to output of first max pooling 54 | - 1: corresponds to output of second max pooling 55 | - 2: corresponds to output which is fed to aux classifier 56 | - 3: corresponds to output of final average pooling 57 | pixel_min (float): Min value for inputs. Default value is -1. 58 | pixel_max (float): Max value for inputs. Default value is 1. 59 | """ 60 | super(InceptionV3FeatureExtractor, self).__init__() 61 | 62 | assert 0 <= output_block <= 3, '`output_block` can only be ' + \ 63 | '0 <= `output_block` <= 3.' 64 | 65 | inception = fid_inception_v3() 66 | 67 | blocks = [] 68 | 69 | # Block 0: input to maxpool1 70 | block0 = [ 71 | inception.Conv2d_1a_3x3, 72 | inception.Conv2d_2a_3x3, 73 | inception.Conv2d_2b_3x3, 74 | nn.MaxPool2d(kernel_size=3, stride=2) 75 | ] 76 | blocks.append(nn.Sequential(*block0)) 77 | 78 | # Block 1: maxpool1 to maxpool2 79 | if output_block >= 1: 80 | block1 = [ 81 | inception.Conv2d_3b_1x1, 82 | inception.Conv2d_4a_3x3, 83 | nn.MaxPool2d(kernel_size=3, stride=2) 84 | ] 85 | blocks.append(nn.Sequential(*block1)) 86 | 87 | # Block 2: maxpool2 to aux classifier 88 | if output_block >= 2: 89 | block2 = [ 90 | inception.Mixed_5b, 91 | inception.Mixed_5c, 92 | inception.Mixed_5d, 93 | inception.Mixed_6a, 94 | inception.Mixed_6b, 95 | inception.Mixed_6c, 96 | inception.Mixed_6d, 97 | inception.Mixed_6e, 98 | ] 99 | blocks.append(nn.Sequential(*block2)) 100 | 101 | # Block 3: aux classifier to final avgpool 102 | if output_block >= 3: 103 | block3 = [ 104 | inception.Mixed_7a, 105 | inception.Mixed_7b, 106 | inception.Mixed_7c, 107 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 108 | ] 109 | blocks.append(nn.Sequential(*block3)) 110 | 111 | self.main = nn.Sequential(*blocks) 112 | self.pixel_nin = pixel_min 113 | self.pixel_max = pixel_max 114 | self.requires_grad_(False) 115 | self.eval() 116 | 117 | def _scale(self, x): 118 | if self.pixel_min != -1 or self.pixel_max != 1: 119 | x = (2*x - self.pixel_min - self.pixel_max) \ 120 | / (self.pixel_max - self.pixel_min) 121 | return x 122 | 123 | def forward(self, input): 124 | """ 125 | Get Inception feature maps. 126 | Arguments: 127 | input (torch.Tensor) 128 | Returns: 129 | feature_maps (torch.Tensor) 130 | """ 131 | return self.main(input) 132 | 133 | 134 | def fid_inception_v3(): 135 | """Build pretrained Inception model for FID computation 136 | The Inception model for FID computation uses a different set of weights 137 | and has a slightly different structure than torchvision's Inception. 138 | This method first constructs torchvision's Inception and then patches the 139 | necessary parts that are different in the FID Inception model. 140 | """ 141 | inception = models.inception_v3(num_classes=1008, 142 | aux_logits=False, 143 | pretrained=False) 144 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 145 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 146 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 147 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 148 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 149 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 150 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 151 | inception.Mixed_7b = FIDInceptionE_1(1280) 152 | inception.Mixed_7c = FIDInceptionE_2(2048) 153 | 154 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 155 | inception.load_state_dict(state_dict) 156 | return inception 157 | 158 | 159 | class FIDInceptionA(models.inception.InceptionA): 160 | """InceptionA block patched for FID computation""" 161 | def __init__(self, in_channels, pool_features): 162 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 163 | 164 | def forward(self, x): 165 | branch1x1 = self.branch1x1(x) 166 | 167 | branch5x5 = self.branch5x5_1(x) 168 | branch5x5 = self.branch5x5_2(branch5x5) 169 | 170 | branch3x3dbl = self.branch3x3dbl_1(x) 171 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 172 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 173 | 174 | # Patch: Tensorflow's average pool does not use the padded zero's in 175 | # its average calculation 176 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 177 | count_include_pad=False) 178 | branch_pool = self.branch_pool(branch_pool) 179 | 180 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 181 | return torch.cat(outputs, 1) 182 | 183 | 184 | class FIDInceptionC(models.inception.InceptionC): 185 | """InceptionC block patched for FID computation""" 186 | def __init__(self, in_channels, channels_7x7): 187 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 188 | 189 | def forward(self, x): 190 | branch1x1 = self.branch1x1(x) 191 | 192 | branch7x7 = self.branch7x7_1(x) 193 | branch7x7 = self.branch7x7_2(branch7x7) 194 | branch7x7 = self.branch7x7_3(branch7x7) 195 | 196 | branch7x7dbl = self.branch7x7dbl_1(x) 197 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 198 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 199 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 200 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 201 | 202 | # Patch: Tensorflow's average pool does not use the padded zero's in 203 | # its average calculation 204 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 205 | count_include_pad=False) 206 | branch_pool = self.branch_pool(branch_pool) 207 | 208 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 209 | return torch.cat(outputs, 1) 210 | 211 | 212 | class FIDInceptionE_1(models.inception.InceptionE): 213 | """First InceptionE block patched for FID computation""" 214 | def __init__(self, in_channels): 215 | super(FIDInceptionE_1, self).__init__(in_channels) 216 | 217 | def forward(self, x): 218 | branch1x1 = self.branch1x1(x) 219 | 220 | branch3x3 = self.branch3x3_1(x) 221 | branch3x3 = [ 222 | self.branch3x3_2a(branch3x3), 223 | self.branch3x3_2b(branch3x3), 224 | ] 225 | branch3x3 = torch.cat(branch3x3, 1) 226 | 227 | branch3x3dbl = self.branch3x3dbl_1(x) 228 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 229 | branch3x3dbl = [ 230 | self.branch3x3dbl_3a(branch3x3dbl), 231 | self.branch3x3dbl_3b(branch3x3dbl), 232 | ] 233 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 234 | 235 | # Patch: Tensorflow's average pool does not use the padded zero's in 236 | # its average calculation 237 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 238 | count_include_pad=False) 239 | branch_pool = self.branch_pool(branch_pool) 240 | 241 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 242 | return torch.cat(outputs, 1) 243 | 244 | 245 | class FIDInceptionE_2(models.inception.InceptionE): 246 | """Second InceptionE block patched for FID computation""" 247 | def __init__(self, in_channels): 248 | super(FIDInceptionE_2, self).__init__(in_channels) 249 | 250 | def forward(self, x): 251 | branch1x1 = self.branch1x1(x) 252 | 253 | branch3x3 = self.branch3x3_1(x) 254 | branch3x3 = [ 255 | self.branch3x3_2a(branch3x3), 256 | self.branch3x3_2b(branch3x3), 257 | ] 258 | branch3x3 = torch.cat(branch3x3, 1) 259 | 260 | branch3x3dbl = self.branch3x3dbl_1(x) 261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 262 | branch3x3dbl = [ 263 | self.branch3x3dbl_3a(branch3x3dbl), 264 | self.branch3x3dbl_3b(branch3x3dbl), 265 | ] 266 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 267 | 268 | # Patch: The FID Inception model uses max pooling instead of average 269 | # pooling. This is likely an error in this specific Inception 270 | # implementation, as other Inception models use average pooling here 271 | # (which matches the description in the paper). 272 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | -------------------------------------------------------------------------------- /stylegan2/external_models/lpips.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/richzhang/PerceptualSimilarity 3 | 4 | Original License: 5 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 6 | All rights reserved. 7 | 8 | Redistribution and use in source and binary forms, with or without 9 | modification, are permitted provided that the following conditions are met: 10 | 11 | * Redistributions of source code must retain the above copyright notice, this 12 | list of conditions and the following disclaimer. 13 | 14 | * Redistributions in binary form must reproduce the above copyright notice, 15 | this list of conditions and the following disclaimer in the documentation 16 | and/or other materials provided with the distribution. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | """ 29 | import torch 30 | from torch import nn 31 | import torchvision 32 | 33 | 34 | class LPIPS_VGG16(nn.Module): 35 | _FEATURE_IDX = [0, 4, 9, 16, 23, 30] 36 | _LINEAR_WEIGHTS_URL = 'https://github.com/richzhang/PerceptualSimilarity' + \ 37 | '/blob/master/lpips/weights/v0.1/vgg.pth?raw=true' 38 | 39 | def __init__(self, pixel_min=-1, pixel_max=1): 40 | super(LPIPS_VGG16, self).__init__() 41 | features = torchvision.models.vgg16(pretrained=True).features 42 | self.slices = nn.ModuleList() 43 | linear_weights = torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL) 44 | for i in range(1, len(self._FEATURE_IDX)): 45 | idx_range = range(self._FEATURE_IDX[i - 1], self._FEATURE_IDX[i]) 46 | self.slices.append(nn.Sequential(*[features[j] for j in idx_range])) 47 | self.linear_layers = nn.ModuleList() 48 | for weight in torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL).values(): 49 | weight = weight.view(1, -1) 50 | linear = nn.Linear(weight.size(1), 1, bias=False) 51 | linear.weight.data.copy_(weight) 52 | self.linear_layers.append(linear) 53 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188]).view(1, -1, 1, 1)) 54 | self.register_buffer('scale', torch.Tensor([.458,.448,.450]).view(1, -1, 1, 1)) 55 | self.pixel_min = pixel_min 56 | self.pixel_max = pixel_max 57 | self.requires_grad_(False) 58 | self.eval() 59 | 60 | def _scale(self, x): 61 | if self.pixel_min != -1 or self.pixel_max != 1: 62 | x = (2*x - self.pixel_min - self.pixel_max) \ 63 | / (self.pixel_max - self.pixel_min) 64 | return (x - self.shift) / self.scale 65 | 66 | @staticmethod 67 | def _normalize_tensor(feature_maps, eps=1e-8): 68 | rnorm = torch.rsqrt(torch.sum(feature_maps ** 2, dim=1, keepdim=True) + eps) 69 | return feature_maps * rnorm 70 | 71 | def forward(self, x0, x1, eps=1e-8): 72 | x0, x1 = self._scale(x0), self._scale(x1) 73 | dist = 0 74 | for slice, linear in zip(self.slices, self.linear_layers): 75 | x0, x1 = slice(x0), slice(x1) 76 | _x0, _x1 = self._normalize_tensor(x0, eps), self._normalize_tensor(x1, eps) 77 | dist += linear(torch.mean((_x0 - _x1) ** 2, dim=[-1, -2])) 78 | return dist.view(-1) 79 | -------------------------------------------------------------------------------- /stylegan2/loss_fns.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from . import utils 6 | 7 | 8 | def _grad(input, output, retain_graph): 9 | # https://discuss.pytorch.org/t/gradient-penalty-loss-with-modified-weights/64910 10 | # Currently not possible to not 11 | # retain graph for regularization losses. 12 | # Ugly hack is to always set it to True. 13 | retain_graph = True 14 | grads = torch.autograd.grad( 15 | output.sum(), 16 | input, 17 | only_inputs=True, 18 | retain_graph=retain_graph, 19 | create_graph=True 20 | ) 21 | return grads[0] 22 | 23 | 24 | def _grad_pen(input, output, gamma, constraint=1, onesided=False, retain_graph=True): 25 | grad = _grad(input, output, retain_graph=retain_graph) 26 | grad = grad.view(grad.size(0), -1) 27 | grad_norm = grad.norm(2, dim=1) 28 | if onesided: 29 | gp = torch.max(0, grad_norm - constraint) 30 | else: 31 | gp = (grad_norm - constraint) ** 2 32 | return gamma * gp.mean() 33 | 34 | 35 | def _grad_reg(input, output, gamma, retain_graph=True): 36 | grad = _grad(input, output, retain_graph=retain_graph) 37 | grad = grad.view(grad.size(0), -1) 38 | gr = (grad ** 2).sum(1) 39 | return (0.5 * gamma) * gr.mean() 40 | 41 | 42 | def _pathreg(dlatents, fakes, pl_avg, pl_decay, gamma, retain_graph=True): 43 | retain_graph = True 44 | pl_noise = torch.empty_like(fakes).normal_().div_(np.sqrt(np.prod(fakes.size()[2:]))) 45 | pl_grad = _grad(dlatents, torch.sum(pl_noise * fakes), retain_graph=retain_graph) 46 | pl_length = torch.sqrt(torch.mean(torch.sum(pl_grad ** 2, dim=2), dim=1)) 47 | with torch.no_grad(): 48 | pl_avg.add_(pl_decay * (torch.mean(pl_length) - pl_avg)) 49 | return gamma * torch.mean((pl_length - pl_avg) ** 2) 50 | 51 | 52 | #---------------------------------------------------------------------------- 53 | # Logistic loss from the paper 54 | # "Generative Adversarial Nets", Goodfellow et al. 2014 55 | 56 | 57 | def G_logistic(G, 58 | D, 59 | latents, 60 | latent_labels=None, 61 | *args, 62 | **kwargs): 63 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float() 64 | loss = - F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores)) 65 | reg = None 66 | return loss, reg 67 | 68 | 69 | def G_logistic_ns(G, 70 | D, 71 | latents, 72 | latent_labels=None, 73 | *args, 74 | **kwargs): 75 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float() 76 | loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores)) 77 | reg = None 78 | return loss, reg 79 | 80 | 81 | def D_logistic(G, 82 | D, 83 | latents, 84 | reals, 85 | latent_labels=None, 86 | real_labels=None, 87 | *args, 88 | **kwargs): 89 | assert (latent_labels is None) == (real_labels is None) 90 | with torch.no_grad(): 91 | fakes = G(latents, labels=latent_labels) 92 | real_scores = D(reals, labels=real_labels).float() 93 | fake_scores = D(fakes, labels=latent_labels).float() 94 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores)) 95 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores)) 96 | loss = real_loss + fake_loss 97 | reg = None 98 | return loss, reg 99 | 100 | 101 | #---------------------------------------------------------------------------- 102 | # R1 and R2 regularizers from the paper 103 | # "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018 104 | 105 | 106 | def D_r1(D, 107 | reals, 108 | real_labels=None, 109 | gamma=10, 110 | *args, 111 | **kwargs): 112 | loss = None 113 | reg = None 114 | if gamma: 115 | reals.requires_grad_(True) 116 | real_scores = D(reals, labels=real_labels) 117 | reg = _grad_reg( 118 | input=reals, output=real_scores, gamma=gamma, retain_graph=False).float() 119 | return loss, reg 120 | 121 | 122 | def D_r2(D, 123 | G, 124 | latents, 125 | latent_labels=None, 126 | gamma=10, 127 | *args, 128 | **kwargs): 129 | loss = None 130 | reg = None 131 | if gamma: 132 | with torch.no_grad(): 133 | fakes = G(latents, labels=latent_labels) 134 | fakes.requires_grad_(True) 135 | fake_scores = D(fakes, labels=latent_labels) 136 | reg = _grad_reg( 137 | input=fakes, output=fake_scores, gamma=gamma, retain_graph=False).float() 138 | return loss, reg 139 | 140 | 141 | def D_logistic_r1(G, 142 | D, 143 | latents, 144 | reals, 145 | latent_labels=None, 146 | real_labels=None, 147 | gamma=10, 148 | *args, 149 | **kwargs): 150 | assert (latent_labels is None) == (real_labels is None) 151 | with torch.no_grad(): 152 | fakes = G(latents, labels=latent_labels) 153 | if gamma: 154 | reals.requires_grad_(True) 155 | real_scores = D(reals, labels=real_labels).float() 156 | fake_scores = D(fakes, labels=latent_labels).float() 157 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores)) 158 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores)) 159 | loss = real_loss + fake_loss 160 | reg = None 161 | if gamma: 162 | reg = _grad_reg( 163 | input=reals, output=real_scores, gamma=gamma, retain_graph=True).float() 164 | return loss, reg 165 | 166 | 167 | def D_logistic_r2(G, 168 | D, 169 | latents, 170 | reals, 171 | latent_labels=None, 172 | real_labels=None, 173 | gamma=10, 174 | *args, 175 | **kwargs): 176 | assert (latent_labels is None) == (real_labels is None) 177 | with torch.no_grad(): 178 | fakes = G(latents, labels=latent_labels) 179 | if gamma: 180 | fakes.requires_grad_(True) 181 | real_scores = D(reals, labels=real_labels).float() 182 | fake_scores = D(fakes, labels=latent_labels).float() 183 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores)) 184 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores)) 185 | loss = real_loss + fake_loss 186 | reg = None 187 | if gamma: 188 | reg = _grad_reg( 189 | input=fakes, output=fake_scores, gamma=gamma, retain_graph=True).float() 190 | return loss, reg 191 | 192 | 193 | #---------------------------------------------------------------------------- 194 | # Non-saturating logistic loss with path length regularizer from the paper 195 | # "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019 196 | 197 | 198 | def G_pathreg(G, 199 | latents, 200 | pl_avg, 201 | latent_labels=None, 202 | pl_decay=0.01, 203 | gamma=2, 204 | *args, 205 | **kwargs): 206 | loss = None 207 | reg = None 208 | if gamma: 209 | fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True, mapping_grad=False) 210 | reg = _pathreg( 211 | dlatents=dlatents, 212 | fakes=fakes, 213 | pl_avg=pl_avg, 214 | pl_decay=pl_decay, 215 | gamma=gamma, 216 | retain_graph=False 217 | ).float() 218 | return loss, reg 219 | 220 | 221 | def G_logistic_ns_pathreg(G, 222 | D, 223 | latents, 224 | pl_avg, 225 | latent_labels=None, 226 | pl_decay=0.01, 227 | gamma=2, 228 | *args, 229 | **kwargs): 230 | fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True) 231 | fake_scores = D(fakes, labels=latent_labels).float() 232 | loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores)) 233 | reg = None 234 | if gamma: 235 | reg = _pathreg( 236 | dlatents=dlatents, 237 | fakes=fakes, 238 | pl_avg=pl_avg, 239 | pl_decay=pl_decay, 240 | gamma=gamma, 241 | retain_graph=True 242 | ).float() 243 | return loss, reg 244 | 245 | 246 | #---------------------------------------------------------------------------- 247 | # WGAN loss from the paper 248 | # "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017 249 | 250 | 251 | def G_wgan(G, 252 | D, 253 | latents, 254 | latent_labels=None, 255 | *args, 256 | **kwargs): 257 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float() 258 | loss = -fake_scores.mean() 259 | reg = None 260 | return loss, reg 261 | 262 | 263 | def D_wgan(G, 264 | D, 265 | latents, 266 | reals, 267 | latent_labels=None, 268 | real_labels=None, 269 | drift_gamma=0.001, 270 | *args, 271 | **kwargs): 272 | assert (latent_labels is None) == (real_labels is None) 273 | with torch.no_grad(): 274 | fakes = G(latents, labels=latent_labels) 275 | real_scores = D(reals, labels=real_labels).float() 276 | fake_scores = D(fakes, labels=latent_labels).float() 277 | loss = fake_scores.mean() - real_scores.mean() 278 | if drift_gamma: 279 | loss += drift_gamma * torch.mean(real_scores ** 2) 280 | reg = None 281 | return loss, reg 282 | 283 | 284 | #---------------------------------------------------------------------------- 285 | # WGAN-GP loss from the paper 286 | # "Improved Training of Wasserstein GANs", Gulrajani et al. 2017 287 | 288 | 289 | def D_gp(G, 290 | D, 291 | latents, 292 | reals, 293 | latent_labels=None, 294 | real_labels=None, 295 | gamma=0, 296 | constraint=1, 297 | *args, 298 | **kwargs): 299 | loss = None 300 | reg = None 301 | if gamma: 302 | assert (latent_labels is None) == (real_labels is None) 303 | with torch.no_grad(): 304 | fakes = G(latents, labels=latent_labels) 305 | assert reals.size() == fakes.size() 306 | if latent_labels: 307 | assert latent_labels == real_labels 308 | alpha = torch.empty(reals.size(0)).uniform_() 309 | alpha = alpha.view(-1, *[1] * (reals.dim() - 1)) 310 | interp = utils.lerp(reals, fakes, alpha).requires_grad_(True) 311 | interp_scores = D(interp, labels=latent_labels) 312 | reg = _grad_pen( 313 | input=interp, output=interp_scores, gamma=gamma, retain_graph=False).float() 314 | return loss, reg 315 | 316 | 317 | def D_wgan_gp(G, 318 | D, 319 | latents, 320 | reals, 321 | latent_labels=None, 322 | real_labels=None, 323 | gamma=0, 324 | drift_gamma=0.001, 325 | constraint=1, 326 | *args, 327 | **kwargs): 328 | assert (latent_labels is None) == (real_labels is None) 329 | with torch.no_grad(): 330 | fakes = G(latents, labels=latent_labels) 331 | real_scores = D(reals, labels=real_labels).float() 332 | fake_scores = D(fakes, labels=latent_labels).float() 333 | loss = fake_scores.mean() - real_scores.mean() 334 | if drift_gamma: 335 | loss += drift_gamma * torch.mean(real_scores ** 2) 336 | reg = None 337 | if gamma: 338 | assert reals.size() == fakes.size() 339 | if latent_labels: 340 | assert latent_labels == real_labels 341 | alpha = torch.empty(reals.size(0)).uniform_() 342 | alpha = alpha.view(-1, *[1] * (reals.dim() - 1)) 343 | interp = utils.lerp(reals, fakes, alpha).requires_grad_(True) 344 | interp_scores = D(interp, labels=latent_labels) 345 | reg = _grad_pen( 346 | input=interp, output=interp_scores, gamma=gamma, retain_graph=True).float() 347 | return loss, reg 348 | -------------------------------------------------------------------------------- /stylegan2/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from . import fid 2 | from . import ppl 3 | -------------------------------------------------------------------------------- /stylegan2/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numbers 3 | import numpy as np 4 | import scipy 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | from .. import models, utils 9 | from ..external_models import inception 10 | 11 | 12 | class _TruncatedDataset: 13 | """ 14 | Truncates a dataset, making only part of it accessible 15 | by `torch.utils.data.DataLoader`. 16 | """ 17 | 18 | def __init__(self, dataset, max_len): 19 | self.dataset = dataset 20 | self.max_len = max_len 21 | 22 | def __len__(self): 23 | return min(len(self.dataset), self.max_len) 24 | 25 | def __getitem__(self, index): 26 | return self.dataset[index] 27 | 28 | 29 | class FID: 30 | """ 31 | This class evaluates the FID metric of a generator. 32 | Arguments: 33 | G (Generator) 34 | prior_generator (PriorGenerator) 35 | dataset (indexable) 36 | device (int, str, torch.device, optional): The device 37 | to use for calculations. By default, the same device 38 | is chosen as the parameters in `generator` reside on. 39 | num_samples (int): Number of samples of reals and fakes 40 | to gather statistics for which are used for calculating 41 | the metric. Default value is 50 000. 42 | fid_model (nn.Module): A model that returns feature maps 43 | of shape (batch_size, features, *). Default value 44 | is InceptionV3. 45 | fid_size (int, optional): Resize any data fed to `fid_model` by scaling 46 | the data so that its smallest side is the same size as this 47 | argument. 48 | truncation_psi (float, optional): Truncation of the generator 49 | when evaluating. 50 | truncation_cutoff (int, optional): Cutoff for truncation when 51 | evaluating. 52 | reals_batch_size (int, optional): Batch size to use for real 53 | samples statistics gathering. 54 | reals_data_workers (int, optional): Number of workers fetching 55 | the real data samples. Default value is 0. 56 | verbose (bool): Write progress of gathering statistics for reals 57 | to stdout. Default value is True. 58 | """ 59 | def __init__(self, 60 | G, 61 | prior_generator, 62 | dataset, 63 | device=None, 64 | num_samples=50000, 65 | fid_model=None, 66 | fid_size=None, 67 | truncation_psi=None, 68 | truncation_cutoff=None, 69 | reals_batch_size=None, 70 | reals_data_workers=0, 71 | verbose=True): 72 | device_ids = [] 73 | if isinstance(G, torch.nn.DataParallel): 74 | device_ids = G.device_ids 75 | G = utils.unwrap_module(G) 76 | assert isinstance(G, models.Generator) 77 | assert isinstance(prior_generator, utils.PriorGenerator) 78 | if device is None: 79 | device = next(G.parameters()).device 80 | else: 81 | device = torch.device(device) 82 | assert torch.device(prior_generator.device) == device, \ 83 | 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \ 84 | 'is not the same as the specified (or infered from the model)' + \ 85 | 'device ({}) for the PPL evaluation.'.format(device) 86 | G.eval().to(device) 87 | if device_ids: 88 | G = torch.nn.DataParallel(G, device_ids=device_ids) 89 | self.G = G 90 | self.prior_generator = prior_generator 91 | self.device = device 92 | self.num_samples = num_samples 93 | self.batch_size = self.prior_generator.batch_size 94 | if fid_model is None: 95 | warnings.warn( 96 | 'Using default fid model metric based on Inception V3. ' + \ 97 | 'This metric will only work on image data where values are in ' + \ 98 | 'the range [-1, 1], please specify another module if you want ' + \ 99 | 'to use other kinds of data formats.' 100 | ) 101 | fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1) 102 | if device_ids: 103 | fid_model = torch.nn.DataParallel(fid_model, device_ids) 104 | self.fid_model = fid_model.eval().to(device) 105 | self.fid_size = fid_size 106 | 107 | dataset = _TruncatedDataset(dataset, self.num_samples) 108 | dataloader = torch.utils.data.DataLoader( 109 | dataset, 110 | batch_size=reals_batch_size or self.batch_size, 111 | num_workers=reals_data_workers 112 | ) 113 | features = [] 114 | self.labels = [] 115 | 116 | if verbose: 117 | progress = utils.ProgressWriter( 118 | np.ceil(self.num_samples / (reals_batch_size or self.batch_size))) 119 | progress.write('FID: Gathering statistics for reals...', step=False) 120 | 121 | for batch in dataloader: 122 | data = batch 123 | if isinstance(batch, (tuple, list)): 124 | data = batch[0] 125 | if len(batch) > 1: 126 | self.labels.append(batch[1]) 127 | data = self._scale_for_fid(data).to(self.device) 128 | with torch.no_grad(): 129 | batch_features = self.fid_model(data) 130 | batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1) 131 | features.append(batch_features.cpu()) 132 | progress.step() 133 | 134 | if verbose: 135 | progress.write('FID: Statistics for reals gathered!', step=False) 136 | progress.close() 137 | 138 | features = torch.cat(features, dim=0).numpy() 139 | 140 | self.mu_real = np.mean(features, axis=0) 141 | self.sigma_real = np.cov(features, rowvar=False) 142 | self.truncation_psi = truncation_psi 143 | self.truncation_cutoff = truncation_cutoff 144 | 145 | def _scale_for_fid(self, data): 146 | if not self.fid_size: 147 | return data 148 | scale_factor = self.fid_size / min(data.size()[2:]) 149 | if scale_factor == 1: 150 | return data 151 | mode = 'nearest' 152 | if scale_factor < 1: 153 | mode = 'area' 154 | return F.interpolate(data, scale_factor=scale_factor, mode=mode) 155 | 156 | def __call__(self, *args, **kwargs): 157 | return self.evaluate(*args, **kwargs) 158 | 159 | def evaluate(self, verbose=True): 160 | """ 161 | Evaluate the FID. 162 | Arguments: 163 | verbose (bool): Write progress to stdout. 164 | Default value is True. 165 | Returns: 166 | fid (float): Metric value. 167 | """ 168 | utils.unwrap_module(self.G).set_truncation( 169 | truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff) 170 | self.G.eval() 171 | features = [] 172 | 173 | if verbose: 174 | progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size)) 175 | progress.write('FID: Gathering statistics for fakes...', step=False) 176 | 177 | remaining = self.num_samples 178 | for i in range(0, self.num_samples, self.batch_size): 179 | 180 | latents, latent_labels = self.prior_generator( 181 | batch_size=min(self.batch_size, remaining)) 182 | if latent_labels is not None and self.labels: 183 | latent_labels = self.labels[i].to(self.device) 184 | length = min(len(latents), len(latent_labels)) 185 | latents, latent_labels = latents[:length], latent_labels[:length] 186 | 187 | with torch.no_grad(): 188 | fakes = self.G(latents, labels=latent_labels) 189 | 190 | with torch.no_grad(): 191 | batch_features = self.fid_model(fakes) 192 | batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1) 193 | features.append(batch_features.cpu()) 194 | 195 | remaining -= len(latents) 196 | progress.step() 197 | 198 | if verbose: 199 | progress.write('FID: Statistics for fakes gathered!', step=False) 200 | progress.close() 201 | 202 | features = torch.cat(features, dim=0).numpy() 203 | 204 | mu_fake = np.mean(features, axis=0) 205 | sigma_fake = np.cov(features, rowvar=False) 206 | 207 | m = np.square(mu_fake - self.mu_real).sum() 208 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False) 209 | dist = m + np.trace(sigma_fake + self.sigma_real - 2*s) 210 | return float(np.real(dist)) 211 | -------------------------------------------------------------------------------- /stylegan2/metrics/ppl.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numbers 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | from .. import models, utils 8 | from ..external_models import lpips 9 | 10 | 11 | class PPL: 12 | """ 13 | This class evaluates the PPL metric of a generator. 14 | Arguments: 15 | G (Generator) 16 | prior_generator (PriorGenerator) 17 | device (int, str, torch.device, optional): The device 18 | to use for calculations. By default, the same device 19 | is chosen as the parameters in `generator` reside on. 20 | num_samples (int): Number of samples of reals and fakes 21 | to gather statistics for which are used for calculating 22 | the metric. Default value is 50 000. 23 | epsilon (float): Perturbation value. Default value is 1e-4. 24 | use_dlatent (bool): Measure PPL against the dlatents instead 25 | of the latents. Default value is True. 26 | full_sampling (bool): Measure on a random interpolation between 27 | two inputs. Default value is False. 28 | crop (float, list, optional): Crop values that should be in the 29 | range [0, 1] with 1 representing the entire data length. 30 | If single value this will be the amount cropped from all 31 | sides of the data. If a list of same length as number of 32 | data dimensions, each crop is mirrored to both sides of 33 | each respective dimension. If the length is 2 * number 34 | of dimensions the crop values for the start and end of 35 | a dimension may be different. 36 | Example 1: 37 | We have 1d data of length 10. We want to crop 1 38 | from the start and end of the data. We then need 39 | to use `crop=0.1` or `crop=[0.1]` or `crop=[0.1, 0.9]`. 40 | Example 2: 41 | We have 2d data (images) of size 10, 10 (height, width) 42 | and we want to use only the top left quarter of the image 43 | we would use `crop=[0, 0.5, 0, 0.5]`. 44 | lpips_model (nn.Module): A model that returns feature the distance 45 | between two inputs. Default value is the LPIPS VGG16 model. 46 | lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling 47 | the data so that its smallest side is the same size as this 48 | argument. Only has a default value of 256 if `lpips_model` is unspecified. 49 | """ 50 | FFHQ_CROP = [1/8 * 3, 1/8 * 7, 1/8 * 2, 1/8 * 6] 51 | 52 | def __init__(self, 53 | G, 54 | prior_generator, 55 | device=None, 56 | num_samples=50000, 57 | epsilon=1e-4, 58 | use_dlatent=True, 59 | full_sampling=False, 60 | crop=None, 61 | lpips_model=None, 62 | lpips_size=None): 63 | device_ids = [] 64 | if isinstance(G, torch.nn.DataParallel): 65 | device_ids = G.device_ids 66 | G = utils.unwrap_module(G) 67 | assert isinstance(G, models.Generator) 68 | assert isinstance(prior_generator, utils.PriorGenerator) 69 | if device is None: 70 | device = next(G.parameters()).device 71 | else: 72 | device = torch.device(device) 73 | assert torch.device(prior_generator.device) == device, \ 74 | 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \ 75 | 'is not the same as the specified (or infered from the model)' + \ 76 | 'device ({}) for the PPL evaluation.'.format(device) 77 | G.eval().to(device) 78 | self.G_mapping = G.G_mapping 79 | self.G_synthesis = G.G_synthesis 80 | if device_ids: 81 | self.G_mapping = torch.nn.DataParallel(self.G_mapping, device_ids=device_ids) 82 | self.G_synthesis = torch.nn.DataParallel(self.G_synthesis, device_ids=device_ids) 83 | self.prior_generator = prior_generator 84 | self.device = device 85 | self.num_samples = num_samples 86 | self.epsilon = epsilon 87 | self.use_dlatent = use_dlatent 88 | self.full_sampling = full_sampling 89 | self.crop = crop 90 | self.batch_size = self.prior_generator.batch_size 91 | if lpips_model is None: 92 | warnings.warn( 93 | 'Using default LPIPS distance metric based on VGG 16. ' + \ 94 | 'This metric will only work on image data where values are in ' + \ 95 | 'the range [-1, 1], please specify an lpips module if you want ' + \ 96 | 'to use other kinds of data formats.' 97 | ) 98 | lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1) 99 | if device_ids: 100 | lpips_model = torch.nn.DataParallel(lpips_model, device_ids=device_ids) 101 | lpips_size = lpips_size or 256 102 | self.lpips_model = lpips_model.eval().to(device) 103 | self.lpips_size = lpips_size 104 | 105 | def _scale_for_lpips(self, data): 106 | if not self.lpips_size: 107 | return data 108 | scale_factor = self.lpips_size / min(data.size()[2:]) 109 | if scale_factor == 1: 110 | return data 111 | mode = 'nearest' 112 | if scale_factor < 1: 113 | mode = 'area' 114 | return F.interpolate(data, scale_factor=scale_factor, mode=mode) 115 | 116 | def crop_data(self, data): 117 | if not self.crop: 118 | return data 119 | dim = data.dim() - 2 120 | if isinstance(self.crop, numbers.Number): 121 | self.crop = [self.crop] 122 | else: 123 | self.crop = list(self.crop) 124 | if len(self.crop) == 1: 125 | self.crop = [self.crop[0], (1 if self.crop[0] < 1 else size) - self.crop[0]] * dim 126 | if len(self.crop) == dim: 127 | crop = self.crop 128 | self.crop = [] 129 | for value in crop: 130 | self.crop += [value, (1 if value < 1 else size) - value] 131 | assert len(self.crop) == 2 * dim, 'Crop values has to be ' + \ 132 | 'a single value or a sequence of values of the same ' + \ 133 | 'size as number of dimensions of the data or twice of that.' 134 | pre_index = [Ellipsis] 135 | post_index = [slice(None, None, None) for _ in range(dim)] 136 | for i in range(0, 2 * dim, 2): 137 | j = i // 2 138 | size = data.size(2 + j) 139 | crop_min, crop_max = self.crop[i:i + 2] 140 | if crop_max < 1: 141 | crop_min, crop_max = crop_min * size, crop_max * size 142 | crop_min, crop_max = max(0, int(crop_min)), min(size, int(crop_max)) 143 | dim_index = post_index.copy() 144 | dim_index[j] = slice(crop_min, crop_max, None) 145 | data = data[pre_index + dim_index] 146 | return data 147 | 148 | def prep_latents(self, latents): 149 | if self.full_sampling: 150 | lerp = utils.slerp 151 | if self.use_dlatent: 152 | lerp = utils.lerp 153 | latents_a, latents_b = latents[:self.batch_size], latents[self.batch_size:] 154 | latents = lerp( 155 | latents_a, 156 | latents_b, 157 | torch.rand( 158 | latents_a.size()[:-1], 159 | dtype=latents_a.dtype, 160 | device=latents_a.device 161 | ).unsqueeze(-1) 162 | ) 163 | return torch.cat([latents, latents + self.epsilon], dim=0) 164 | 165 | def __call__(self, *args, **kwargs): 166 | return self.evaluate(*args, **kwargs) 167 | 168 | def evaluate(self, verbose=True): 169 | """ 170 | Evaluate the PPL. 171 | Arguments: 172 | verbose (bool): Write progress to stdout. 173 | Default value is True. 174 | Returns: 175 | ppl (float): Metric value. 176 | """ 177 | distances = [] 178 | batch_size = self.batch_size 179 | if self.full_sampling: 180 | batch_size = 2 * batch_size 181 | 182 | if verbose: 183 | progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size)) 184 | progress.write('PPL: Evaluating metric...', step=False) 185 | 186 | for _ in range(0, self.num_samples, self.batch_size): 187 | utils.unwrap_module(self.G_synthesis).static_noise() 188 | 189 | latents, latent_labels = self.prior_generator(batch_size=batch_size) 190 | if latent_labels is not None and self.full_sampling: 191 | # Labels should be the same for the first and second half of latents 192 | latent_labels = latent_labels.view(2, -1)[0].repeat(2) 193 | 194 | if self.use_dlatent: 195 | with torch.no_grad(): 196 | dlatents = self.G_mapping(latents=latents, labels=latent_labels) 197 | dlatents = self.prep_latents(dlatents) 198 | else: 199 | latents = self.prep_latents(latents) 200 | with torch.no_grad(): 201 | dlatents = self.G_mapping(latents=latents, labels=latent_labels) 202 | 203 | dlatents = dlatents.unsqueeze(1).repeat(1, len(utils.unwrap_module(self.G_synthesis)), 1) 204 | 205 | with torch.no_grad(): 206 | output = self.G_synthesis(dlatents) 207 | 208 | output = self.crop_data(output) 209 | output = self._scale_for_lpips(output) 210 | 211 | output_a, output_b = output[:self.batch_size], output[self.batch_size:] 212 | 213 | with torch.no_grad(): 214 | dist = self.lpips_model(output_a, output_b) 215 | 216 | distances.append(dist.cpu() * (1 / self.epsilon ** 2)) 217 | 218 | if verbose: 219 | progress.step() 220 | 221 | if verbose: 222 | progress.write('PPL: Evaluated!', step=False) 223 | progress.close() 224 | 225 | distances = torch.cat(distances, dim=0).numpy() 226 | lo = np.percentile(distances, 1, interpolation='lower') 227 | hi = np.percentile(distances, 99, interpolation='higher') 228 | filtered_distances = np.extract(np.logical_and(lo <= distances, distances <= hi), distances) 229 | return float(np.mean(filtered_distances)) 230 | -------------------------------------------------------------------------------- /stylegan2/project.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from . import models, utils 8 | from .external_models import lpips 9 | 10 | 11 | class Projector(nn.Module): 12 | """ 13 | Projects data to latent space and noise tensors. 14 | Arguments: 15 | G (Generator) 16 | dlatent_avg_samples (int): Number of dlatent samples 17 | to collect to find the mean and std. 18 | Default value is 10 000. 19 | dlatent_avg_label (int, torch.Tensor, optional): The label to 20 | use when gathering dlatent statistics. 21 | dlatent_device (int, str, torch.device, optional): Device to use 22 | for gathering statistics of dlatents. By default uses 23 | the same device as parameters of `G` reside on. 24 | dlatent_batch_size (int): The batch size to sample 25 | dlatents with. Default value is 1024. 26 | lpips_model (nn.Module): A model that returns feature the distance 27 | between two inputs. Default value is the LPIPS VGG16 model. 28 | lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling 29 | the data so that its smallest side is the same size as this 30 | argument. Only has a default value of 256 if `lpips_model` is unspecified. 31 | verbose (bool): Write progress of dlatent statistics gathering to stdout. 32 | Default value is True. 33 | """ 34 | def __init__(self, 35 | G, 36 | dlatent_avg_samples=10000, 37 | dlatent_avg_label=None, 38 | dlatent_device=None, 39 | dlatent_batch_size=1024, 40 | lpips_model=None, 41 | lpips_size=None, 42 | verbose=True): 43 | super(Projector, self).__init__() 44 | assert isinstance(G, models.Generator) 45 | G.eval().requires_grad_(False) 46 | 47 | self.G_synthesis = G.G_synthesis 48 | 49 | G_mapping = G.G_mapping 50 | 51 | dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples) 52 | 53 | if dlatent_device is None: 54 | dlatent_device = next(G_mapping.parameters()).device() 55 | else: 56 | dlatent_device = torch.device(dlatent_device) 57 | 58 | G_mapping.to(dlatent_device) 59 | 60 | latents = torch.empty( 61 | dlatent_avg_samples, G_mapping.latent_size).normal_() 62 | dlatents = [] 63 | 64 | labels = None 65 | if dlatent_avg_label is not None: 66 | labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size) 67 | 68 | if verbose: 69 | progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size)) 70 | progress.write('Gathering dlatents...', step=False) 71 | 72 | for i in range(0, dlatent_avg_samples, dlatent_batch_size): 73 | batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device) 74 | batch_labels = None 75 | if labels is not None: 76 | batch_labels = labels[:len(batch_latents)] 77 | with torch.no_grad(): 78 | dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu()) 79 | if verbose: 80 | progress.step() 81 | 82 | if verbose: 83 | progress.write('Done!', step=False) 84 | progress.close() 85 | 86 | dlatents = torch.cat(dlatents, dim=0) 87 | 88 | self.register_buffer( 89 | '_dlatent_avg', 90 | dlatents.mean(dim=0).view(1, 1, -1) 91 | ) 92 | self.register_buffer( 93 | '_dlatent_std', 94 | torch.sqrt( 95 | torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8 96 | ).view(1, 1, 1) 97 | ) 98 | 99 | if lpips_model is None: 100 | warnings.warn( 101 | 'Using default LPIPS distance metric based on VGG 16. ' + \ 102 | 'This metric will only work on image data where values are in ' + \ 103 | 'the range [-1, 1], please specify an lpips module if you want ' + \ 104 | 'to use other kinds of data formats.' 105 | ) 106 | lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1) 107 | lpips_size = 256 108 | self.lpips_model = lpips_model.eval().requires_grad_(False) 109 | self.lpips_size = lpips_size 110 | 111 | self.to(dlatent_device) 112 | 113 | def _scale_for_lpips(self, data): 114 | if not self.lpips_size: 115 | return data 116 | scale_factor = self.lpips_size / min(data.size()[2:]) 117 | if scale_factor == 1: 118 | return data 119 | mode = 'nearest' 120 | if scale_factor < 1: 121 | mode = 'area' 122 | return F.interpolate(data, scale_factor=scale_factor, mode=mode) 123 | 124 | def _check_job(self): 125 | assert self._job is not None, 'Call `start()` first to set up target.' 126 | # device of dlatent param will not change with the rest of the models 127 | # and buffers of this class as it was never registered as a buffer or 128 | # parameter. Same goes for optimizer. Make sure it is on the correct device. 129 | if self._job.dlatent_param.device != self._dlatent_avg.device: 130 | self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg) 131 | self._job.opt.load_state_dict( 132 | utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0]) 133 | 134 | def generate(self): 135 | """ 136 | Generate an output with the current dlatent and noise values. 137 | Returns: 138 | output (torch.Tensor) 139 | """ 140 | self._check_job() 141 | with torch.no_grad(): 142 | return self.G_synthesis(self._job.dlatent_param) 143 | 144 | def get_dlatent(self): 145 | """ 146 | Get a copy of the current dlatent values. 147 | Returns: 148 | dlatents (torch.Tensor) 149 | """ 150 | self._check_job() 151 | return self._job.dlatent_param.data.clone() 152 | 153 | def get_noise(self): 154 | """ 155 | Get a copy of the current noise values. 156 | Returns: 157 | noise_tensors (list) 158 | """ 159 | self._check_job() 160 | return [noise.data.clone() for noise in self._job.noise_params] 161 | 162 | def start(self, 163 | target, 164 | num_steps=1000, 165 | initial_learning_rate=0.1, 166 | initial_noise_factor=0.05, 167 | lr_rampdown_length=0.25, 168 | lr_rampup_length=0.05, 169 | noise_ramp_length=0.75, 170 | regularize_noise_weight=1e5, 171 | verbose=True, 172 | verbose_prefix=''): 173 | """ 174 | Set up a target and its projection parameters. 175 | Arguments: 176 | target (torch.Tensor): The data target. This should 177 | already be preprocessed (scaled to correct value range). 178 | num_steps (int): Number of optimization steps. Default 179 | value is 1000. 180 | initial_learning_rate (float): Default value is 0.1. 181 | initial_noise_factor (float): Default value is 0.05. 182 | lr_rampdown_length (float): Default value is 0.25. 183 | lr_rampup_length (float): Default value is 0.05. 184 | noise_ramp_length (float): Default value is 0.75. 185 | regularize_noise_weight (float): Default value is 1e5. 186 | verbose (bool): Write progress to stdout every time 187 | `step()` is called. 188 | verbose_prefix (str, optional): This is written before 189 | any other output to stdout. 190 | """ 191 | if target.dim() == self.G_synthesis.dim + 1: 192 | target = target.unsqueeze(0) 193 | assert target.dim() == self.G_synthesis.dim + 2, \ 194 | 'Number of dimensions of target data is incorrect.' 195 | 196 | target = target.to(self._dlatent_avg) 197 | target_scaled = self._scale_for_lpips(target) 198 | 199 | dlatent_param = nn.Parameter( 200 | self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1)) 201 | noise_params = self.G_synthesis.static_noise(trainable=True) 202 | params = [dlatent_param] + noise_params 203 | 204 | opt = torch.optim.Adam(params) 205 | 206 | noise_tensor = torch.empty_like(dlatent_param) 207 | 208 | if verbose: 209 | progress = utils.ProgressWriter(num_steps) 210 | value_tracker = utils.ValueTracker() 211 | 212 | self._job = utils.AttributeDict(**locals()) 213 | self._job.current_step = 0 214 | 215 | def step(self, steps=1): 216 | """ 217 | Take a projection step. 218 | Arguments: 219 | steps (int): Number of steps to take. If this 220 | exceeds the remaining steps of the projection 221 | that amount of steps is taken instead. Default 222 | value is 1. 223 | """ 224 | self._check_job() 225 | 226 | remaining_steps = self._job.num_steps - self._job.current_step 227 | if not remaining_steps > 0: 228 | warnings.warn( 229 | 'Trying to take a projection step after the ' + \ 230 | 'final projection iteration has been completed.' 231 | ) 232 | if steps < 0: 233 | steps = remaining_steps 234 | steps = min(remaining_steps, steps) 235 | 236 | if not steps > 0: 237 | return 238 | 239 | for _ in range(steps): 240 | 241 | if self._job.current_step >= self._job.num_steps: 242 | break 243 | 244 | # Hyperparameters. 245 | t = self._job.current_step / self._job.num_steps 246 | noise_strength = self._dlatent_std * self._job.initial_noise_factor \ 247 | * max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2 248 | lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length) 249 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 250 | lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length) 251 | learning_rate = self._job.initial_learning_rate * lr_ramp 252 | 253 | for param_group in self._job.opt.param_groups: 254 | param_group['lr'] = learning_rate 255 | 256 | dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_() 257 | 258 | output = self.G_synthesis(dlatents) 259 | assert output.size() == self._job.target.size(), \ 260 | 'target size {} does not fit output size {} of generator'.format( 261 | target.size(), output.size()) 262 | 263 | output_scaled = self._scale_for_lpips(output) 264 | 265 | # Main loss: LPIPS distance of output and target 266 | lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled)) 267 | 268 | # Calculate noise regularization loss 269 | reg_loss = 0 270 | for p in self._job.noise_params: 271 | size = min(p.size()[2:]) 272 | dim = p.dim() - 2 273 | while True: 274 | reg_loss += torch.mean( 275 | (p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2) 276 | if size <= 8: 277 | break 278 | p = F.interpolate(p, scale_factor=0.5, mode='area') 279 | size = size // 2 280 | 281 | # Combine loss, backward and update params 282 | loss = lpips_distance + self._job.regularize_noise_weight * reg_loss 283 | self._job.opt.zero_grad() 284 | loss.backward() 285 | self._job.opt.step() 286 | 287 | # Normalize noise values 288 | for p in self._job.noise_params: 289 | with torch.no_grad(): 290 | p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True) 291 | p_rstd = torch.rsqrt( 292 | torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8) 293 | p.data = (p.data - p_mean) * p_rstd 294 | 295 | self._job.current_step += 1 296 | 297 | if self._job.verbose: 298 | self._job.value_tracker.add('loss', float(loss)) 299 | self._job.value_tracker.add('lpips_distance', float(lpips_distance)) 300 | self._job.value_tracker.add('noise_reg', float(reg_loss)) 301 | self._job.value_tracker.add('lr', learning_rate, beta=0) 302 | self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker)) 303 | if self._job.current_step >= self._job.num_steps: 304 | self._job.progress.close() 305 | -------------------------------------------------------------------------------- /stylegan2/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numbers 3 | import re 4 | import sys 5 | import collections 6 | import argparse 7 | import yaml 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | from torch import nn 12 | from torch.nn import functional as F 13 | import torchvision 14 | try: 15 | import tqdm 16 | except ImportError: 17 | pass 18 | try: 19 | from IPython.display import display as notebook_display 20 | from IPython.display import clear_output as notebook_clear 21 | except ImportError: 22 | pass 23 | 24 | 25 | #---------------------------------------------------------------------------- 26 | # Miscellaneous utils 27 | 28 | 29 | class AttributeDict(dict): 30 | """ 31 | Dict where values can be accessed using attribute syntax. 32 | Same as "EasyDict" in the NVIDIA stylegan git repository. 33 | """ 34 | 35 | def __getattr__(self, name): 36 | try: 37 | return self[name] 38 | except KeyError: 39 | raise AttributeError(name) 40 | 41 | def __setattr__(self, name, value): 42 | self[name] = value 43 | 44 | def __delattr__(self, name): 45 | del self[name] 46 | 47 | def __getstate__(self): 48 | return dict(**self) 49 | 50 | def __setstate__(self, state): 51 | self.update(**state) 52 | 53 | def __repr__(self): 54 | return '{}({})'.format( 55 | self.__class__.__name__, 56 | ', '.join('{}={}'.format(key, value) for key, value in self.items()) 57 | ) 58 | 59 | @classmethod 60 | def convert_dict_recursive(cls, obj): 61 | if isinstance(obj, dict): 62 | for key in list(obj.keys()): 63 | obj[key] = cls.convert_dict_recursive(obj[key]) 64 | if not isinstance(obj, cls): 65 | return cls(**obj) 66 | return obj 67 | 68 | 69 | class Timer: 70 | 71 | def __init__(self): 72 | self.reset() 73 | 74 | def __enter__(self): 75 | self._t0 = time.time() 76 | 77 | def __exit__(self, *args): 78 | self._t += time.time() - self._t0 79 | 80 | def value(self): 81 | return self._t 82 | 83 | def reset(self): 84 | self._t = 0 85 | 86 | def __str__(self): 87 | """ 88 | Get a string representation of the recorded time. 89 | Returns: 90 | time_as_string (str) 91 | """ 92 | value = self.value() 93 | if not value or value >= 100: 94 | return '{} s'.format(int(value)) 95 | elif value >= 1: 96 | return '{:.3g} s'.format(value) 97 | elif value >= 1e-3: 98 | return '{:.3g} ms'.format(value * 1e+3) 99 | elif value >= 1e-6: 100 | return '{:.3g} us'.format(value * 1e+6) 101 | elif value >= 1e-9: 102 | return '{:.3g} ns'.format(value * 1e+9) 103 | else: 104 | return '{:.2E} s'.format(value) 105 | 106 | 107 | def to_list(values): 108 | if values is None: 109 | return [] 110 | if isinstance(values, tuple): 111 | return list(values) 112 | if not isinstance(values, list): 113 | return [values] 114 | return values 115 | 116 | 117 | def lerp(a, b, beta): 118 | if isinstance(beta, numbers.Number): 119 | if beta == 1: 120 | return b 121 | elif beta == 0: 122 | return a 123 | if torch.is_tensor(a) and a.dtype == torch.float32: 124 | # torch lerp only available for fp32 125 | return torch.lerp(a, b, beta) 126 | # More numerically stable than a + beta * (b - a) 127 | return (1 - beta) * a + beta * b 128 | 129 | 130 | def _normalize(v): 131 | return v * torch.rsqrt(torch.sum(v ** 2, dim=-1, keepdim=True)) 132 | 133 | 134 | def slerp(a, b, beta): 135 | assert a.size() == b.size(), 'Size mismatch between ' + \ 136 | 'slerp arguments, received {} and {}'.format(a.size(), b.size()) 137 | if not torch.is_tensor(beta): 138 | beta = torch.tensor(beta).to(a) 139 | a = _normalize(a) 140 | b = _normalize(b) 141 | d = torch.sum(a * b, axis=-1, keepdim=True) 142 | p = beta * torch.acos(beta) 143 | c = _normalize(b - d * a) 144 | d = a * torch.cos(p) + c * torch.sin(p) 145 | return _normalize(d) 146 | 147 | 148 | #---------------------------------------------------------------------------- 149 | # Command line utils 150 | 151 | 152 | def _parse_configs(configs): 153 | kwargs = {} 154 | for config in configs: 155 | with open(config, 'r') as fp: 156 | kwargs.update(yaml.safe_load(fp)) 157 | return kwargs 158 | 159 | 160 | class ConfigArgumentParser(argparse.ArgumentParser): 161 | 162 | _CONFIG_ARG_KEY = '_configs' 163 | 164 | def __init__(self, *args, **kwargs): 165 | super(ConfigArgumentParser, self).__init__(*args, **kwargs) 166 | self.add_argument( 167 | self._CONFIG_ARG_KEY, 168 | nargs='*', 169 | help='Any yaml-style config file whos values will override the defaults of this argument parser.', 170 | type=str 171 | ) 172 | 173 | def parse_args(self, args=None): 174 | config_args = _parse_configs( 175 | getattr( 176 | super(ConfigArgumentParser, self).parse_args(args), 177 | self._CONFIG_ARG_KEY 178 | ) 179 | ) 180 | self.set_defaults(**config_args) 181 | return super(ConfigArgumentParser, self).parse_args(args) 182 | 183 | 184 | def bool_type(v): 185 | if isinstance(v, bool): 186 | return v 187 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 188 | return True 189 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 190 | return False 191 | else: 192 | raise argparse.ArgumentTypeError('Boolean value expected.') 193 | 194 | 195 | def range_type(s): 196 | """ 197 | Accept either a comma separated list of numbers 198 | 'a,b,c' or a range 'a-c' and return as a list of ints. 199 | """ 200 | range_re = re.compile(r'^(\d+)-(\d+)$') 201 | m = range_re.match(s) 202 | if m: 203 | return range(int(m.group(1)), int(m.group(2))+1) 204 | vals = s.split(',') 205 | return [int(x) for x in vals] 206 | 207 | 208 | #---------------------------------------------------------------------------- 209 | # Dataset and generation of latents 210 | 211 | 212 | class ResizeTransform: 213 | 214 | def __init__(self, height, width, resize=True, mode='bicubic'): 215 | if resize: 216 | assert height and width, 'Height and width have to be given ' + \ 217 | 'when resizing data.' 218 | self.height = height 219 | self.width = width 220 | self.resize = resize 221 | self.mode = mode 222 | 223 | def __call__(self, tensor): 224 | if self.height and self.width: 225 | if tensor.size(1) != self.height or tensor.size(2) != self.width: 226 | if self.resize: 227 | kwargs = {} 228 | if 'cubic' in self.mode or 'linear' in self.mode: 229 | kwargs.update(align_corners=False) 230 | tensor = F.interpolate( 231 | tensor.unsqueeze(0), 232 | size=(self.height, self.width), 233 | mode=self.mode, 234 | **kwargs 235 | ).squeeze(0) 236 | else: 237 | raise ValueError( 238 | 'Data shape incorrect, expected ({},{}) '.format(self.width, self.height) + \ 239 | 'but got ({},{}) (width, height)'.format(tensor.size(2), tensor.size(1)) 240 | ) 241 | return tensor 242 | 243 | 244 | def _PIL_RGB_loader(path): 245 | return Image.open(path).convert('RGB') 246 | 247 | 248 | def _PIL_grayscale_loader(path): 249 | return Image.open(path).convert('L') 250 | 251 | 252 | class ImageFolder(torchvision.datasets.ImageFolder): 253 | 254 | def __init__(self, 255 | *args, 256 | mirror=False, 257 | pixel_min=-1, 258 | pixel_max=1, 259 | height=None, 260 | width=None, 261 | resize=False, 262 | resize_mode='bicubic', 263 | grayscale=False, 264 | **kwargs): 265 | super(ImageFolder, self).__init__( 266 | *args, 267 | loader=_PIL_grayscale_loader if grayscale else _PIL_RGB_loader, 268 | **kwargs 269 | ) 270 | transforms = [] 271 | if mirror: 272 | transforms.append(torchvision.transforms.RandomHorizontalFlip()) 273 | transforms.append(torchvision.transforms.ToTensor()) 274 | transforms.append( 275 | torchvision.transforms.Normalize( 276 | mean=[-(pixel_min / (pixel_max - pixel_min))], 277 | std=[1. / (pixel_max - pixel_min)] 278 | ) 279 | ) 280 | transforms.append(ResizeTransform( 281 | height=height, width=width, resize=resize, mode=resize_mode)) 282 | self.transform = torchvision.transforms.Compose(transforms) 283 | 284 | def _find_classes(self, *args, **kwargs): 285 | classes, class_to_idx = super(ImageFolder, self)._find_classes(*args, **kwargs) 286 | if not classes: 287 | classes = [''] 288 | class_to_idx = {'': 0} 289 | return classes, class_to_idx 290 | 291 | 292 | class PriorGenerator: 293 | 294 | def __init__(self, latent_size, label_size, batch_size, device): 295 | self.latent_size = latent_size 296 | self.label_size = label_size 297 | self.batch_size = batch_size 298 | self.device = device 299 | 300 | def __iter__(self): 301 | return self 302 | 303 | def __next__(self): 304 | return self() 305 | 306 | def __call__(self, batch_size=None, multi_latent_prob=0, seed=None): 307 | if batch_size is None: 308 | batch_size = self.batch_size 309 | shape = [batch_size, self.latent_size] 310 | if multi_latent_prob: 311 | if seed is not None: 312 | np.random.seed(seed) 313 | if np.random.uniform() < multi_latent_prob: 314 | shape = [batch_size, 2, self.latent_size] 315 | if seed is not None: 316 | torch.manual_seed(seed) 317 | latents = torch.empty(*shape, device=self.device).normal_() 318 | labels = None 319 | if self.label_size: 320 | label_shape = [batch_size] 321 | labels = torch.randint(0, self.label_size, label_shape, device=self.device) 322 | return latents, labels 323 | 324 | 325 | #---------------------------------------------------------------------------- 326 | # Training utils 327 | 328 | 329 | class MovingAverageModule: 330 | 331 | def __init__(self, 332 | from_module, 333 | to_module=None, 334 | param_beta=0.995, 335 | buffer_beta=0, 336 | device=None): 337 | from_module = unwrap_module(from_module) 338 | to_module = unwrap_module(to_module) 339 | if device is None: 340 | module = from_module 341 | if to_module is not None: 342 | module = to_module 343 | device = next(module.parameters()).device 344 | else: 345 | device = torch.device(device) 346 | self.from_module = from_module 347 | if to_module is None: 348 | self.module = from_module.clone().to(device) 349 | else: 350 | assert type(to_module) == type(from_module), \ 351 | 'Mismatch between type of source and target module.' 352 | assert set(self._get_named_parameters(to_module).keys()) \ 353 | == set(self._get_named_parameters(from_module).keys()), \ 354 | 'Mismatch between parameters of source and target module.' 355 | assert set(self._get_named_buffers(to_module).keys()) \ 356 | == set(self._get_named_buffers(from_module).keys()), \ 357 | 'Mismatch between buffers of source and target module.' 358 | self.module = to_module.to(device) 359 | self.module.eval().requires_grad_(False) 360 | self.param_beta = param_beta 361 | self.buffer_beta = buffer_beta 362 | self.device = device 363 | 364 | def __getattr__(self, name): 365 | try: 366 | return super(object, self).__getattr__(name) 367 | except AttributeError: 368 | return getattr(self.module, name) 369 | 370 | def update(self): 371 | self._update_data( 372 | from_data=self._get_named_parameters(self.from_module), 373 | to_data=self._get_named_parameters(self.module), 374 | beta=self.param_beta 375 | ) 376 | self._update_data( 377 | from_data=self._get_named_buffers(self.from_module), 378 | to_data=self._get_named_buffers(self.module), 379 | beta=self.buffer_beta 380 | ) 381 | 382 | @staticmethod 383 | def _update_data(from_data, to_data, beta): 384 | for name in from_data.keys(): 385 | if name not in to_data: 386 | continue 387 | fr, to = from_data[name], to_data[name] 388 | with torch.no_grad(): 389 | if beta == 0: 390 | to.data.copy_(fr.data.to(to.data)) 391 | elif beta < 1: 392 | to.data.copy_(lerp(fr.data.to(to.data), to.data, beta)) 393 | 394 | @staticmethod 395 | def _get_named_parameters(module): 396 | return {name: value for name, value in module.named_parameters()} 397 | 398 | @staticmethod 399 | def _get_named_buffers(module): 400 | return {name: value for name, value in module.named_buffers()} 401 | 402 | def __call__(self, *args, **kwargs): 403 | return self.forward(*args, **kwargs) 404 | 405 | def forward(self, *args, **kwargs): 406 | self.module.eval() 407 | args, args_in_device = move_to_device(args, self.device) 408 | kwargs, kwargs_in_device = move_to_device(kwargs, self.device) 409 | in_device = None 410 | if args_in_device is not None: 411 | in_device = args_in_device 412 | if kwargs_in_device is not None: 413 | in_device = kwargs_in_device 414 | out = self.module(*args, **kwargs) 415 | if in_device is not None: 416 | out, _ = move_to_device(out, in_device) 417 | return out 418 | 419 | 420 | def move_to_device(value, device): 421 | if torch.is_tensor(value): 422 | value.to(device), value.device 423 | orig_device = None 424 | if isinstance(value, (tuple, list)): 425 | values = [] 426 | for val in value: 427 | _val, orig_device = move_to_device(val, device) 428 | values.append(_val) 429 | return type(value)(values), orig_device 430 | if isinstance(value, dict): 431 | if isinstance(value, collections.OrderedDict): 432 | values = collections.OrderedDict() 433 | else: 434 | values = {} 435 | for key, val in value.items(): 436 | _val, orig_device = move_to_device(val, device) 437 | values[key] = val 438 | return values, orig_device 439 | return value, orig_device 440 | 441 | 442 | _WRAPPER_CLASSES = (MovingAverageModule, nn.DataParallel, nn.parallel.DistributedDataParallel) 443 | def unwrap_module(module): 444 | if isinstance(module, _WRAPPER_CLASSES): 445 | return module.module 446 | return module 447 | 448 | 449 | def get_grad_norm_from_optimizer(optimizer, norm_type=2): 450 | """ 451 | Get the gradient norm for some parameters contained in an optimizer. 452 | Arguments: 453 | optimizer (torch.optim.Optimizer) 454 | norm_type (int): Type of norm. Default value is 2. 455 | Returns: 456 | norm (float) 457 | """ 458 | total_norm = 0 459 | if optimizer is not None: 460 | for param_group in optimizer.param_groups: 461 | for p in param_group['params']: 462 | if p.grad is not None: 463 | with torch.no_grad(): 464 | param_norm = p.grad.data.norm(norm_type) 465 | total_norm += param_norm ** norm_type 466 | total_norm = total_norm ** (1. / norm_type) 467 | return total_norm.item() 468 | 469 | 470 | #---------------------------------------------------------------------------- 471 | # printing and logging utils 472 | 473 | 474 | class ValueTracker: 475 | 476 | def __init__(self, beta=0.95): 477 | self.beta = beta 478 | self.values = {} 479 | 480 | def add(self, name, value, beta=None): 481 | if torch.is_tensor(value): 482 | value = value.item() 483 | if beta is None: 484 | beta = self.beta 485 | if name not in self.values: 486 | self.values[name] = value 487 | else: 488 | self.values[name] = lerp(value, self.values[name], beta) 489 | 490 | def __getitem__(self, key): 491 | return self.values[key] 492 | 493 | def __str__(self): 494 | string = '' 495 | for i, name in enumerate(sorted(self.values.keys())): 496 | if i and i % 3 == 0: 497 | string += '\n' 498 | elif string: 499 | string += ', ' 500 | format_string = '{}: {}' 501 | if isinstance(self.values[name], float): 502 | format_string = '{}: {:.4g}' 503 | string += format_string.format(name, self.values[name]) 504 | return string 505 | 506 | 507 | def is_notebook(): 508 | """ 509 | Check if code is running from jupyter notebook. 510 | Returns: 511 | notebook (bool): True if running from jupyter notebook, 512 | else False. 513 | """ 514 | try: 515 | __IPYTHON__ 516 | return True 517 | except NameError: 518 | return False 519 | 520 | 521 | def _progress_bar(count, total): 522 | """ 523 | Get a simple one-line string representing a progress bar. 524 | Arguments: 525 | count (int): Current count. Starts at 0. 526 | total (int): Total count. 527 | Returns: 528 | pbar_string (str): The string progress bar. 529 | """ 530 | bar_len = 60 531 | filled_len = int(round(bar_len * (count + 1) / float(total))) 532 | bar = '=' * filled_len + '-' * (bar_len - filled_len) 533 | return '[{}] {}/{}'.format(bar, count + 1, total) 534 | 535 | 536 | class ProgressWriter: 537 | """ 538 | Handles writing output and displaying a progress bar. Automatically 539 | adjust for notebooks. Supports outputting text 540 | that is compatible with the progressbar (in notebooks the text is 541 | refreshed instead of printed). 542 | Arguments: 543 | length (int, optional): Total length of the progressbar. 544 | Default value is None. 545 | progress_bar (bool, optional): Display a progressbar. 546 | Default value is True. 547 | clear (bool, optional): If running from a notebook, clear 548 | the current cell's output. Default value is False. 549 | """ 550 | def __init__(self, length=None, progress_bar=True, clear=False): 551 | if is_notebook() and clear: 552 | notebook_clear() 553 | 554 | if length is not None: 555 | length = int(length) 556 | self.length = length 557 | self.count = 0 558 | 559 | self._simple_pbar = False 560 | if progress_bar and 'tqdm' not in sys.modules: 561 | self._simple_pbar = True 562 | 563 | progress_bar = progress_bar and 'tqdm' in sys.modules 564 | 565 | self._progress_bar = None 566 | if progress_bar: 567 | pbar = tqdm.tqdm 568 | if is_notebook(): 569 | pbar = tqdm.tqdm_notebook 570 | if length is not None: 571 | self._progress_bar = pbar(total=length, file=sys.stdout) 572 | else: 573 | self._progress_bar = pbar(file=sys.stdout) 574 | 575 | if is_notebook(): 576 | self._writer = notebook_display( 577 | _StrRepr(''), 578 | display_id=time.asctime() 579 | ) 580 | else: 581 | if progress_bar: 582 | self._writer = self._progress_bar 583 | else: 584 | self._writer = sys.stdout 585 | 586 | def write(self, *lines, step=True): 587 | """ 588 | Output values to stdout (or a display object if called from notebook). 589 | Arguments: 590 | *lines: The lines to write (positional arguments). 591 | step (bool): Update the progressbar if present. 592 | Default value is True. 593 | """ 594 | string = '\n'.join(str(line) for line in lines if line and line.strip()) 595 | if self._simple_pbar: 596 | string = _progress_bar(self.count, self.length) + '\n' + string 597 | if is_notebook(): 598 | self._writer.update(_StrRepr(string)) 599 | else: 600 | self._writer.write('\n\n' + string) 601 | if hasattr(self._writer, 'flush'): 602 | self._writer.flush() 603 | if step: 604 | self.step() 605 | 606 | def step(self): 607 | """ 608 | Update the progressbar if present. 609 | """ 610 | self.count += 1 611 | if self._progress_bar is not None: 612 | self._progress_bar.update() 613 | 614 | def __iter__(self): 615 | return self 616 | 617 | def __next__(self): 618 | return next(self.rnge) 619 | 620 | def close(self): 621 | if hasattr(self._writer, 'close'): 622 | can_close = True 623 | try: 624 | can_close = self._writer != sys.stdout and self._writer != sys.stderr 625 | except AttributeError: 626 | pass 627 | if can_close: 628 | self._writer.close() 629 | if hasattr(self._progress_bar, 'close'): 630 | self._progress_bar.close() 631 | 632 | def __del__(self): 633 | self.close() 634 | 635 | 636 | class _StrRepr: 637 | """ 638 | A wrapper for strings that returns the string 639 | on repr() calls. Used by notebooks. 640 | """ 641 | def __init__(self, string): 642 | self.string = string 643 | 644 | def __repr__(self): 645 | return self.string 646 | 647 | 648 | #---------------------------------------------------------------------------- 649 | # image utils 650 | 651 | 652 | def tensor_to_PIL(image_tensor, pixel_min=-1, pixel_max=1): 653 | image_tensor = image_tensor.cpu() 654 | if pixel_min != 0 or pixel_max != 1: 655 | image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min) 656 | image_tensor.clamp_(min=0, max=1) 657 | to_pil = torchvision.transforms.functional.to_pil_image 658 | if image_tensor.dim() == 4: 659 | return [to_pil(img) for img in image_tensor] 660 | return to_pil(image_tensor) 661 | 662 | 663 | def PIL_to_tensor(image, pixel_min=-1, pixel_max=1): 664 | to_tensor = torchvision.transforms.functional.to_tensor 665 | if isinstance(image, (list, tuple)): 666 | image_tensor = torch.stack([to_tensor(img) for img in image]) 667 | else: 668 | image_tensor = to_tensor(image) 669 | if pixel_min != 0 or pixel_max != 1: 670 | image_tensor = image_tensor * (pixel_max - pixel_min) + pixel_min 671 | return image_tensor 672 | 673 | 674 | def stack_images_PIL(imgs, shape=None, individual_img_size=None): 675 | """ 676 | Concatenate multiple images into a grid within a single image. 677 | Arguments: 678 | imgs (Sequence of PIL.Image): Input images. 679 | shape (list, tuple, int, optional): Shape of the grid. Should consist 680 | of two values, (width, height). If an integer value is passed it 681 | is used for both width and height. If no value is passed the shape 682 | is infered from the number of images. Default value is None. 683 | individual_img_size (list, tuple, int, optional): The size of the 684 | images being concatenated. Default value is None. 685 | Returns: 686 | canvas (PIL.Image): Image containing input images in a grid. 687 | """ 688 | assert len(imgs) > 0, 'No images received.' 689 | if shape is None: 690 | size = int(np.ceil(np.sqrt(len(imgs)))) 691 | shape = [int(np.ceil(len(imgs) / size)), size] 692 | else: 693 | if isinstance(shape, numbers.Number): 694 | shape = 2 * [shape] 695 | assert len(shape) == 2, 'Shape should specify (width, height).' 696 | 697 | if individual_img_size is None: 698 | for i in range(len(imgs) - 1): 699 | assert imgs[i].size == imgs[i + 1].size, \ 700 | 'Images are of different sizes, please specify a ' + \ 701 | 'size (width, height). Found sizes:\n' + \ 702 | ', '.join(str(img.size) for img in imgs) 703 | individual_img_size = imgs[0].size 704 | else: 705 | if not isinstance(individual_img_size, (tuple, list)): 706 | individual_img_size = 2 * (individual_img_size,) 707 | individual_img_size = tuple(individual_img_size) 708 | for i in range(len(imgs)): 709 | if imgs[i].size != individual_img_size: 710 | imgs[i] = imgs[i].resize(individual_img_size) 711 | 712 | width, height = individual_img_size 713 | width, height = int(width), int(height) 714 | canvas = Image.new( 715 | 'RGB', 716 | (shape[0] * width, shape[1] * height), 717 | (0, 0, 0, 0) 718 | ) 719 | imgs = imgs.copy() 720 | for h_i in range(shape[1]): 721 | for w_i in range(shape[0]): 722 | if len(imgs) > 0: 723 | img = imgs.pop(0).convert('RGB') 724 | offset = (w_i * width, h_i * height) 725 | canvas.paste(img, offset) 726 | return canvas 727 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from matplotlib import pyplot as plt 4 | 5 | def save_grid(images, path): 6 | grid = torchvision.utils.make_grid(images) 7 | torchvision.utils.save_image(grid, path) 8 | 9 | def show_grid(images): 10 | grid = torchvision.utils.make_grid(images) 11 | plt.imshow(grid.permute(1, 2, 0).cpu().detach().numpy()) 12 | plt.show() 13 | 14 | def biggan_norm(images): 15 | images = (images + 1) / 2.0 16 | images = images.clip(0, 1) 17 | return images 18 | 19 | def biggan_denorm(images): 20 | images = images*2 - 1 21 | return images 22 | 23 | 24 | def freeze_model(model): 25 | for param in model.parameters(): 26 | param.requires_grad = False 27 | 28 | --------------------------------------------------------------------------------