├── .gitignore
├── LICENSE
├── README.md
├── assets
└── bpe_simple_vocab_16e6.txt.gz
├── clip
├── clip.py
├── model.py
└── simple_tokenizer.py
├── config.py
├── download-weights.sh
├── generator.py
├── gpt2
├── config.py
├── encoder.py
├── model.py
├── sample.py
├── utils.py
└── weights
│ ├── encoder.json
│ └── vocab.bpe
├── gpt2_images
├── dog.jpeg
├── goldfish.jpeg
├── harmonica.jpeg
├── harp.jpeg
├── knot.jpeg
├── radio_telescope.jpeg
├── teapot.jpeg
├── telephone.jpeg
└── zebra.jpeg
├── latent.py
├── models.py
├── operators.py
├── problem.py
├── requirements.txt
├── run.py
├── stylegan2
├── __init__.py
├── convert_from_tf.py
├── external_models
│ ├── __init__.py
│ ├── inception.py
│ └── lpips.py
├── loss_fns.py
├── metrics
│ ├── __init__.py
│ ├── fid.py
│ └── ppl.py
├── models.py
├── modules.py
├── project.py
├── train.py
└── utils.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /env
2 | __pycache__/
3 | /tmp
4 |
5 | /stylegan2/weights
6 | /gpt2/weights/gpt2-pytorch_model.bin
7 |
8 | /.vscode
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CLIP-GLaSS
2 |
3 | Repository for the paper [Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search](https://arxiv.org/abs/2102.01645)
4 |
5 |
6 | ### **An in-browser demo is available [here](https://colab.research.google.com/drive/1fWka_U56NhCegbbrQPt4PWpHPtNRdU49?usp=sharing)**
7 |
8 |
9 | ## Installation
10 |
11 | Clone this repository
12 |
13 | ```
14 | git clone https://github.com/galatolofederico/clip-glass && cd clip-glass
15 | ```
16 |
17 | Create a virtual environment and install the requirements
18 |
19 | ```
20 | virtualenv --python=python3.6 env && . ./env/bin/activate
21 | pip install -r requirements.txt
22 | ```
23 |
24 | ## Run CLIP-GLaSS
25 |
26 | You can run `CLIP-GLaSS` with:
27 |
28 | ```
29 | python run.py --config --target
30 | ```
31 |
32 | Specifying `` and `` according to the following table:
33 |
34 | | Config | Meaning | Target Type |
35 | |:--------------------:|:--------------------------------------------------------------------------:|:-----------:|
36 | | GPT2 | Use GPT2 to solve the Image-to-Text task | Image |
37 | | DeepMindBigGAN512 | Use DeepMind's BigGAN 512x512 to solve the Text-to-Image task | Text |
38 | | DeepMindBigGAN256 | Use DeepMind's BigGAN 256x256 to solve the Text-to-Image task | Text |
39 | | StyleGAN2_ffhq_d | Use StyleGAN2-ffhq to solve the Text-to-Image task | Text |
40 | | StyleGAN2_ffhq_nod | Use StyleGAN2-ffhq without Discriminator to solve the Text-to-Image task | Text |
41 | | StyleGAN2_church_d | Use StyleGAN2-church to solve the Text-to-Image task | Text |
42 | | StyleGAN2_church_nod | Use StyleGAN2-church without Discriminator to solve the Text-to-Image task | Text |
43 | | StyleGAN2_car_d | Use StyleGAN2-car to solve the Text-to-Image task | Text |
44 | | StyleGAN2_car_nod | Use StyleGAN2-car without Discriminator to solve the Text-to-Image task | Text |
45 |
46 |
47 | If you do not have downloaded the models weights you will be prompted to run `./download-weights.sh`
48 | You will find the results in the folder `./tmp`, a different output folder can be specified with `--tmp-folder`
49 |
50 | #### Examples
51 |
52 | ```
53 | python run.py --config StyleGAN2_ffhq_d --target "the face of a man with brown eyes and stubble beard"
54 | python run.py --config GPT2 --target gpt2_images/dog.jpeg
55 | ```
56 |
57 |
58 | ## Acknowledgments and licensing
59 |
60 | This work heavily relies on the following amazing repositories and would have not been possible without them:
61 |
62 | * [CLIP](https://github.com/openai/CLIP) from [openai](https://github.com/openai) (included in the folder `clip`)
63 | * [pytorch-pretrained-BigGAN](https://github.com/huggingface/pytorch-pretrained-BigGAN) from [huggingface](https://github.com/huggingface)
64 | * [stylegan2-pytorch](https://github.com/Tetratrio/stylegan2_pytorch) from [Adrian Sahlman](https://github.com/Tetratrio) (included in the folder `stylegan2`)
65 | * [gpt-2-pytorch](https://github.com/graykode/gpt-2-Pytorch) from [Tae-Hwan Jung](https://github.com/graykode) (included in the folder `gpt2`)
66 |
67 | All their work can be shared under the terms of the respective original licenses.
68 |
69 | All my original work (everything except the content of the folders `clip`, `stylegan2` and `gpt2`) is released under the terms of the [GNU/GPLv3](https://choosealicense.com/licenses/gpl-3.0/) license. Copying, adapting and republishing it is not only consent but also encouraged.
70 |
71 | ## Citing
72 |
73 | If you want to cite use you can use this BibTeX
74 |
75 | ```
76 | @article{generating2021,
77 | author={Federico Galatolo. and Mario Cimino. and Gigliola Vaglini},
78 | title={Generating Images from Caption and Vice Versa via CLIP-Guided Generative Latent Space Search},
79 | journal={Proceedings of the International Conference on Image Processing and Vision Engineering},
80 | year={2021},
81 | volume={},
82 | pages={},
83 | publisher={SCITEPRESS - Science and Technology Publications},
84 | doi={10.5220/0010503701660174},
85 | issn={},
86 | }
87 | ```
88 |
89 | ## Contacts
90 |
91 | For any further question feel free to reach me at [federico.galatolo@ing.unipi.it](mailto:federico.galatolo@ing.unipi.it) or on Telegram [@galatolo](https://t.me/galatolo)
92 |
--------------------------------------------------------------------------------
/assets/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/assets/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from clip.model import build_model
13 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | __all__ = ["available_models", "load", "tokenize"]
16 | _tokenizer = _Tokenizer()
17 |
18 | _MODELS = {
19 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
21 | }
22 |
23 |
24 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
25 | os.makedirs(root, exist_ok=True)
26 | filename = os.path.basename(url)
27 |
28 | expected_sha256 = url.split("/")[-2]
29 | download_target = os.path.join(root, filename)
30 |
31 | if os.path.exists(download_target) and not os.path.isfile(download_target):
32 | raise RuntimeError(f"{download_target} exists and is not a regular file")
33 |
34 | if os.path.isfile(download_target):
35 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
36 | return download_target
37 | else:
38 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
39 |
40 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
41 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80) as loop:
42 | while True:
43 | buffer = source.read(8192)
44 | if not buffer:
45 | break
46 |
47 | output.write(buffer)
48 | loop.update(len(buffer))
49 |
50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
51 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
52 |
53 | return download_target
54 |
55 |
56 | def available_models():
57 | return list(_MODELS.keys())
58 |
59 |
60 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
61 | if name not in _MODELS:
62 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
63 |
64 | model_path = _download(_MODELS[name])
65 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
66 | n_px = model.input_resolution.item()
67 |
68 | transform = Compose([
69 | Resize(n_px, interpolation=Image.BICUBIC),
70 | CenterCrop(n_px),
71 | lambda image: image.convert("RGB"),
72 | ToTensor(),
73 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
74 | ])
75 |
76 | if not jit:
77 | model = build_model(model.state_dict()).to(device)
78 | return model, transform
79 |
80 | # patch the device names
81 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
82 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
83 |
84 | def patch_device(module):
85 | graphs = [module.graph] if hasattr(module, "graph") else []
86 | if hasattr(module, "forward1"):
87 | graphs.append(module.forward1.graph)
88 |
89 | for graph in graphs:
90 | for node in graph.findAllNodes("prim::Constant"):
91 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
92 | node.copyAttributes(device_node)
93 |
94 | model.apply(patch_device)
95 | patch_device(model.encode_image)
96 | patch_device(model.encode_text)
97 |
98 | # patch dtype to float32 on CPU
99 | if device == "cpu":
100 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
101 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
102 | float_node = float_input.node()
103 |
104 | def patch_float(module):
105 | graphs = [module.graph] if hasattr(module, "graph") else []
106 | if hasattr(module, "forward1"):
107 | graphs.append(module.forward1.graph)
108 |
109 | for graph in graphs:
110 | for node in graph.findAllNodes("aten::to"):
111 | inputs = list(node.inputs())
112 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
113 | if inputs[i].node()["value"] == 5:
114 | inputs[i].node().copyAttributes(float_node)
115 |
116 | model.apply(patch_float)
117 | patch_float(model.encode_image)
118 | patch_float(model.encode_text)
119 |
120 | model.float()
121 |
122 | return model, transform
123 |
124 |
125 | def tokenize(texts: Union[str, List[str]], context_length: int = 77):
126 | if isinstance(texts, str):
127 | texts = [texts]
128 |
129 | sot_token = _tokenizer.encoder["<|startoftext|>"]
130 | eot_token = _tokenizer.encoder["<|endoftext|>"]
131 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
132 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
133 |
134 | for i, tokens in enumerate(all_tokens):
135 | if len(tokens) > context_length:
136 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
137 | result[i, :len(tokens)] = torch.tensor(tokens)
138 |
139 | return result
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | expansion = 4
11 |
12 | def __init__(self, inplanes, planes, stride=1):
13 | super().__init__()
14 |
15 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
16 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 |
19 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 |
22 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
23 |
24 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
25 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
26 |
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | self.stride = stride
30 |
31 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
32 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
33 | self.downsample = nn.Sequential(OrderedDict([
34 | ("-1", nn.AvgPool2d(stride)),
35 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
36 | ("1", nn.BatchNorm2d(planes * self.expansion))
37 | ]))
38 |
39 | def forward(self, x: torch.Tensor):
40 | identity = x
41 |
42 | out = self.relu(self.bn1(self.conv1(x)))
43 | out = self.relu(self.bn2(self.conv2(out)))
44 | out = self.avgpool(out)
45 | out = self.bn3(self.conv3(out))
46 |
47 | if self.downsample is not None:
48 | identity = self.downsample(x)
49 |
50 | out += identity
51 | out = self.relu(out)
52 | return out
53 |
54 |
55 | class AttentionPool2d(nn.Module):
56 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
57 | super().__init__()
58 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
59 | self.k_proj = nn.Linear(embed_dim, embed_dim)
60 | self.q_proj = nn.Linear(embed_dim, embed_dim)
61 | self.v_proj = nn.Linear(embed_dim, embed_dim)
62 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
63 | self.num_heads = num_heads
64 |
65 | def forward(self, x):
66 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
67 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
68 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
69 | x, _ = F.multi_head_attention_forward(
70 | query=x, key=x, value=x,
71 | embed_dim_to_check=x.shape[-1],
72 | num_heads=self.num_heads,
73 | q_proj_weight=self.q_proj.weight,
74 | k_proj_weight=self.k_proj.weight,
75 | v_proj_weight=self.v_proj.weight,
76 | in_proj_weight=None,
77 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
78 | bias_k=None,
79 | bias_v=None,
80 | add_zero_attn=False,
81 | dropout_p=0,
82 | out_proj_weight=self.c_proj.weight,
83 | out_proj_bias=self.c_proj.bias,
84 | use_separate_proj_weight=True,
85 | training=self.training,
86 | need_weights=False
87 | )
88 |
89 | return x[0]
90 |
91 |
92 | class ModifiedResNet(nn.Module):
93 | """
94 | A ResNet class that is similar to torchvision's but contains the following changes:
95 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
96 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
97 | - The final pooling layer is a QKV attention instead of an average pool
98 | """
99 |
100 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
101 | super().__init__()
102 | self.output_dim = output_dim
103 | self.input_resolution = input_resolution
104 |
105 | # the 3-layer stem
106 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
107 | self.bn1 = nn.BatchNorm2d(width // 2)
108 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
109 | self.bn2 = nn.BatchNorm2d(width // 2)
110 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
111 | self.bn3 = nn.BatchNorm2d(width)
112 | self.avgpool = nn.AvgPool2d(2)
113 | self.relu = nn.ReLU(inplace=True)
114 |
115 | # residual layers
116 | self._inplanes = width # this is a *mutable* variable used during construction
117 | self.layer1 = self._make_layer(width, layers[0])
118 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
119 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
120 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
121 |
122 | embed_dim = width * 32 # the ResNet feature dimension
123 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
124 |
125 | def _make_layer(self, planes, blocks, stride=1):
126 | layers = [Bottleneck(self._inplanes, planes, stride)]
127 |
128 | self._inplanes = planes * Bottleneck.expansion
129 | for _ in range(1, blocks):
130 | layers.append(Bottleneck(self._inplanes, planes))
131 |
132 | return nn.Sequential(*layers)
133 |
134 | def forward(self, x):
135 | def stem(x):
136 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
137 | x = self.relu(bn(conv(x)))
138 | x = self.avgpool(x)
139 | return x
140 |
141 | x = x.type(self.conv1.weight.dtype)
142 | x = stem(x)
143 | x = self.layer1(x)
144 | x = self.layer2(x)
145 | x = self.layer3(x)
146 | x = self.layer4(x)
147 | x = self.attnpool(x)
148 |
149 | return x
150 |
151 |
152 | class LayerNorm(nn.LayerNorm):
153 | """Subclass torch's LayerNorm to handle fp16."""
154 |
155 | def forward(self, x: torch.Tensor):
156 | orig_type = x.dtype
157 | ret = super().forward(x.type(torch.float32))
158 | return ret.type(orig_type)
159 |
160 |
161 | class QuickGELU(nn.Module):
162 | def forward(self, x: torch.Tensor):
163 | return x * torch.sigmoid(1.702 * x)
164 |
165 |
166 | class ResidualAttentionBlock(nn.Module):
167 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
168 | super().__init__()
169 |
170 | self.attn = nn.MultiheadAttention(d_model, n_head)
171 | self.ln_1 = LayerNorm(d_model)
172 | self.mlp = nn.Sequential(OrderedDict([
173 | ("c_fc", nn.Linear(d_model, d_model * 4)),
174 | ("gelu", QuickGELU()),
175 | ("c_proj", nn.Linear(d_model * 4, d_model))
176 | ]))
177 | self.ln_2 = LayerNorm(d_model)
178 | self.attn_mask = attn_mask
179 |
180 | def attention(self, x: torch.Tensor):
181 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
182 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
183 |
184 | def forward(self, x: torch.Tensor):
185 | x = x + self.attention(self.ln_1(x))
186 | x = x + self.mlp(self.ln_2(x))
187 | return x
188 |
189 |
190 | class Transformer(nn.Module):
191 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
192 | super().__init__()
193 | self.width = width
194 | self.layers = layers
195 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
196 |
197 | def forward(self, x: torch.Tensor):
198 | return self.resblocks(x)
199 |
200 |
201 | class VisualTransformer(nn.Module):
202 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
203 | super().__init__()
204 | self.input_resolution = input_resolution
205 | self.output_dim = output_dim
206 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
207 |
208 | scale = width ** -0.5
209 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
210 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
211 | self.ln_pre = LayerNorm(width)
212 |
213 | self.transformer = Transformer(width, layers, heads)
214 |
215 | self.ln_post = LayerNorm(width)
216 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
217 |
218 | def forward(self, x: torch.Tensor):
219 | x = self.conv1(x) # shape = [*, width, grid, grid]
220 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
221 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
222 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
223 | x = x + self.positional_embedding.to(x.dtype)
224 | x = self.ln_pre(x)
225 |
226 | x = x.permute(1, 0, 2) # NLD -> LND
227 | x = self.transformer(x)
228 | x = x.permute(1, 0, 2) # LND -> NLD
229 |
230 | x = self.ln_post(x[:, 0, :])
231 |
232 | if self.proj is not None:
233 | x = x @ self.proj
234 |
235 | return x
236 |
237 |
238 | class CLIP(nn.Module):
239 | def __init__(self,
240 | embed_dim: int,
241 | # vision
242 | image_resolution: int,
243 | vision_layers: Union[Tuple[int, int, int, int], int],
244 | vision_width: int,
245 | vision_patch_size: int,
246 | # text
247 | context_length: int,
248 | vocab_size: int,
249 | transformer_width: int,
250 | transformer_heads: int,
251 | transformer_layers: int
252 | ):
253 | super().__init__()
254 |
255 | self.context_length = context_length
256 |
257 | if isinstance(vision_layers, (tuple, list)):
258 | vision_heads = vision_width * 32 // 64
259 | self.visual = ModifiedResNet(
260 | layers=vision_layers,
261 | output_dim=embed_dim,
262 | heads=vision_heads,
263 | input_resolution=image_resolution,
264 | width=vision_width
265 | )
266 | else:
267 | vision_heads = vision_width // 64
268 | self.visual = VisualTransformer(
269 | input_resolution=image_resolution,
270 | patch_size=vision_patch_size,
271 | width=vision_width,
272 | layers=vision_layers,
273 | heads=vision_heads,
274 | output_dim=embed_dim
275 | )
276 |
277 | self.transformer = Transformer(
278 | width=transformer_width,
279 | layers=transformer_layers,
280 | heads=transformer_heads,
281 | attn_mask=self.build_attention_mask()
282 | )
283 |
284 | self.vocab_size = vocab_size
285 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
286 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
287 | self.ln_final = LayerNorm(transformer_width)
288 |
289 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
290 | self.logit_scale = nn.Parameter(torch.ones([]))
291 |
292 | def build_attention_mask(self):
293 | # lazily create causal attention mask, with full attention between the vision tokens
294 | # pytorch uses additive attention mask; fill with -inf
295 | mask = torch.empty(self.context_length, self.context_length)
296 | mask.fill_(float("-inf"))
297 | mask.triu_(1) # zero out the lower diagonal
298 | return mask
299 |
300 | @property
301 | def dtype(self):
302 | return self.visual.conv1.weight.dtype
303 |
304 | def encode_image(self, image):
305 | return self.visual(image.type(self.dtype))
306 |
307 | def encode_text(self, text):
308 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
309 |
310 | x = x + self.positional_embedding.type(self.dtype)
311 | x = x.permute(1, 0, 2) # NLD -> LND
312 | x = self.transformer(x)
313 | x = x.permute(1, 0, 2) # LND -> NLD
314 | x = self.ln_final(x).type(self.dtype)
315 |
316 | # x.shape = [batch_size, n_ctx, transformer.width]
317 | # take features from the eot embedding (eot_token is the highest number in each sequence)
318 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
319 |
320 | return x
321 |
322 | def forward(self, image, text):
323 | image_features = self.encode_image(image)
324 | text_features = self.encode_text(text)
325 |
326 | # normalized features
327 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
328 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
329 |
330 | # cosine similarity as logits
331 | logit_scale = self.logit_scale.exp()
332 | logits_per_iamge = logit_scale * image_features @ text_features.t()
333 | logits_per_text = logit_scale * text_features @ image_features.t()
334 |
335 | # shape = [global_batch_size, global_batch_size]
336 | return logits_per_iamge, logits_per_text
337 |
338 |
339 | def convert_weights(model: nn.Module):
340 | """Convert applicable model parameters to fp16"""
341 |
342 | def _convert_weights_to_fp16(l):
343 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
344 | l.weight.data = l.weight.data.half()
345 | if l.bias is not None:
346 | l.bias.data = l.bias.data.half()
347 |
348 | if isinstance(l, nn.MultiheadAttention):
349 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
350 | tensor = getattr(l, attr)
351 | if tensor is not None:
352 | tensor.data = tensor.data.half()
353 |
354 | for name in ["text_projection", "proj"]:
355 | if hasattr(l, name):
356 | attr = getattr(l, name)
357 | if attr is not None:
358 | attr.data = attr.data.half()
359 |
360 | model.apply(_convert_weights_to_fp16)
361 |
362 |
363 | def build_model(state_dict: dict):
364 | vit = "visual.proj" in state_dict
365 |
366 | if vit:
367 | vision_width = state_dict["visual.conv1.weight"].shape[0]
368 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
369 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
370 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
371 | image_resolution = vision_patch_size * grid_size
372 | else:
373 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
374 | vision_layers = tuple(counts)
375 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
376 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
377 | vision_patch_size = None
378 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
379 | image_resolution = output_width * 32
380 |
381 | embed_dim = state_dict["text_projection"].shape[1]
382 | context_length = state_dict["positional_embedding"].shape[0]
383 | vocab_size = state_dict["token_embedding.weight"].shape[0]
384 | transformer_width = state_dict["ln_final.weight"].shape[0]
385 | transformer_heads = transformer_width // 64
386 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
387 |
388 | model = CLIP(
389 | embed_dim,
390 | image_resolution, vision_layers, vision_width, vision_patch_size,
391 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
392 | )
393 |
394 | for key in ["input_resolution", "context_length", "vocab_size"]:
395 | del state_dict[key]
396 |
397 | convert_weights(model)
398 | model.load_state_dict(state_dict)
399 | return model.eval()
400 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return "./assets/bpe_simple_vocab_16e6.txt.gz"
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from models import DeepMindBigGAN, StyleGAN2, GPT2
2 | from latent import DeepMindBigGANLatentSpace, StyleGAN2LatentSpace, GPT2LatentSpace
3 | from utils import biggan_norm, biggan_denorm
4 |
5 | configs = dict(
6 | GPT2 = dict(
7 | task = "img2txt",
8 | dim_z = 20,
9 | max_tokens_len = 30,
10 | max_text_len = 50,
11 | encoder_size = 50257,
12 | latent = GPT2LatentSpace,
13 | model = GPT2,
14 | use_discriminator = False,
15 | init_text = "the picture of",
16 | weights = "./gpt2/weights/gpt2-pytorch_model.bin",
17 | encoder = "./gpt2/weights/encoder.json",
18 | vocab = "./gpt2/weights/vocab.bpe",
19 | stochastic = False,
20 | algorithm = "ga",
21 | pop_size = 100,
22 | batch_size = 25,
23 | problem_args = dict(
24 | n_var = 20,
25 | n_obj = 1,
26 | n_constr = 20,
27 | xl = 0,
28 | xu = 50256
29 | )
30 | ),
31 | DeepMindBigGAN256 = dict(
32 | task = "txt2img",
33 | dim_z = 128,
34 | num_classes = 1000,
35 | latent = DeepMindBigGANLatentSpace,
36 | model = DeepMindBigGAN,
37 | weights = "biggan-deep-256",
38 | use_discriminator = False,
39 | algorithm = "ga",
40 | norm = biggan_norm,
41 | denorm = biggan_denorm,
42 | truncation = 1.0,
43 | pop_size = 64,
44 | batch_size = 32,
45 | problem_args = dict(
46 | n_var = 128 + 1000,
47 | n_obj = 1,
48 | n_constr = 128,
49 | xl = -2,
50 | xu = 2
51 | )
52 | ),
53 | DeepMindBigGAN512 = dict(
54 | task = "txt2img",
55 | dim_z = 128,
56 | num_classes = 1000,
57 | latent = DeepMindBigGANLatentSpace,
58 | model = DeepMindBigGAN,
59 | weights = "biggan-deep-512",
60 | use_discriminator = False,
61 | algorithm = "ga",
62 | norm = biggan_norm,
63 | denorm = biggan_denorm,
64 | truncation = 1.0,
65 | pop_size = 32,
66 | batch_size = 8,
67 | problem_args = dict(
68 | n_var = 128 + 1000,
69 | n_obj = 1,
70 | n_constr = 128,
71 | xl = -2,
72 | xu = 2
73 | )
74 | ),
75 | StyleGAN2_ffhq_d = dict(
76 | task = "txt2img",
77 | dim_z = 512,
78 | latent = StyleGAN2LatentSpace,
79 | model = StyleGAN2,
80 | use_discriminator = True,
81 | weights = "./stylegan2/weights/ffhq-config-f",
82 | algorithm = "nsga2",
83 | norm = biggan_norm,
84 | denorm = biggan_denorm,
85 | pop_size = 16,
86 | batch_size = 4,
87 | problem_args = dict(
88 | n_var = 512,
89 | n_obj = 2,
90 | n_constr = 512,
91 | xl = -10,
92 | xu = 10,
93 | ),
94 | ),
95 | StyleGAN2_car_d = dict(
96 | task = "txt2img",
97 | dim_z = 512,
98 | latent = StyleGAN2LatentSpace,
99 | model = StyleGAN2,
100 | use_discriminator = True,
101 | weights = "./stylegan2/weights/car-config-f",
102 | algorithm = "nsga2",
103 | norm = biggan_norm,
104 | denorm = biggan_denorm,
105 | pop_size = 16,
106 | batch_size = 4,
107 | problem_args = dict(
108 | n_var = 512,
109 | n_obj = 2,
110 | n_constr = 512,
111 | xl = -10,
112 | xu = 10
113 | ),
114 | ),
115 | StyleGAN2_church_d = dict(
116 | task = "txt2img",
117 | dim_z = 512,
118 | latent = StyleGAN2LatentSpace,
119 | model = StyleGAN2,
120 | use_discriminator = True,
121 | weights = "./stylegan2/weights/church-config-f",
122 | algorithm = "nsga2",
123 | norm = biggan_norm,
124 | denorm = biggan_denorm,
125 | pop_size = 16,
126 | batch_size = 4,
127 | problem_args = dict(
128 | n_var = 512,
129 | n_obj = 2,
130 | n_constr = 512,
131 | xl = -10,
132 | xu = 10
133 | ),
134 | ),
135 | StyleGAN2_ffhq_nod = dict(
136 | task = "txt2img",
137 | dim_z = 512,
138 | latent = StyleGAN2LatentSpace,
139 | model = StyleGAN2,
140 | use_discriminator = False,
141 | weights = "./stylegan2/weights/ffhq-config-f",
142 | algorithm = "ga",
143 | norm = biggan_norm,
144 | denorm = biggan_denorm,
145 | pop_size = 16,
146 | batch_size = 4,
147 | problem_args = dict(
148 | n_var = 512,
149 | n_obj = 1,
150 | n_constr = 512,
151 | xl = -10,
152 | xu = 10
153 | )
154 | ),
155 | StyleGAN2_car_nod = dict(
156 | task = "txt2img",
157 | dim_z = 512,
158 | latent = StyleGAN2LatentSpace,
159 | model = StyleGAN2,
160 | use_discriminator = False,
161 | weights = "./stylegan2/weights/car-config-f",
162 | algorithm = "ga",
163 | norm = biggan_norm,
164 | denorm = biggan_denorm,
165 | pop_size = 16,
166 | batch_size = 4,
167 | problem_args = dict(
168 | n_var = 512,
169 | n_obj = 1,
170 | n_constr = 512,
171 | xl = -10,
172 | xu = 10
173 | )
174 | ),
175 | StyleGAN2_church_nod = dict(
176 | task = "txt2img",
177 | dim_z = 512,
178 | latent = StyleGAN2LatentSpace,
179 | model = StyleGAN2,
180 | use_discriminator = False,
181 | weights = "./stylegan2/weights/church-config-f",
182 | algorithm = "ga",
183 | norm = biggan_norm,
184 | denorm = biggan_denorm,
185 | pop_size = 16,
186 | batch_size = 4,
187 | problem_args = dict(
188 | n_var = 512,
189 | n_obj = 1,
190 | n_constr = 512,
191 | xl = -10,
192 | xu = 10
193 | )
194 | )
195 | )
196 |
197 |
198 |
199 | def get_config(name):
200 | return configs[name]
--------------------------------------------------------------------------------
/download-weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ "$#" -ne 1 ]; then
4 | echo "./download-weights.sh "
5 | echo "Possibile are: StyleGAN2-ffhq, StyleGAN2-church, StyleGAN2-car, GPT2"
6 | echo "Example:"
7 | echo "./download-weights.sh StyleGAN2-ffhq"
8 | exit
9 | fi
10 |
11 | die(){
12 | echo "$1"
13 | exit
14 | }
15 |
16 | download_stylegan2(){
17 | config="$1"
18 | dest="./stylegan2/weights/$config"
19 | [ -f "$dest/G.pth" ] && die "Weights already downloaded"
20 | [ ! -d "$dest" ] && mkdir -p "$dest"
21 | python -m stylegan2.convert_from_tf --download "$config" --output "$dest/G.pth" "$dest/D.pth" "$dest/Gs.pth"
22 | }
23 |
24 |
25 | case $1 in
26 | "StyleGAN2-ffhq")
27 | download_stylegan2 "ffhq-config-f"
28 | ;;
29 | "StyleGAN2-church")
30 | download_stylegan2 "church-config-f"
31 | ;;
32 | "StyleGAN2-car")
33 | download_stylegan2 "car-config-f"
34 | ;;
35 | "GPT2")
36 | [ -f "gpt2/weights/gpt2-pytorch_model.bin" ] && die "Weights already downloaded"
37 | curl --output gpt2/weights/gpt2-pytorch_model.bin https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin
38 | ;;
39 | *)
40 | echo "Unknown model '$1'"
41 | ;;
42 | esac
--------------------------------------------------------------------------------
/generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_biggan import BigGAN
3 | from clip import clip
4 | import kornia
5 | from PIL import Image
6 | from torchvision.utils import save_image
7 |
8 |
9 | from utils import save_grid, freeze_model
10 |
11 | class Generator:
12 | def __init__(self, config):
13 | self.config = config
14 | self.augmentation = None
15 |
16 | self.CLIP, clip_preprocess = clip.load("ViT-B/32", device=self.config.device, jit=False)
17 | self.CLIP = self.CLIP.eval()
18 | freeze_model(self.CLIP)
19 | self.model = self.config.model(config).to(self.config.device).eval()
20 | freeze_model(self.model)
21 |
22 | if config.task == "txt2img":
23 | self.tokens = clip.tokenize([self.config.target]).to(self.config.device)
24 | self.text_features = self.CLIP.encode_text(self.tokens).detach()
25 | if config.task == "img2txt":
26 | image = clip_preprocess(Image.open(self.config.target)).unsqueeze(0).to(self.config.device)
27 | self.image_features = self.CLIP.encode_image(image)
28 |
29 | def generate(self, ls, minibatch=None):
30 | z = ls()
31 | result = self.model.generate(*z, minibatch=minibatch)
32 | if hasattr(self.config, "norm"):
33 | result = self.config.norm(result)
34 | return result
35 |
36 | def discriminate(self, images, minibatch=None):
37 | images = self.config.denorm(images)
38 | return self.model.discriminate(images, minibatch)
39 |
40 | def has_discriminator(self):
41 | return self.model.has_discriminator()
42 |
43 | def clip_similarity(self, input):
44 | if self.config.task == "txt2img":
45 | image = kornia.resize(input, (224, 224))
46 | if self.augmentation is not None:
47 | image = self.augmentation(image)
48 |
49 | image_features = self.CLIP.encode_image(image)
50 |
51 | sim = torch.cosine_similarity(image_features, self.text_features)
52 | elif self.config.task == "img2txt":
53 | try:
54 | text_tokens = clip.tokenize(input).to(self.config.device)
55 | except:
56 | return torch.zeros(len(input))
57 | text_features = self.CLIP.encode_text(text_tokens)
58 |
59 | sim = torch.cosine_similarity(text_features, self.image_features)
60 | return sim
61 |
62 |
63 | def save(self, input, path):
64 | if self.config.task == "txt2img":
65 | if input.shape[0] > 1:
66 | save_grid(input.detach().cpu(), path)
67 | else:
68 | save_image(input[0], path)
69 | elif self.config.task == "img2txt":
70 | f = open(path, "w")
71 | f.write("\n".join(input))
72 | f.close()
--------------------------------------------------------------------------------
/gpt2/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | code by TaeHwan Jung(@graykode)
3 | Original Paper and repository here : https://github.com/openai/gpt-2
4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
5 | '''
6 | class GPT2Config(object):
7 | def __init__(
8 | self,
9 | vocab_size_or_config_json_file=50257,
10 | n_positions=1024,
11 | n_ctx=1024,
12 | n_embd=768,
13 | n_layer=12,
14 | n_head=12,
15 | layer_norm_epsilon=1e-5,
16 | initializer_range=0.02,
17 | ):
18 | self.vocab_size = vocab_size_or_config_json_file
19 | self.n_ctx = n_ctx
20 | self.n_positions = n_positions
21 | self.n_embd = n_embd
22 | self.n_layer = n_layer
23 | self.n_head = n_head
24 | self.layer_norm_epsilon = layer_norm_epsilon
25 | self.initializer_range = initializer_range
--------------------------------------------------------------------------------
/gpt2/encoder.py:
--------------------------------------------------------------------------------
1 | """Byte pair encoding utilities"""
2 |
3 | import os
4 | import json
5 | import regex as re
6 | from functools import lru_cache
7 |
8 | @lru_cache()
9 | def bytes_to_unicode():
10 | """
11 | Returns list of utf-8 byte and a corresponding list of unicode strings.
12 | The reversible bpe codes work on unicode strings.
13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
15 | This is a signficant percentage of your normal, say, 32K bpe vocab.
16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
17 | And avoids mapping to whitespace/control characters the bpe code barfs on.
18 | """
19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
20 | cs = bs[:]
21 | n = 0
22 | for b in range(2**8):
23 | if b not in bs:
24 | bs.append(b)
25 | cs.append(2**8+n)
26 | n += 1
27 | cs = [chr(n) for n in cs]
28 | return dict(zip(bs, cs))
29 |
30 | def get_pairs(word):
31 | """Return set of symbol pairs in a word.
32 | Word is represented as tuple of symbols (symbols being variable-length strings).
33 | """
34 | pairs = set()
35 | prev_char = word[0]
36 | for char in word[1:]:
37 | pairs.add((prev_char, char))
38 | prev_char = char
39 | return pairs
40 |
41 | class Encoder:
42 | def __init__(self, encoder, bpe_merges, errors='replace'):
43 | self.encoder = encoder
44 | self.decoder = {v:k for k,v in self.encoder.items()}
45 | self.errors = errors # how to handle errors in decoding
46 | self.byte_encoder = bytes_to_unicode()
47 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
48 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
49 | self.cache = {}
50 |
51 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
52 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
53 |
54 | def bpe(self, token):
55 | if token in self.cache:
56 | return self.cache[token]
57 | word = tuple(token)
58 | pairs = get_pairs(word)
59 |
60 | if not pairs:
61 | return token
62 |
63 | while True:
64 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
65 | if bigram not in self.bpe_ranks:
66 | break
67 | first, second = bigram
68 | new_word = []
69 | i = 0
70 | while i < len(word):
71 | try:
72 | j = word.index(first, i)
73 | new_word.extend(word[i:j])
74 | i = j
75 | except:
76 | new_word.extend(word[i:])
77 | break
78 |
79 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
80 | new_word.append(first+second)
81 | i += 2
82 | else:
83 | new_word.append(word[i])
84 | i += 1
85 | new_word = tuple(new_word)
86 | word = new_word
87 | if len(word) == 1:
88 | break
89 | else:
90 | pairs = get_pairs(word)
91 | word = ' '.join(word)
92 | self.cache[token] = word
93 | return word
94 |
95 | def encode(self, text):
96 | bpe_tokens = []
97 | for token in re.findall(self.pat, text):
98 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
99 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
100 | return bpe_tokens
101 |
102 | def decode(self, tokens):
103 | text = ''.join([self.decoder[token] for token in tokens])
104 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
105 | return text
106 |
107 | def get_encoder(config):
108 | with open(config.encoder, 'r') as f:
109 | encoder = json.load(f)
110 | with open(config.vocab, 'r', encoding="utf-8") as f:
111 | bpe_data = f.read()
112 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
113 | return Encoder(
114 | encoder=encoder,
115 | bpe_merges=bpe_merges,
116 | )
--------------------------------------------------------------------------------
/gpt2/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | code by TaeHwan Jung(@graykode)
3 | Original Paper and repository here : https://github.com/openai/gpt-2
4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
5 | '''
6 | import copy
7 | import torch
8 | import math
9 | import torch.nn as nn
10 | from torch.nn.parameter import Parameter
11 |
12 | def gelu(x):
13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
14 |
15 | class LayerNorm(nn.Module):
16 | def __init__(self, hidden_size, eps=1e-12):
17 | """Construct a layernorm module in the TF style (epsilon inside the square root).
18 | """
19 | super(LayerNorm, self).__init__()
20 | self.weight = nn.Parameter(torch.ones(hidden_size))
21 | self.bias = nn.Parameter(torch.zeros(hidden_size))
22 | self.variance_epsilon = eps
23 |
24 | def forward(self, x):
25 | u = x.mean(-1, keepdim=True)
26 | s = (x - u).pow(2).mean(-1, keepdim=True)
27 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
28 | return self.weight * x + self.bias
29 |
30 | class Conv1D(nn.Module):
31 | def __init__(self, nf, nx):
32 | super(Conv1D, self).__init__()
33 | self.nf = nf
34 | w = torch.empty(nx, nf)
35 | nn.init.normal_(w, std=0.02)
36 | self.weight = Parameter(w)
37 | self.bias = Parameter(torch.zeros(nf))
38 |
39 | def forward(self, x):
40 | size_out = x.size()[:-1] + (self.nf,)
41 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
42 | x = x.view(*size_out)
43 | return x
44 |
45 | class Attention(nn.Module):
46 | def __init__(self, nx, n_ctx, config, scale=False):
47 | super(Attention, self).__init__()
48 | n_state = nx # in Attention: n_state=768 (nx=n_embd)
49 | # [switch nx => n_state from Block to Attention to keep identical to TF implem]
50 | assert n_state % config.n_head == 0
51 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
52 | self.n_head = config.n_head
53 | self.split_size = n_state
54 | self.scale = scale
55 | self.c_attn = Conv1D(n_state * 3, nx)
56 | self.c_proj = Conv1D(n_state, nx)
57 |
58 | def _attn(self, q, k, v):
59 | w = torch.matmul(q, k)
60 | if self.scale:
61 | w = w / math.sqrt(v.size(-1))
62 | nd, ns = w.size(-2), w.size(-1)
63 | b = self.bias[:, :, ns-nd:ns, :ns]
64 | w = w * b - 1e10 * (1 - b)
65 | w = nn.Softmax(dim=-1)(w)
66 | return torch.matmul(w, v)
67 |
68 | def merge_heads(self, x):
69 | x = x.permute(0, 2, 1, 3).contiguous()
70 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
71 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
72 |
73 | def split_heads(self, x, k=False):
74 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
75 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
76 | if k:
77 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
78 | else:
79 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
80 |
81 | def forward(self, x, layer_past=None):
82 | x = self.c_attn(x)
83 | query, key, value = x.split(self.split_size, dim=2)
84 | query = self.split_heads(query)
85 | key = self.split_heads(key, k=True)
86 | value = self.split_heads(value)
87 | if layer_past is not None:
88 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
89 | key = torch.cat((past_key, key), dim=-1)
90 | value = torch.cat((past_value, value), dim=-2)
91 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
92 | a = self._attn(query, key, value)
93 | a = self.merge_heads(a)
94 | a = self.c_proj(a)
95 | return a, present
96 |
97 | class MLP(nn.Module):
98 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
99 | super(MLP, self).__init__()
100 | nx = config.n_embd
101 | self.c_fc = Conv1D(n_state, nx)
102 | self.c_proj = Conv1D(nx, n_state)
103 | self.act = gelu
104 |
105 | def forward(self, x):
106 | h = self.act(self.c_fc(x))
107 | h2 = self.c_proj(h)
108 | return h2
109 |
110 | class Block(nn.Module):
111 | def __init__(self, n_ctx, config, scale=False):
112 | super(Block, self).__init__()
113 | nx = config.n_embd
114 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
115 | self.attn = Attention(nx, n_ctx, config, scale)
116 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
117 | self.mlp = MLP(4 * nx, config)
118 |
119 | def forward(self, x, layer_past=None):
120 | a, present = self.attn(self.ln_1(x), layer_past=layer_past)
121 | x = x + a
122 | m = self.mlp(self.ln_2(x))
123 | x = x + m
124 | return x, present
125 |
126 | class GPT2Model(nn.Module):
127 | def __init__(self, config):
128 | super(GPT2Model, self).__init__()
129 | self.n_layer = config.n_layer
130 | self.n_embd = config.n_embd
131 | self.n_vocab = config.vocab_size
132 |
133 | self.wte = nn.Embedding(config.vocab_size, config.n_embd)
134 | self.wpe = nn.Embedding(config.n_positions, config.n_embd)
135 | block = Block(config.n_ctx, config, scale=True)
136 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
137 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
138 |
139 | def set_embeddings_weights(self, model_embeddings_weights):
140 | embed_shape = model_embeddings_weights.shape
141 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
142 | self.decoder.weight = model_embeddings_weights # Tied weights
143 |
144 | def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
145 | if past is None:
146 | past_length = 0
147 | past = [None] * len(self.h)
148 | else:
149 | past_length = past[0][0].size(-2)
150 | if position_ids is None:
151 | position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long,
152 | device=input_ids.device)
153 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
154 |
155 | input_shape = input_ids.size()
156 | input_ids = input_ids.view(-1, input_ids.size(-1))
157 | position_ids = position_ids.view(-1, position_ids.size(-1))
158 |
159 | inputs_embeds = self.wte(input_ids)
160 | position_embeds = self.wpe(position_ids)
161 | if token_type_ids is not None:
162 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
163 | token_type_embeds = self.wte(token_type_ids)
164 | else:
165 | token_type_embeds = 0
166 | hidden_states = inputs_embeds + position_embeds + token_type_embeds
167 |
168 | presents = []
169 | for block, layer_past in zip(self.h, past):
170 | hidden_states, present = block(hidden_states, layer_past)
171 | presents.append(present)
172 | hidden_states = self.ln_f(hidden_states)
173 | output_shape = input_shape + (hidden_states.size(-1),)
174 |
175 | return hidden_states.view(*output_shape), presents
176 |
177 | class GPT2LMHead(nn.Module):
178 | def __init__(self, model_embeddings_weights, config):
179 | super(GPT2LMHead, self).__init__()
180 | self.n_embd = config.n_embd
181 | self.set_embeddings_weights(model_embeddings_weights)
182 |
183 | def set_embeddings_weights(self, model_embeddings_weights):
184 | embed_shape = model_embeddings_weights.shape
185 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
186 | self.decoder.weight = model_embeddings_weights # Tied weights
187 |
188 | def forward(self, hidden_state):
189 | # Truncated Language modeling logits (we remove the last token)
190 | # h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
191 | lm_logits = self.decoder(hidden_state)
192 | return lm_logits
193 |
194 | class GPT2LMHeadModel(nn.Module):
195 | def __init__(self, config):
196 | super(GPT2LMHeadModel, self).__init__()
197 | self.transformer = GPT2Model(config)
198 | self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
199 |
200 | def set_tied(self):
201 | """ Make sure we are sharing the embeddings
202 | """
203 | self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
204 |
205 | def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None):
206 | hidden_states, presents = self.transformer(input_ids, position_ids, token_type_ids, past)
207 | lm_logits = self.lm_head(hidden_states)
208 | if lm_labels is not None:
209 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
210 | loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1))
211 | return loss
212 | return lm_logits, presents
--------------------------------------------------------------------------------
/gpt2/sample.py:
--------------------------------------------------------------------------------
1 | '''
2 | code by TaeHwan Jung(@graykode)
3 | Original Paper and repository here : https://github.com/openai/gpt-2
4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
5 | '''
6 | import torch
7 | import torch.nn.functional as F
8 | from tqdm import trange
9 |
10 | def top_k_logits(logits, k):
11 | if k == 0:
12 | return logits
13 | values, _ = torch.topk(logits, k)
14 | min_values = values[:, -1]
15 | rets = []
16 | for l, m in zip(logits, min_values):
17 | rets.append(torch.where(l < m, torch.ones_like(l, dtype=l.dtype) * -1e10, l))
18 | rets = torch.stack(rets)
19 | return rets
20 |
21 | def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device="cuda", sample=True):
22 | prev = context
23 | output = context
24 | past = None
25 | with torch.no_grad():
26 | for i in range(length):
27 | logits, past = model(prev, past=past)
28 | logits = logits[:, -1, :] / temperature
29 | logits = top_k_logits(logits, k=top_k)
30 | log_probs = F.softmax(logits, dim=-1)
31 | if sample:
32 | prev = torch.multinomial(log_probs, num_samples=1)
33 | else:
34 | _, prev = torch.topk(log_probs, k=1, dim=-1)
35 | output = torch.cat((output, prev), dim=1)
36 |
37 | return output.cpu().numpy().tolist()
--------------------------------------------------------------------------------
/gpt2/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | code by TaeHwan Jung(@graykode)
3 | Original Paper and repository here : https://github.com/openai/gpt-2
4 | GPT2 Pytorch Model : https://github.com/huggingface/pytorch-pretrained-BERT
5 | '''
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | def load_weight(model, state_dict):
11 | old_keys = []
12 | new_keys = []
13 | for key in state_dict.keys():
14 | new_key = None
15 | if key.endswith(".g"):
16 | new_key = key[:-2] + ".weight"
17 | elif key.endswith(".b"):
18 | new_key = key[:-2] + ".bias"
19 | elif key.endswith(".w"):
20 | new_key = key[:-2] + ".weight"
21 | if new_key:
22 | old_keys.append(key)
23 | new_keys.append(new_key)
24 | for old_key, new_key in zip(old_keys, new_keys):
25 | state_dict[new_key] = state_dict.pop(old_key)
26 |
27 | missing_keys = []
28 | unexpected_keys = []
29 | error_msgs = []
30 | # copy state_dict so _load_from_state_dict can modify it
31 | metadata = getattr(state_dict, "_metadata", None)
32 | state_dict = state_dict.copy()
33 | if metadata is not None:
34 | state_dict._metadata = metadata
35 |
36 | def load(module, prefix=""):
37 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
38 | module._load_from_state_dict(
39 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs
40 | )
41 | for name, child in module._modules.items():
42 | if child is not None:
43 | load(child, prefix + name + ".")
44 |
45 | start_model = model
46 | if hasattr(model, "transformer") and all(not s.startswith('transformer.') for s in state_dict.keys()):
47 | start_model = model.transformer
48 | load(start_model, prefix="")
49 |
50 | # Make sure we are still sharing the output and input embeddings after loading weights
51 | model.set_tied()
52 | return model
--------------------------------------------------------------------------------
/gpt2_images/dog.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/dog.jpeg
--------------------------------------------------------------------------------
/gpt2_images/goldfish.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/goldfish.jpeg
--------------------------------------------------------------------------------
/gpt2_images/harmonica.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/harmonica.jpeg
--------------------------------------------------------------------------------
/gpt2_images/harp.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/harp.jpeg
--------------------------------------------------------------------------------
/gpt2_images/knot.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/knot.jpeg
--------------------------------------------------------------------------------
/gpt2_images/radio_telescope.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/radio_telescope.jpeg
--------------------------------------------------------------------------------
/gpt2_images/teapot.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/teapot.jpeg
--------------------------------------------------------------------------------
/gpt2_images/telephone.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/telephone.jpeg
--------------------------------------------------------------------------------
/gpt2_images/zebra.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/galatolofederico/clip-glass/0887b13a19e75f20061574587cd2e03be59851e6/gpt2_images/zebra.jpeg
--------------------------------------------------------------------------------
/latent.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_pretrained_biggan import truncated_noise_sample
3 |
4 | class DeepMindBigGANLatentSpace(torch.nn.Module):
5 | def __init__(self, config):
6 | super(DeepMindBigGANLatentSpace, self).__init__()
7 | self.config = config
8 |
9 | self.z = torch.nn.Parameter(torch.tensor(truncated_noise_sample(self.config.batch_size)).to(self.config.device))
10 | self.class_labels = torch.nn.Parameter(torch.rand(self.config.batch_size, self.config.num_classes).to(self.config.device))
11 |
12 | def set_values(self, z, class_labels):
13 | self.z.data = z
14 | self.class_labels.data = class_labels
15 |
16 | def set_from_population(self, x):
17 | self.z.data = torch.tensor(x[:,:self.config.dim_z].astype(float)).float().to(self.config.device)
18 | self.class_labels.data = torch.tensor(x[:,self.config.dim_z:].astype(float)).float().to(self.config.device)
19 |
20 | def forward(self):
21 | z = torch.clip(self.z, -2, 2)
22 | class_labels = torch.softmax(self.class_labels, dim=1)
23 |
24 | return z, class_labels
25 |
26 |
27 | class StyleGAN2LatentSpace(torch.nn.Module):
28 | def __init__(self, config):
29 | super(StyleGAN2LatentSpace, self).__init__()
30 | self.config = config
31 |
32 | self.z = torch.nn.Parameter(torch.randn(self.config.batch_size, self.config.dim_z).to(self.config.device))
33 |
34 | def set_values(self, z):
35 | self.z.data = z
36 |
37 | def set_from_population(self, x):
38 | self.z.data = torch.tensor(x.astype(float)).float().to(self.config.device)
39 |
40 | def forward(self):
41 | return (self.z, )
42 |
43 |
44 | class GPT2LatentSpace(torch.nn.Module):
45 | def __init__(self, config):
46 | super(GPT2LatentSpace, self).__init__()
47 | self.config = config
48 |
49 | self.z = torch.randint(0, self.config.encoder_size, size=(self.config.batch_size, self.config.dim_z)).to(self.config.device)
50 | #self.z = torch.zeros(self.config.batch_size, self.config.dim_z)
51 |
52 | def set_values(self, z):
53 | self.z.data = z
54 |
55 | def set_from_population(self, x):
56 | self.z.data = torch.tensor(x.astype(int)).long().to(self.config.device)
57 |
58 | def forward(self):
59 | return (self.z, )
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from pytorch_pretrained_biggan import BigGAN as DMBigGAN
5 | import stylegan2
6 |
7 | from gpt2.model import GPT2LMHeadModel
8 | from gpt2.utils import load_weight
9 | from gpt2.config import GPT2Config
10 | from gpt2.sample import sample_sequence
11 | from gpt2.encoder import get_encoder
12 |
13 |
14 | class GPT2(torch.nn.Module):
15 | def __init__(self, config):
16 | super(GPT2, self).__init__()
17 | self.config = config
18 | if not os.path.exists(self.config.weights):
19 | print("Weights not found!\nRun: ./download-weights.sh GPT2")
20 | sys.exit(1)
21 |
22 | state_dict = torch.load(self.config.weights, map_location=self.config.device)
23 |
24 | self.enc = get_encoder(config)
25 | self.model = GPT2LMHeadModel(GPT2Config())
26 | self.model = load_weight(self.model, state_dict)
27 | self.model.to(self.config.device)
28 | self.model.eval()
29 |
30 | self.init_tokens = torch.tensor(self.enc.encode(self.config.init_text)).to(self.config.device)
31 |
32 | def parse_out(self, out):
33 | texts = []
34 | for seq in out:
35 | if self.enc.encoder["<|endoftext|>"] in seq:
36 | text = seq[self.config.dim_z:seq.index(self.enc.encoder["<|endoftext|>"])]
37 | else:
38 | text = seq[self.config.dim_z:]
39 | text = self.enc.decode(text)
40 |
41 | texts.append(text[:self.config.max_text_len])
42 | return texts
43 |
44 |
45 | def generate(self, z, minibatch=None):
46 | #TODO: implement minibatch
47 | init_tokens = self.init_tokens.repeat(z.shape[0], 1)
48 | z = torch.cat((z, init_tokens), dim=1)
49 |
50 | out = sample_sequence(
51 | model=self.model,
52 | length=self.config.max_tokens_len,
53 | context=z,
54 | start_token=None,
55 | batch_size=self.config.batch_size,
56 | temperature=0.7,
57 | top_k=40,
58 | device=self.config.device,
59 | sample=self.config.stochastic
60 | )
61 |
62 | return self.parse_out(out)
63 |
64 |
65 | class DeepMindBigGAN(torch.nn.Module):
66 | def __init__(self, config):
67 | super(DeepMindBigGAN, self).__init__()
68 | self.config = config
69 | self.G = DMBigGAN.from_pretrained(config.weights)
70 | self.D = None
71 |
72 | def has_discriminator(self):
73 | return False
74 |
75 | def generate(self, z, class_labels, minibatch = None):
76 | if minibatch is None:
77 | return self.G(z, class_labels, self.config.truncation)
78 | else:
79 | assert z.shape[0] % minibatch == 0
80 | gen_images = []
81 | for i in range(0, z.shape[0] // minibatch):
82 | z_minibatch = z[i*minibatch:(i+1)*minibatch, :]
83 | cl_minibatch = class_labels[i*minibatch:(i+1)*minibatch, :]
84 | gen_images.append(self.G(z_minibatch, cl_minibatch, self.config.truncation))
85 | gen_images = torch.cat(gen_images)
86 | return gen_images
87 |
88 |
89 |
90 | class StyleGAN2(torch.nn.Module):
91 | def __init__(self, config):
92 | super(StyleGAN2, self).__init__()
93 | if not os.path.exists(os.path.join(config.weights, "G.pth")):
94 | if "ffhq" in config.config:
95 | model = "ffhq"
96 | elif "car" in config.config:
97 | model = "car"
98 | elif "church" in config.config:
99 | model = "church"
100 | print("Weights not found!\nRun : ./download-weights.sh StyleGAN2-%s" % (model))
101 | sys.exit(1)
102 | self.G = stylegan2.models.load(os.path.join(config.weights, "G.pth"))
103 | self.D = stylegan2.models.load(os.path.join(config.weights, "D.pth"))
104 |
105 | def has_discriminator(self):
106 | return True
107 |
108 | def generate(self, z, minibatch = None):
109 | if minibatch is None:
110 | return self.G(z)
111 | else:
112 | assert z.shape[0] % minibatch == 0
113 | gen_images = []
114 | for i in range(0, z.shape[0] // minibatch):
115 | z_minibatch = z[i*minibatch:(i+1)*minibatch, :]
116 | gen_images.append(self.G(z_minibatch))
117 | gen_images = torch.cat(gen_images)
118 | return gen_images
119 |
120 | def discriminate(self, images, minibatch = None):
121 | if minibatch is None:
122 | return self.D(images)
123 | else:
124 | assert images.shape[0] % minibatch == 0
125 | discriminations = []
126 | for i in range(0, images.shape[0] // minibatch):
127 | images_minibatch = images[i*minibatch:(i+1)*minibatch, :]
128 | discriminations.append(self.D(images_minibatch))
129 | discriminations = torch.cat(discriminations)
130 | return discriminations
--------------------------------------------------------------------------------
/operators.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy.stats import truncnorm
4 |
5 | from pymoo.factory import get_sampling, get_crossover, get_mutation
6 | from pymoo.operators.mixed_variable_operator import MixedVariableSampling, MixedVariableMutation, MixedVariableCrossover
7 | from pymoo.model.sampling import Sampling
8 |
9 | class TruncatedNormalRandomSampling(Sampling):
10 | def __init__(self, var_type=np.float):
11 | super().__init__()
12 | self.var_type = var_type
13 |
14 | def _do(self, problem, n_samples, **kwargs):
15 | return truncnorm.rvs(-2, 2, size=(n_samples, problem.n_var)).astype(np.float32)
16 |
17 | class NormalRandomSampling(Sampling):
18 | def __init__(self, mu=0, std=1, var_type=np.float):
19 | super().__init__()
20 | self.mu = mu
21 | self.std = std
22 | self.var_type = var_type
23 |
24 | def _do(self, problem, n_samples, **kwargs):
25 | return np.random.normal(self.mu, self.std, size=(n_samples, problem.n_var))
26 |
27 | class BinaryRandomSampling(Sampling):
28 | def __init__(self, prob=0.5):
29 | super().__init__()
30 | self.prob = prob
31 |
32 | def _do(self, problem, n_samples, **kwargs):
33 | val = np.random.random((n_samples, problem.n_var))
34 | return (val < self.prob).astype(np.bool)
35 |
36 |
37 | def get_operators(config):
38 | if config.config == "DeepMindBigGAN256" or config.config == "DeepMindBigGAN512":
39 | mask = ["real"]*config.dim_z + ["bool"]*config.num_classes
40 |
41 | real_sampling = None
42 | if config.config == "DeepMindBigGAN256" or config.config == "DeepMindBigGAN512":
43 | real_sampling = TruncatedNormalRandomSampling()
44 |
45 | sampling = MixedVariableSampling(mask, {
46 | "real": real_sampling,
47 | "bool": BinaryRandomSampling(prob=5/1000)
48 | })
49 |
50 | crossover = MixedVariableCrossover(mask, {
51 | "real": get_crossover("real_sbx", prob=1.0, eta=3.0),
52 | "bool": get_crossover("bin_hux", prob=0.2)
53 | })
54 |
55 | mutation = MixedVariableMutation(mask, {
56 | "real": get_mutation("real_pm", prob=0.5, eta=3.0),
57 | "bool": get_mutation("bin_bitflip", prob=10/1000)
58 | })
59 |
60 | return dict(
61 | sampling=sampling,
62 | crossover=crossover,
63 | mutation=mutation
64 | )
65 |
66 | elif config.config.split("_")[0] == "StyleGAN2":
67 | return dict(
68 | sampling=NormalRandomSampling(),
69 | crossover=get_crossover("real_sbx", prob=1.0, eta=3.0),
70 | mutation=get_mutation("real_pm", prob=0.5, eta=3.0)
71 | )
72 |
73 | elif config.config == "GPT2":
74 | return dict(
75 | sampling=get_sampling("int_random"),
76 | crossover=get_crossover("int_sbx", prob=1.0, eta=3.0),
77 | mutation=get_mutation("int_pm", prob=0.5, eta=3.0)
78 | )
79 |
80 | else:
81 | raise Exception("Unknown config")
82 |
83 |
--------------------------------------------------------------------------------
/problem.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from pymoo.model.problem import Problem
5 | from generator import Generator
6 |
7 | class GenerationProblem(Problem):
8 | def __init__(self, config):
9 | self.generator = Generator(config)
10 | self.config = config
11 |
12 | super().__init__(**self.config.problem_args)
13 |
14 | def _evaluate(self, x, out, *args, **kwargs):
15 | ls = self.config.latent(self.config)
16 | ls.set_from_population(x)
17 |
18 | with torch.no_grad():
19 | generated = self.generator.generate(ls, minibatch=self.config.batch_size)
20 | sim = self.generator.clip_similarity(generated).cpu().numpy()
21 | if self.config.problem_args["n_obj"] == 2 and self.config.use_discriminator:
22 | dis = self.generator.discriminate(generated, minibatch=self.config.batch_size)
23 | hinge = torch.relu(1 - dis)
24 | hinge = hinge.squeeze(1).cpu().numpy()
25 | out["F"] = np.column_stack((-sim, hinge))
26 | else:
27 | out["F"] = -sim
28 |
29 | out["G"] = np.zeros((x.shape[0]))
30 |
31 |
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.11.0
2 | autograd==1.3
3 | boto3==1.16.63
4 | botocore==1.19.63
5 | cachetools==4.2.1
6 | certifi==2020.12.5
7 | chardet==4.0.0
8 | cma==2.7.0
9 | cycler==0.10.0
10 | dataclasses==0.8
11 | ftfy==5.8
12 | future==0.18.2
13 | google-auth==1.24.0
14 | google-auth-oauthlib==0.4.2
15 | grpcio==1.35.0
16 | idna==2.10
17 | importlib-metadata==3.4.0
18 | jmespath==0.10.0
19 | kiwisolver==1.3.1
20 | kornia==0.4.1
21 | Markdown==3.3.3
22 | matplotlib==3.3.4
23 | numpy==1.19.5
24 | oauthlib==3.1.0
25 | Pillow==8.1.0
26 | protobuf==3.14.0
27 | pyasn1==0.4.8
28 | pyasn1-modules==0.2.8
29 | pymoo==0.4.2.1
30 | pyparsing==2.4.7
31 | python-dateutil==2.8.1
32 | pytorch-pretrained-biggan==0.1.1
33 | PyYAML==5.4.1
34 | regex==2020.11.13
35 | requests==2.25.1
36 | requests-oauthlib==1.3.0
37 | rsa==4.7
38 | s3transfer==0.3.4
39 | scipy==1.5.4
40 | six==1.15.0
41 | tensorboard==2.4.1
42 | tensorboard-plugin-wit==1.8.0
43 | torch==1.7.1
44 | torchvision==0.8.2
45 | tqdm==4.56.0
46 | typing-extensions==3.7.4.3
47 | urllib3==1.26.3
48 | wcwidth==0.2.5
49 | Werkzeug==1.0.1
50 | zipp==3.4.0
51 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import torch
4 | import numpy as np
5 | import pickle
6 | from pymoo.optimize import minimize
7 | from pymoo.algorithms.so_genetic_algorithm import GA
8 | from pymoo.factory import get_algorithm, get_decision_making, get_decomposition
9 | from pymoo.visualization.scatter import Scatter
10 |
11 | from config import get_config
12 | from problem import GenerationProblem
13 | from operators import get_operators
14 |
15 | parser = argparse.ArgumentParser()
16 |
17 | parser.add_argument("--device", type=str, default="cuda")
18 | parser.add_argument("--config", type=str, default="DeepMindBigGAN512")
19 | parser.add_argument("--generations", type=int, default=500)
20 | parser.add_argument("--save-each", type=int, default=50)
21 | parser.add_argument("--tmp-folder", type=str, default="./tmp")
22 | parser.add_argument("--target", type=str, default="a wolf at night with the moon in the background")
23 |
24 | config = parser.parse_args()
25 | vars(config).update(get_config(config.config))
26 |
27 |
28 | iteration = 0
29 | def save_callback(algorithm):
30 | global iteration
31 | global config
32 |
33 | iteration += 1
34 | if iteration % config.save_each == 0 or iteration == config.generations:
35 | if config.problem_args["n_obj"] == 1:
36 | sortedpop = sorted(algorithm.pop, key=lambda p: p.F)
37 | X = np.stack([p.X for p in sortedpop])
38 | else:
39 | X = algorithm.pop.get("X")
40 |
41 | ls = config.latent(config)
42 | ls.set_from_population(X)
43 |
44 | with torch.no_grad():
45 | generated = algorithm.problem.generator.generate(ls, minibatch=config.batch_size)
46 | if config.task == "txt2img":
47 | ext = "jpg"
48 | elif config.task == "img2txt":
49 | ext = "txt"
50 | name = "genetic-it-%d.%s" % (iteration, ext) if iteration < config.generations else "genetic-it-final.%s" % (ext, )
51 | algorithm.problem.generator.save(generated, os.path.join(config.tmp_folder, name))
52 |
53 |
54 | problem = GenerationProblem(config)
55 | operators = get_operators(config)
56 |
57 | if not os.path.exists(config.tmp_folder): os.mkdir(config.tmp_folder)
58 |
59 | algorithm = get_algorithm(
60 | config.algorithm,
61 | pop_size=config.pop_size,
62 | sampling=operators["sampling"],
63 | crossover=operators["crossover"],
64 | mutation=operators["mutation"],
65 | eliminate_duplicates=True,
66 | callback=save_callback,
67 | **(config.algorithm_args[config.algorithm] if "algorithm_args" in config and config.algorithm in config.algorithm_args else dict())
68 | )
69 |
70 | res = minimize(
71 | problem,
72 | algorithm,
73 | ("n_gen", config.generations),
74 | save_history=False,
75 | verbose=True,
76 | )
77 |
78 |
79 | pickle.dump(dict(
80 | X = res.X,
81 | F = res.F,
82 | G = res.G,
83 | CV = res.CV,
84 | ), open(os.path.join(config.tmp_folder, "genetic_result"), "wb"))
85 |
86 | if config.problem_args["n_obj"] == 2:
87 | plot = Scatter(labels=["similarity", "discriminator",])
88 | plot.add(res.F, color="red")
89 | plot.save(os.path.join(config.tmp_folder, "F.jpg"))
90 |
91 |
92 | if config.problem_args["n_obj"] == 1:
93 | sortedpop = sorted(res.pop, key=lambda p: p.F)
94 | X = np.stack([p.X for p in sortedpop])
95 | else:
96 | X = res.pop.get("X")
97 |
98 | ls = config.latent(config)
99 | ls.set_from_population(X)
100 |
101 | torch.save(ls.state_dict(), os.path.join(config.tmp_folder, "ls_result"))
102 |
103 | if config.problem_args["n_obj"] == 1:
104 | X = np.atleast_2d(res.X)
105 | else:
106 | try:
107 | result = get_decision_making("pseudo-weights", [0, 1]).do(res.F)
108 | except:
109 | print("Warning: cant use pseudo-weights")
110 | result = get_decomposition("asf").do(res.F, [0, 1]).argmin()
111 |
112 | X = res.X[result]
113 | X = np.atleast_2d(X)
114 |
115 | ls.set_from_population(X)
116 |
117 | with torch.no_grad():
118 | generated = problem.generator.generate(ls)
119 |
120 | if config.task == "txt2img":
121 | ext = "jpg"
122 | elif config.task == "img2txt":
123 | ext = "txt"
124 |
125 | problem.generator.save(generated, os.path.join(config.tmp_folder, "output.%s" % (ext)))
--------------------------------------------------------------------------------
/stylegan2/__init__.py:
--------------------------------------------------------------------------------
1 | from . import external_models
2 | from . import metrics
3 | from . import models
4 | from . import project
5 | from . import train
6 |
--------------------------------------------------------------------------------
/stylegan2/convert_from_tf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import pickle
4 | import argparse
5 | import io
6 | import requests
7 | import torch
8 | import stylegan2
9 | from stylegan2 import utils
10 |
11 |
12 | pretrained_model_urls = {
13 | 'car-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-e.pkl',
14 | 'car-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-car-config-f.pkl',
15 | 'cat-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-cat-config-f.pkl',
16 | 'church-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-church-config-f.pkl',
17 | 'ffhq-config-e': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-e.pkl',
18 | 'ffhq-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-ffhq-config-f.pkl',
19 | 'horse-config-f': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/stylegan2-horse-config-f.pkl',
20 | 'car-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dorig.pkl',
21 | 'car-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dresnet.pkl',
22 | 'car-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gorig-Dskip.pkl',
23 | 'car-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dorig.pkl',
24 | 'car-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dresnet.pkl',
25 | 'car-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gresnet-Dskip.pkl',
26 | 'car-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dorig.pkl',
27 | 'car-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dresnet.pkl',
28 | 'car-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-car-config-e-Gskip-Dskip.pkl',
29 | 'ffhq-config-e-Gorig-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dorig.pkl',
30 | 'ffhq-config-e-Gorig-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dresnet.pkl',
31 | 'ffhq-config-e-Gorig-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gorig-Dskip.pkl',
32 | 'ffhq-config-e-Gresnet-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dorig.pkl',
33 | 'ffhq-config-e-Gresnet-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dresnet.pkl',
34 | 'ffhq-config-e-Gresnet-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gresnet-Dskip.pkl',
35 | 'ffhq-config-e-Gskip-Dorig': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dorig.pkl',
36 | 'ffhq-config-e-Gskip-Dresnet': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dresnet.pkl',
37 | 'ffhq-config-e-Gskip-Dskip': 'http://d36zk2xti64re0.cloudfront.net/stylegan2/networks/table2/stylegan2-ffhq-config-e-Gskip-Dskip.pkl',
38 | }
39 |
40 |
41 | class Unpickler(pickle.Unpickler):
42 | def find_class(self, module, name):
43 | if module == 'dnnlib.tflib.network' and name == 'Network':
44 | return utils.AttributeDict
45 | return super(Unpickler, self).find_class(module, name)
46 |
47 |
48 | def load_tf_models_file(fpath):
49 | with open(fpath, 'rb') as fp:
50 | return Unpickler(fp).load()
51 |
52 |
53 | def load_tf_models_url(url):
54 | print('Downloading file {}...'.format(url))
55 | with requests.Session() as session:
56 | with session.get(url) as ret:
57 | fp = io.BytesIO(ret.content)
58 | return Unpickler(fp).load()
59 |
60 |
61 | def convert_kwargs(static_kwargs, kwargs_mapping):
62 | kwargs = utils.AttributeDict()
63 | for key, value in static_kwargs.items():
64 | if key in kwargs_mapping:
65 | if value == 'lrelu':
66 | value = 'leaky:0.2'
67 | for k in utils.to_list(kwargs_mapping[key]):
68 | kwargs[k] = value
69 | return kwargs
70 |
71 |
72 | _PERMITTED_MODELS = ['G_main', 'G_mapping', 'G_synthesis_stylegan2', 'D_stylegan2', 'D_main', 'G_synthesis']
73 | def convert_from_tf(tf_state):
74 | tf_state = utils.AttributeDict.convert_dict_recursive(tf_state)
75 | model_type = tf_state.build_func_name
76 | assert model_type in _PERMITTED_MODELS, \
77 | 'Found model type {}. '.format(model_type) + \
78 | 'Allowed model types are: {}'.format(_PERMITTED_MODELS)
79 |
80 | if model_type == 'G_main':
81 | kwargs = convert_kwargs(
82 | static_kwargs=tf_state.static_kwargs,
83 | kwargs_mapping={
84 | 'dlatent_avg_beta': 'dlatent_avg_beta'
85 | }
86 | )
87 | kwargs.G_mapping = convert_from_tf(tf_state.components.mapping)
88 | kwargs.G_synthesis = convert_from_tf(tf_state.components.synthesis)
89 | G = stylegan2.models.Generator(**kwargs)
90 | for name, var in tf_state.variables:
91 | if name == 'dlatent_avg':
92 | G.dlatent_avg.data.copy_(torch.from_numpy(var))
93 | kwargs = convert_kwargs(
94 | static_kwargs=tf_state.static_kwargs,
95 | kwargs_mapping={
96 | 'truncation_psi': 'truncation_psi',
97 | 'truncation_cutoff': 'truncation_cutoff',
98 | 'truncation_psi_val': 'truncation_psi',
99 | 'truncation_cutoff_val': 'truncation_cutoff'
100 | }
101 | )
102 | G.set_truncation(**kwargs)
103 | return G
104 |
105 | if model_type == 'G_mapping':
106 | kwargs = convert_kwargs(
107 | static_kwargs=tf_state.static_kwargs,
108 | kwargs_mapping={
109 | 'mapping_nonlinearity': 'activation',
110 | 'normalize_latents': 'normalize_input',
111 | 'mapping_lr_mul': 'lr_mul'
112 | }
113 | )
114 | kwargs.num_layers = sum(
115 | 1 for var_name, _ in tf_state.variables
116 | if re.match('Dense[0-9]+/weight', var_name)
117 | )
118 | for var_name, var in tf_state.variables:
119 | if var_name == 'LabelConcat/weight':
120 | kwargs.label_size = var.shape[0]
121 | if var_name == 'Dense0/weight':
122 | kwargs.latent_size = var.shape[0]
123 | kwargs.hidden = var.shape[1]
124 | if var_name == 'Dense{}/bias'.format(kwargs.num_layers - 1):
125 | kwargs.out_size = var.shape[0]
126 | G_mapping = stylegan2.models.GeneratorMapping(**kwargs)
127 | for var_name, var in tf_state.variables:
128 | if re.match('Dense[0-9]+/[a-zA-Z]*', var_name):
129 | layer_idx = int(re.search('Dense(\d+)/[a-zA-Z]*', var_name).groups()[0])
130 | if var_name.endswith('weight'):
131 | G_mapping.main[layer_idx].layer.weight.data.copy_(
132 | torch.from_numpy(var.T).contiguous())
133 | elif var_name.endswith('bias'):
134 | G_mapping.main[layer_idx].bias.data.copy_(torch.from_numpy(var))
135 | if var_name == 'LabelConcat/weight':
136 | G_mapping.embedding.weight.data.copy_(torch.from_numpy(var))
137 | return G_mapping
138 |
139 | if model_type == 'G_synthesis_stylegan2' or model_type == 'G_synthesis':
140 | assert tf_state.static_kwargs.get('fused_modconv', True), \
141 | 'Can not load TF networks that use `fused_modconv=False`'
142 | noise_tensors = []
143 | conv_vars = {}
144 | for var_name, var in tf_state.variables:
145 | if var_name.startswith('noise'):
146 | noise_tensors.append(torch.from_numpy(var))
147 | else:
148 | layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
149 | if layer_size not in conv_vars:
150 | conv_vars[layer_size] = {}
151 | var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
152 | conv_vars[layer_size][var_name] = var
153 | noise_tensors = sorted(noise_tensors, key=lambda x:x.size(-1))
154 | kwargs = convert_kwargs(
155 | static_kwargs=tf_state.static_kwargs,
156 | kwargs_mapping={
157 | 'nonlinearity': 'activation',
158 | 'resample_filter': ['conv_filter', 'skip_filter']
159 | }
160 | )
161 | kwargs.skip = False
162 | kwargs.resnet = True
163 | kwargs.channels = []
164 | for size in sorted(conv_vars.keys(), reverse=True):
165 | if size == 4:
166 | if 'ToRGB/weight' in conv_vars[size]:
167 | kwargs.skip = True
168 | kwargs.resnet = False
169 | kwargs.latent_size = conv_vars[size]['Conv/mod_weight'].shape[0]
170 | kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
171 | else:
172 | kwargs.channels.append(conv_vars[size]['Conv1/bias'].shape[0])
173 | if 'ToRGB/bias' in conv_vars[size]:
174 | kwargs.data_channels = conv_vars[size]['ToRGB/bias'].shape[0]
175 | G_synthesis = stylegan2.models.GeneratorSynthesis(**kwargs)
176 | G_synthesis.const.data.copy_(torch.from_numpy(conv_vars[4]['Const/const']).squeeze(0))
177 | def assign_weights(layer, weight, bias, mod_weight, mod_bias, noise_strength, transposed=False):
178 | layer.bias.data.copy_(torch.from_numpy(bias))
179 | layer.layer.weight.data.copy_(torch.tensor(noise_strength))
180 | layer.layer.layer.dense.layer.weight.data.copy_(
181 | torch.from_numpy(mod_weight.T).contiguous())
182 | layer.layer.layer.dense.bias.data.copy_(torch.from_numpy(mod_bias + 1))
183 | weight = torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous()
184 | if transposed:
185 | weight = weight.flip(dims=[2,3])
186 | layer.layer.layer.weight.data.copy_(weight)
187 | conv_blocks = G_synthesis.conv_blocks
188 | for i, size in enumerate(sorted(conv_vars.keys())):
189 | block = conv_blocks[i]
190 | if size == 4:
191 | assign_weights(
192 | layer=block.conv_block[0],
193 | weight=conv_vars[size]['Conv/weight'],
194 | bias=conv_vars[size]['Conv/bias'],
195 | mod_weight=conv_vars[size]['Conv/mod_weight'],
196 | mod_bias=conv_vars[size]['Conv/mod_bias'],
197 | noise_strength=conv_vars[size]['Conv/noise_strength'],
198 | )
199 | else:
200 | assign_weights(
201 | layer=block.conv_block[0],
202 | weight=conv_vars[size]['Conv0_up/weight'],
203 | bias=conv_vars[size]['Conv0_up/bias'],
204 | mod_weight=conv_vars[size]['Conv0_up/mod_weight'],
205 | mod_bias=conv_vars[size]['Conv0_up/mod_bias'],
206 | noise_strength=conv_vars[size]['Conv0_up/noise_strength'],
207 | transposed=True
208 | )
209 | assign_weights(
210 | layer=block.conv_block[1],
211 | weight=conv_vars[size]['Conv1/weight'],
212 | bias=conv_vars[size]['Conv1/bias'],
213 | mod_weight=conv_vars[size]['Conv1/mod_weight'],
214 | mod_bias=conv_vars[size]['Conv1/mod_bias'],
215 | noise_strength=conv_vars[size]['Conv1/noise_strength'],
216 | )
217 | if 'Skip/weight' in conv_vars[size]:
218 | block.projection.weight.data.copy_(torch.from_numpy(
219 | conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
220 | to_RGB = G_synthesis.to_data_layers[i]
221 | if to_RGB is not None:
222 | to_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['ToRGB/bias']))
223 | to_RGB.layer.weight.data.copy_(torch.from_numpy(
224 | conv_vars[size]['ToRGB/weight']).permute((3, 2, 0, 1)).contiguous())
225 | to_RGB.layer.dense.bias.data.copy_(
226 | torch.from_numpy(conv_vars[size]['ToRGB/mod_bias'] + 1))
227 | to_RGB.layer.dense.layer.weight.data.copy_(
228 | torch.from_numpy(conv_vars[size]['ToRGB/mod_weight'].T).contiguous())
229 | if not tf_state.static_kwargs.get('randomize_noise', True):
230 | G_synthesis.static_noise(noise_tensors=noise_tensors)
231 | return G_synthesis
232 |
233 | if model_type == 'D_stylegan2' or model_type == 'D_main':
234 | output_vars = {}
235 | conv_vars = {}
236 | for var_name, var in tf_state.variables:
237 | if var_name.startswith('Output'):
238 | output_vars[var_name.replace('Output/', '')] = var
239 | else:
240 | layer_size = int(re.search('(\d+)x[0-9]+/*', var_name).groups()[0])
241 | if layer_size not in conv_vars:
242 | conv_vars[layer_size] = {}
243 | var_name = var_name.replace('{}x{}/'.format(layer_size, layer_size), '')
244 | conv_vars[layer_size][var_name] = var
245 | kwargs = convert_kwargs(
246 | static_kwargs=tf_state.static_kwargs,
247 | kwargs_mapping={
248 | 'nonlinearity': 'activation',
249 | 'resample_filter': ['conv_filter', 'skip_filter'],
250 | 'mbstd_group_size': 'mbstd_group_size'
251 | }
252 | )
253 | kwargs.skip = False
254 | kwargs.resnet = True
255 | kwargs.channels = []
256 | for size in sorted(conv_vars.keys(), reverse=True):
257 | if size == 4:
258 | if 'FromRGB/weight' in conv_vars[size]:
259 | kwargs.skip = True
260 | kwargs.resnet = False
261 | kwargs.channels.append(conv_vars[size]['Conv/bias'].shape[0])
262 | kwargs.dense_hidden = conv_vars[size]['Dense0/bias'].shape[0]
263 | else:
264 | kwargs.channels.append(conv_vars[size]['Conv0/bias'].shape[0])
265 | if 'FromRGB/weight' in conv_vars[size]:
266 | kwargs.data_channels = conv_vars[size]['FromRGB/weight'].shape[-2]
267 | output_size = output_vars['bias'].shape[0]
268 | if output_size > 1:
269 | kwargs.label_size = output_size
270 | D = stylegan2.models.Discriminator(**kwargs)
271 | def assign_weights(layer, weight, bias):
272 | layer.bias.data.copy_(torch.from_numpy(bias))
273 | layer.layer.weight.data.copy_(
274 | torch.from_numpy(weight).permute((3, 2, 0, 1)).contiguous())
275 | conv_blocks = D.conv_blocks
276 | for i, size in enumerate(sorted(conv_vars.keys())):
277 | block = conv_blocks[-i - 1]
278 | if size == 4:
279 | assign_weights(
280 | layer=block[-1].conv_block[0],
281 | weight=conv_vars[size]['Conv/weight'],
282 | bias=conv_vars[size]['Conv/bias'],
283 | )
284 | else:
285 | assign_weights(
286 | layer=block.conv_block[0],
287 | weight=conv_vars[size]['Conv0/weight'],
288 | bias=conv_vars[size]['Conv0/bias'],
289 | )
290 | assign_weights(
291 | layer=block.conv_block[1],
292 | weight=conv_vars[size]['Conv1_down/weight'],
293 | bias=conv_vars[size]['Conv1_down/bias'],
294 | )
295 | if 'Skip/weight' in conv_vars[size]:
296 | block.projection.weight.data.copy_(torch.from_numpy(
297 | conv_vars[size]['Skip/weight']).permute((3, 2, 0, 1)).contiguous())
298 | from_RGB = D.from_data_layers[-i - 1]
299 | if from_RGB is not None:
300 | from_RGB.bias.data.copy_(torch.from_numpy(conv_vars[size]['FromRGB/bias']))
301 | from_RGB.layer.weight.data.copy_(torch.from_numpy(
302 | conv_vars[size]['FromRGB/weight']).permute((3, 2, 0, 1)).contiguous())
303 | return D
304 |
305 |
306 | def get_arg_parser():
307 | parser = argparse.ArgumentParser(
308 | description='Convert tensorflow stylegan2 model to pytorch.',
309 | epilog='Pretrained models that can be downloaded:\n{}'.format(
310 | '\n'.join(pretrained_model_urls.keys()))
311 | )
312 |
313 | parser.add_argument(
314 | '-i',
315 | '--input',
316 | help='File path to pickled tensorflow models.',
317 | type=str,
318 | default=None,
319 | )
320 |
321 | parser.add_argument(
322 | '-d',
323 | '--download',
324 | help='Download the specified pretrained model. Use --help for info on available models.',
325 | type=str,
326 | default=None,
327 | )
328 |
329 | parser.add_argument(
330 | '-o',
331 | '--output',
332 | help='One or more output file paths. Alternatively a directory path ' + \
333 | 'where all models will be saved. Default: current directory',
334 | type=str,
335 | nargs='*',
336 | default=['.'],
337 | )
338 |
339 | return parser
340 |
341 |
342 | def main():
343 | args = get_arg_parser().parse_args()
344 | assert bool(args.input) != bool(args.download), \
345 | 'Incorrect input format. Can only take either one ' + \
346 | 'input filepath to a pickled tensorflow model or ' + \
347 | 'a model name to download, but not both at the same ' + \
348 | 'time or none at all.'
349 | if args.input:
350 | unpickled = load_tf_models_file(args.input)
351 | else:
352 | assert args.download in pretrained_model_urls.keys(), \
353 | 'Unknown model {}. Use --help for list of models.'.format(args.download)
354 | unpickled = load_tf_models_url(pretrained_model_urls[args.download])
355 | if not isinstance(unpickled, (tuple, list)):
356 | unpickled = [unpickled]
357 | print('Converting tensorflow models and saving them...')
358 | converted = [convert_from_tf(tf_state) for tf_state in unpickled]
359 | if len(args.output) == 1 and (os.path.isdir(args.output[0]) or not os.path.splitext(args.output[0])[-1]):
360 | if not os.path.exists(args.output[0]):
361 | os.makedirs(args.output[0])
362 | for tf_state, torch_model in zip(unpickled, converted):
363 | torch_model.save(os.path.join(args.output[0], tf_state['name'] + '.pth'))
364 | else:
365 | assert len(args.output) == len(converted), 'Found {} models '.format(len(converted)) + \
366 | 'in pickled file but only {} output paths were given.'.format(len(args.output))
367 | for out_path, torch_model in zip(args.output, converted):
368 | torch_model.save(out_path)
369 | print('Done!')
370 |
371 |
372 | if __name__ == '__main__':
373 | main()
--------------------------------------------------------------------------------
/stylegan2/external_models/__init__.py:
--------------------------------------------------------------------------------
1 | from . import inception
2 | from . import lpips
3 |
--------------------------------------------------------------------------------
/stylegan2/external_models/inception.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/mseitzer/pytorch-fid/
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 | http://www.apache.org/licenses/LICENSE-2.0
8 | Unless required by applicable law or agreed to in writing, software
9 | distributed under the License is distributed on an "AS IS" BASIS,
10 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | See the License for the specific language governing permissions and
12 | limitations under the License.
13 | """
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torchvision import models
18 |
19 | try:
20 | from torchvision.models.utils import load_state_dict_from_url
21 | except ImportError:
22 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
23 |
24 | # Inception weights ported to Pytorch from
25 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
26 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
27 |
28 |
29 | class InceptionV3FeatureExtractor(nn.Module):
30 | """Pretrained InceptionV3 network returning feature maps"""
31 |
32 | # Index of default block of inception to return,
33 | # corresponds to output of final average pooling
34 | DEFAULT_BLOCK_INDEX = 3
35 |
36 | # Maps feature dimensionality to their output blocks indices
37 | BLOCK_INDEX_BY_DIM = {
38 | 64: 0, # First max pooling features
39 | 192: 1, # Second max pooling featurs
40 | 768: 2, # Pre-aux classifier features
41 | 2048: 3 # Final average pooling features
42 | }
43 |
44 | def __init__(self,
45 | output_block=DEFAULT_BLOCK_INDEX,
46 | pixel_min=-1,
47 | pixel_max=1):
48 | """
49 | Build pretrained InceptionV3
50 | Arguments:
51 | output_block (int): Index of block to return features of.
52 | Possible values are:
53 | - 0: corresponds to output of first max pooling
54 | - 1: corresponds to output of second max pooling
55 | - 2: corresponds to output which is fed to aux classifier
56 | - 3: corresponds to output of final average pooling
57 | pixel_min (float): Min value for inputs. Default value is -1.
58 | pixel_max (float): Max value for inputs. Default value is 1.
59 | """
60 | super(InceptionV3FeatureExtractor, self).__init__()
61 |
62 | assert 0 <= output_block <= 3, '`output_block` can only be ' + \
63 | '0 <= `output_block` <= 3.'
64 |
65 | inception = fid_inception_v3()
66 |
67 | blocks = []
68 |
69 | # Block 0: input to maxpool1
70 | block0 = [
71 | inception.Conv2d_1a_3x3,
72 | inception.Conv2d_2a_3x3,
73 | inception.Conv2d_2b_3x3,
74 | nn.MaxPool2d(kernel_size=3, stride=2)
75 | ]
76 | blocks.append(nn.Sequential(*block0))
77 |
78 | # Block 1: maxpool1 to maxpool2
79 | if output_block >= 1:
80 | block1 = [
81 | inception.Conv2d_3b_1x1,
82 | inception.Conv2d_4a_3x3,
83 | nn.MaxPool2d(kernel_size=3, stride=2)
84 | ]
85 | blocks.append(nn.Sequential(*block1))
86 |
87 | # Block 2: maxpool2 to aux classifier
88 | if output_block >= 2:
89 | block2 = [
90 | inception.Mixed_5b,
91 | inception.Mixed_5c,
92 | inception.Mixed_5d,
93 | inception.Mixed_6a,
94 | inception.Mixed_6b,
95 | inception.Mixed_6c,
96 | inception.Mixed_6d,
97 | inception.Mixed_6e,
98 | ]
99 | blocks.append(nn.Sequential(*block2))
100 |
101 | # Block 3: aux classifier to final avgpool
102 | if output_block >= 3:
103 | block3 = [
104 | inception.Mixed_7a,
105 | inception.Mixed_7b,
106 | inception.Mixed_7c,
107 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
108 | ]
109 | blocks.append(nn.Sequential(*block3))
110 |
111 | self.main = nn.Sequential(*blocks)
112 | self.pixel_nin = pixel_min
113 | self.pixel_max = pixel_max
114 | self.requires_grad_(False)
115 | self.eval()
116 |
117 | def _scale(self, x):
118 | if self.pixel_min != -1 or self.pixel_max != 1:
119 | x = (2*x - self.pixel_min - self.pixel_max) \
120 | / (self.pixel_max - self.pixel_min)
121 | return x
122 |
123 | def forward(self, input):
124 | """
125 | Get Inception feature maps.
126 | Arguments:
127 | input (torch.Tensor)
128 | Returns:
129 | feature_maps (torch.Tensor)
130 | """
131 | return self.main(input)
132 |
133 |
134 | def fid_inception_v3():
135 | """Build pretrained Inception model for FID computation
136 | The Inception model for FID computation uses a different set of weights
137 | and has a slightly different structure than torchvision's Inception.
138 | This method first constructs torchvision's Inception and then patches the
139 | necessary parts that are different in the FID Inception model.
140 | """
141 | inception = models.inception_v3(num_classes=1008,
142 | aux_logits=False,
143 | pretrained=False)
144 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
145 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
146 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
147 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
148 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
149 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
150 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
151 | inception.Mixed_7b = FIDInceptionE_1(1280)
152 | inception.Mixed_7c = FIDInceptionE_2(2048)
153 |
154 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
155 | inception.load_state_dict(state_dict)
156 | return inception
157 |
158 |
159 | class FIDInceptionA(models.inception.InceptionA):
160 | """InceptionA block patched for FID computation"""
161 | def __init__(self, in_channels, pool_features):
162 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
163 |
164 | def forward(self, x):
165 | branch1x1 = self.branch1x1(x)
166 |
167 | branch5x5 = self.branch5x5_1(x)
168 | branch5x5 = self.branch5x5_2(branch5x5)
169 |
170 | branch3x3dbl = self.branch3x3dbl_1(x)
171 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
172 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
173 |
174 | # Patch: Tensorflow's average pool does not use the padded zero's in
175 | # its average calculation
176 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
177 | count_include_pad=False)
178 | branch_pool = self.branch_pool(branch_pool)
179 |
180 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
181 | return torch.cat(outputs, 1)
182 |
183 |
184 | class FIDInceptionC(models.inception.InceptionC):
185 | """InceptionC block patched for FID computation"""
186 | def __init__(self, in_channels, channels_7x7):
187 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
188 |
189 | def forward(self, x):
190 | branch1x1 = self.branch1x1(x)
191 |
192 | branch7x7 = self.branch7x7_1(x)
193 | branch7x7 = self.branch7x7_2(branch7x7)
194 | branch7x7 = self.branch7x7_3(branch7x7)
195 |
196 | branch7x7dbl = self.branch7x7dbl_1(x)
197 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
198 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
199 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
200 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
201 |
202 | # Patch: Tensorflow's average pool does not use the padded zero's in
203 | # its average calculation
204 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
205 | count_include_pad=False)
206 | branch_pool = self.branch_pool(branch_pool)
207 |
208 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
209 | return torch.cat(outputs, 1)
210 |
211 |
212 | class FIDInceptionE_1(models.inception.InceptionE):
213 | """First InceptionE block patched for FID computation"""
214 | def __init__(self, in_channels):
215 | super(FIDInceptionE_1, self).__init__(in_channels)
216 |
217 | def forward(self, x):
218 | branch1x1 = self.branch1x1(x)
219 |
220 | branch3x3 = self.branch3x3_1(x)
221 | branch3x3 = [
222 | self.branch3x3_2a(branch3x3),
223 | self.branch3x3_2b(branch3x3),
224 | ]
225 | branch3x3 = torch.cat(branch3x3, 1)
226 |
227 | branch3x3dbl = self.branch3x3dbl_1(x)
228 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
229 | branch3x3dbl = [
230 | self.branch3x3dbl_3a(branch3x3dbl),
231 | self.branch3x3dbl_3b(branch3x3dbl),
232 | ]
233 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
234 |
235 | # Patch: Tensorflow's average pool does not use the padded zero's in
236 | # its average calculation
237 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
238 | count_include_pad=False)
239 | branch_pool = self.branch_pool(branch_pool)
240 |
241 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
242 | return torch.cat(outputs, 1)
243 |
244 |
245 | class FIDInceptionE_2(models.inception.InceptionE):
246 | """Second InceptionE block patched for FID computation"""
247 | def __init__(self, in_channels):
248 | super(FIDInceptionE_2, self).__init__(in_channels)
249 |
250 | def forward(self, x):
251 | branch1x1 = self.branch1x1(x)
252 |
253 | branch3x3 = self.branch3x3_1(x)
254 | branch3x3 = [
255 | self.branch3x3_2a(branch3x3),
256 | self.branch3x3_2b(branch3x3),
257 | ]
258 | branch3x3 = torch.cat(branch3x3, 1)
259 |
260 | branch3x3dbl = self.branch3x3dbl_1(x)
261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
262 | branch3x3dbl = [
263 | self.branch3x3dbl_3a(branch3x3dbl),
264 | self.branch3x3dbl_3b(branch3x3dbl),
265 | ]
266 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
267 |
268 | # Patch: The FID Inception model uses max pooling instead of average
269 | # pooling. This is likely an error in this specific Inception
270 | # implementation, as other Inception models use average pooling here
271 | # (which matches the description in the paper).
272 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
273 | branch_pool = self.branch_pool(branch_pool)
274 |
275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
276 | return torch.cat(outputs, 1)
277 |
--------------------------------------------------------------------------------
/stylegan2/external_models/lpips.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/richzhang/PerceptualSimilarity
3 |
4 | Original License:
5 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
6 | All rights reserved.
7 |
8 | Redistribution and use in source and binary forms, with or without
9 | modification, are permitted provided that the following conditions are met:
10 |
11 | * Redistributions of source code must retain the above copyright notice, this
12 | list of conditions and the following disclaimer.
13 |
14 | * Redistributions in binary form must reproduce the above copyright notice,
15 | this list of conditions and the following disclaimer in the documentation
16 | and/or other materials provided with the distribution.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | """
29 | import torch
30 | from torch import nn
31 | import torchvision
32 |
33 |
34 | class LPIPS_VGG16(nn.Module):
35 | _FEATURE_IDX = [0, 4, 9, 16, 23, 30]
36 | _LINEAR_WEIGHTS_URL = 'https://github.com/richzhang/PerceptualSimilarity' + \
37 | '/blob/master/lpips/weights/v0.1/vgg.pth?raw=true'
38 |
39 | def __init__(self, pixel_min=-1, pixel_max=1):
40 | super(LPIPS_VGG16, self).__init__()
41 | features = torchvision.models.vgg16(pretrained=True).features
42 | self.slices = nn.ModuleList()
43 | linear_weights = torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL)
44 | for i in range(1, len(self._FEATURE_IDX)):
45 | idx_range = range(self._FEATURE_IDX[i - 1], self._FEATURE_IDX[i])
46 | self.slices.append(nn.Sequential(*[features[j] for j in idx_range]))
47 | self.linear_layers = nn.ModuleList()
48 | for weight in torch.utils.model_zoo.load_url(self._LINEAR_WEIGHTS_URL).values():
49 | weight = weight.view(1, -1)
50 | linear = nn.Linear(weight.size(1), 1, bias=False)
51 | linear.weight.data.copy_(weight)
52 | self.linear_layers.append(linear)
53 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188]).view(1, -1, 1, 1))
54 | self.register_buffer('scale', torch.Tensor([.458,.448,.450]).view(1, -1, 1, 1))
55 | self.pixel_min = pixel_min
56 | self.pixel_max = pixel_max
57 | self.requires_grad_(False)
58 | self.eval()
59 |
60 | def _scale(self, x):
61 | if self.pixel_min != -1 or self.pixel_max != 1:
62 | x = (2*x - self.pixel_min - self.pixel_max) \
63 | / (self.pixel_max - self.pixel_min)
64 | return (x - self.shift) / self.scale
65 |
66 | @staticmethod
67 | def _normalize_tensor(feature_maps, eps=1e-8):
68 | rnorm = torch.rsqrt(torch.sum(feature_maps ** 2, dim=1, keepdim=True) + eps)
69 | return feature_maps * rnorm
70 |
71 | def forward(self, x0, x1, eps=1e-8):
72 | x0, x1 = self._scale(x0), self._scale(x1)
73 | dist = 0
74 | for slice, linear in zip(self.slices, self.linear_layers):
75 | x0, x1 = slice(x0), slice(x1)
76 | _x0, _x1 = self._normalize_tensor(x0, eps), self._normalize_tensor(x1, eps)
77 | dist += linear(torch.mean((_x0 - _x1) ** 2, dim=[-1, -2]))
78 | return dist.view(-1)
79 |
--------------------------------------------------------------------------------
/stylegan2/loss_fns.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.nn import functional as F
4 |
5 | from . import utils
6 |
7 |
8 | def _grad(input, output, retain_graph):
9 | # https://discuss.pytorch.org/t/gradient-penalty-loss-with-modified-weights/64910
10 | # Currently not possible to not
11 | # retain graph for regularization losses.
12 | # Ugly hack is to always set it to True.
13 | retain_graph = True
14 | grads = torch.autograd.grad(
15 | output.sum(),
16 | input,
17 | only_inputs=True,
18 | retain_graph=retain_graph,
19 | create_graph=True
20 | )
21 | return grads[0]
22 |
23 |
24 | def _grad_pen(input, output, gamma, constraint=1, onesided=False, retain_graph=True):
25 | grad = _grad(input, output, retain_graph=retain_graph)
26 | grad = grad.view(grad.size(0), -1)
27 | grad_norm = grad.norm(2, dim=1)
28 | if onesided:
29 | gp = torch.max(0, grad_norm - constraint)
30 | else:
31 | gp = (grad_norm - constraint) ** 2
32 | return gamma * gp.mean()
33 |
34 |
35 | def _grad_reg(input, output, gamma, retain_graph=True):
36 | grad = _grad(input, output, retain_graph=retain_graph)
37 | grad = grad.view(grad.size(0), -1)
38 | gr = (grad ** 2).sum(1)
39 | return (0.5 * gamma) * gr.mean()
40 |
41 |
42 | def _pathreg(dlatents, fakes, pl_avg, pl_decay, gamma, retain_graph=True):
43 | retain_graph = True
44 | pl_noise = torch.empty_like(fakes).normal_().div_(np.sqrt(np.prod(fakes.size()[2:])))
45 | pl_grad = _grad(dlatents, torch.sum(pl_noise * fakes), retain_graph=retain_graph)
46 | pl_length = torch.sqrt(torch.mean(torch.sum(pl_grad ** 2, dim=2), dim=1))
47 | with torch.no_grad():
48 | pl_avg.add_(pl_decay * (torch.mean(pl_length) - pl_avg))
49 | return gamma * torch.mean((pl_length - pl_avg) ** 2)
50 |
51 |
52 | #----------------------------------------------------------------------------
53 | # Logistic loss from the paper
54 | # "Generative Adversarial Nets", Goodfellow et al. 2014
55 |
56 |
57 | def G_logistic(G,
58 | D,
59 | latents,
60 | latent_labels=None,
61 | *args,
62 | **kwargs):
63 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
64 | loss = - F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
65 | reg = None
66 | return loss, reg
67 |
68 |
69 | def G_logistic_ns(G,
70 | D,
71 | latents,
72 | latent_labels=None,
73 | *args,
74 | **kwargs):
75 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
76 | loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
77 | reg = None
78 | return loss, reg
79 |
80 |
81 | def D_logistic(G,
82 | D,
83 | latents,
84 | reals,
85 | latent_labels=None,
86 | real_labels=None,
87 | *args,
88 | **kwargs):
89 | assert (latent_labels is None) == (real_labels is None)
90 | with torch.no_grad():
91 | fakes = G(latents, labels=latent_labels)
92 | real_scores = D(reals, labels=real_labels).float()
93 | fake_scores = D(fakes, labels=latent_labels).float()
94 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
95 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
96 | loss = real_loss + fake_loss
97 | reg = None
98 | return loss, reg
99 |
100 |
101 | #----------------------------------------------------------------------------
102 | # R1 and R2 regularizers from the paper
103 | # "Which Training Methods for GANs do actually Converge?", Mescheder et al. 2018
104 |
105 |
106 | def D_r1(D,
107 | reals,
108 | real_labels=None,
109 | gamma=10,
110 | *args,
111 | **kwargs):
112 | loss = None
113 | reg = None
114 | if gamma:
115 | reals.requires_grad_(True)
116 | real_scores = D(reals, labels=real_labels)
117 | reg = _grad_reg(
118 | input=reals, output=real_scores, gamma=gamma, retain_graph=False).float()
119 | return loss, reg
120 |
121 |
122 | def D_r2(D,
123 | G,
124 | latents,
125 | latent_labels=None,
126 | gamma=10,
127 | *args,
128 | **kwargs):
129 | loss = None
130 | reg = None
131 | if gamma:
132 | with torch.no_grad():
133 | fakes = G(latents, labels=latent_labels)
134 | fakes.requires_grad_(True)
135 | fake_scores = D(fakes, labels=latent_labels)
136 | reg = _grad_reg(
137 | input=fakes, output=fake_scores, gamma=gamma, retain_graph=False).float()
138 | return loss, reg
139 |
140 |
141 | def D_logistic_r1(G,
142 | D,
143 | latents,
144 | reals,
145 | latent_labels=None,
146 | real_labels=None,
147 | gamma=10,
148 | *args,
149 | **kwargs):
150 | assert (latent_labels is None) == (real_labels is None)
151 | with torch.no_grad():
152 | fakes = G(latents, labels=latent_labels)
153 | if gamma:
154 | reals.requires_grad_(True)
155 | real_scores = D(reals, labels=real_labels).float()
156 | fake_scores = D(fakes, labels=latent_labels).float()
157 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
158 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
159 | loss = real_loss + fake_loss
160 | reg = None
161 | if gamma:
162 | reg = _grad_reg(
163 | input=reals, output=real_scores, gamma=gamma, retain_graph=True).float()
164 | return loss, reg
165 |
166 |
167 | def D_logistic_r2(G,
168 | D,
169 | latents,
170 | reals,
171 | latent_labels=None,
172 | real_labels=None,
173 | gamma=10,
174 | *args,
175 | **kwargs):
176 | assert (latent_labels is None) == (real_labels is None)
177 | with torch.no_grad():
178 | fakes = G(latents, labels=latent_labels)
179 | if gamma:
180 | fakes.requires_grad_(True)
181 | real_scores = D(reals, labels=real_labels).float()
182 | fake_scores = D(fakes, labels=latent_labels).float()
183 | real_loss = F.binary_cross_entropy_with_logits(real_scores, torch.ones_like(real_scores))
184 | fake_loss = F.binary_cross_entropy_with_logits(fake_scores, torch.zeros_like(fake_scores))
185 | loss = real_loss + fake_loss
186 | reg = None
187 | if gamma:
188 | reg = _grad_reg(
189 | input=fakes, output=fake_scores, gamma=gamma, retain_graph=True).float()
190 | return loss, reg
191 |
192 |
193 | #----------------------------------------------------------------------------
194 | # Non-saturating logistic loss with path length regularizer from the paper
195 | # "Analyzing and Improving the Image Quality of StyleGAN", Karras et al. 2019
196 |
197 |
198 | def G_pathreg(G,
199 | latents,
200 | pl_avg,
201 | latent_labels=None,
202 | pl_decay=0.01,
203 | gamma=2,
204 | *args,
205 | **kwargs):
206 | loss = None
207 | reg = None
208 | if gamma:
209 | fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True, mapping_grad=False)
210 | reg = _pathreg(
211 | dlatents=dlatents,
212 | fakes=fakes,
213 | pl_avg=pl_avg,
214 | pl_decay=pl_decay,
215 | gamma=gamma,
216 | retain_graph=False
217 | ).float()
218 | return loss, reg
219 |
220 |
221 | def G_logistic_ns_pathreg(G,
222 | D,
223 | latents,
224 | pl_avg,
225 | latent_labels=None,
226 | pl_decay=0.01,
227 | gamma=2,
228 | *args,
229 | **kwargs):
230 | fakes, dlatents = G(latents, labels=latent_labels, return_dlatents=True)
231 | fake_scores = D(fakes, labels=latent_labels).float()
232 | loss = F.binary_cross_entropy_with_logits(fake_scores, torch.ones_like(fake_scores))
233 | reg = None
234 | if gamma:
235 | reg = _pathreg(
236 | dlatents=dlatents,
237 | fakes=fakes,
238 | pl_avg=pl_avg,
239 | pl_decay=pl_decay,
240 | gamma=gamma,
241 | retain_graph=True
242 | ).float()
243 | return loss, reg
244 |
245 |
246 | #----------------------------------------------------------------------------
247 | # WGAN loss from the paper
248 | # "Wasserstein Generative Adversarial Networks", Arjovsky et al. 2017
249 |
250 |
251 | def G_wgan(G,
252 | D,
253 | latents,
254 | latent_labels=None,
255 | *args,
256 | **kwargs):
257 | fake_scores = D(G(latents, labels=latent_labels), labels=latent_labels).float()
258 | loss = -fake_scores.mean()
259 | reg = None
260 | return loss, reg
261 |
262 |
263 | def D_wgan(G,
264 | D,
265 | latents,
266 | reals,
267 | latent_labels=None,
268 | real_labels=None,
269 | drift_gamma=0.001,
270 | *args,
271 | **kwargs):
272 | assert (latent_labels is None) == (real_labels is None)
273 | with torch.no_grad():
274 | fakes = G(latents, labels=latent_labels)
275 | real_scores = D(reals, labels=real_labels).float()
276 | fake_scores = D(fakes, labels=latent_labels).float()
277 | loss = fake_scores.mean() - real_scores.mean()
278 | if drift_gamma:
279 | loss += drift_gamma * torch.mean(real_scores ** 2)
280 | reg = None
281 | return loss, reg
282 |
283 |
284 | #----------------------------------------------------------------------------
285 | # WGAN-GP loss from the paper
286 | # "Improved Training of Wasserstein GANs", Gulrajani et al. 2017
287 |
288 |
289 | def D_gp(G,
290 | D,
291 | latents,
292 | reals,
293 | latent_labels=None,
294 | real_labels=None,
295 | gamma=0,
296 | constraint=1,
297 | *args,
298 | **kwargs):
299 | loss = None
300 | reg = None
301 | if gamma:
302 | assert (latent_labels is None) == (real_labels is None)
303 | with torch.no_grad():
304 | fakes = G(latents, labels=latent_labels)
305 | assert reals.size() == fakes.size()
306 | if latent_labels:
307 | assert latent_labels == real_labels
308 | alpha = torch.empty(reals.size(0)).uniform_()
309 | alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
310 | interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
311 | interp_scores = D(interp, labels=latent_labels)
312 | reg = _grad_pen(
313 | input=interp, output=interp_scores, gamma=gamma, retain_graph=False).float()
314 | return loss, reg
315 |
316 |
317 | def D_wgan_gp(G,
318 | D,
319 | latents,
320 | reals,
321 | latent_labels=None,
322 | real_labels=None,
323 | gamma=0,
324 | drift_gamma=0.001,
325 | constraint=1,
326 | *args,
327 | **kwargs):
328 | assert (latent_labels is None) == (real_labels is None)
329 | with torch.no_grad():
330 | fakes = G(latents, labels=latent_labels)
331 | real_scores = D(reals, labels=real_labels).float()
332 | fake_scores = D(fakes, labels=latent_labels).float()
333 | loss = fake_scores.mean() - real_scores.mean()
334 | if drift_gamma:
335 | loss += drift_gamma * torch.mean(real_scores ** 2)
336 | reg = None
337 | if gamma:
338 | assert reals.size() == fakes.size()
339 | if latent_labels:
340 | assert latent_labels == real_labels
341 | alpha = torch.empty(reals.size(0)).uniform_()
342 | alpha = alpha.view(-1, *[1] * (reals.dim() - 1))
343 | interp = utils.lerp(reals, fakes, alpha).requires_grad_(True)
344 | interp_scores = D(interp, labels=latent_labels)
345 | reg = _grad_pen(
346 | input=interp, output=interp_scores, gamma=gamma, retain_graph=True).float()
347 | return loss, reg
348 |
--------------------------------------------------------------------------------
/stylegan2/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from . import fid
2 | from . import ppl
3 |
--------------------------------------------------------------------------------
/stylegan2/metrics/fid.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numbers
3 | import numpy as np
4 | import scipy
5 | import torch
6 | from torch.nn import functional as F
7 |
8 | from .. import models, utils
9 | from ..external_models import inception
10 |
11 |
12 | class _TruncatedDataset:
13 | """
14 | Truncates a dataset, making only part of it accessible
15 | by `torch.utils.data.DataLoader`.
16 | """
17 |
18 | def __init__(self, dataset, max_len):
19 | self.dataset = dataset
20 | self.max_len = max_len
21 |
22 | def __len__(self):
23 | return min(len(self.dataset), self.max_len)
24 |
25 | def __getitem__(self, index):
26 | return self.dataset[index]
27 |
28 |
29 | class FID:
30 | """
31 | This class evaluates the FID metric of a generator.
32 | Arguments:
33 | G (Generator)
34 | prior_generator (PriorGenerator)
35 | dataset (indexable)
36 | device (int, str, torch.device, optional): The device
37 | to use for calculations. By default, the same device
38 | is chosen as the parameters in `generator` reside on.
39 | num_samples (int): Number of samples of reals and fakes
40 | to gather statistics for which are used for calculating
41 | the metric. Default value is 50 000.
42 | fid_model (nn.Module): A model that returns feature maps
43 | of shape (batch_size, features, *). Default value
44 | is InceptionV3.
45 | fid_size (int, optional): Resize any data fed to `fid_model` by scaling
46 | the data so that its smallest side is the same size as this
47 | argument.
48 | truncation_psi (float, optional): Truncation of the generator
49 | when evaluating.
50 | truncation_cutoff (int, optional): Cutoff for truncation when
51 | evaluating.
52 | reals_batch_size (int, optional): Batch size to use for real
53 | samples statistics gathering.
54 | reals_data_workers (int, optional): Number of workers fetching
55 | the real data samples. Default value is 0.
56 | verbose (bool): Write progress of gathering statistics for reals
57 | to stdout. Default value is True.
58 | """
59 | def __init__(self,
60 | G,
61 | prior_generator,
62 | dataset,
63 | device=None,
64 | num_samples=50000,
65 | fid_model=None,
66 | fid_size=None,
67 | truncation_psi=None,
68 | truncation_cutoff=None,
69 | reals_batch_size=None,
70 | reals_data_workers=0,
71 | verbose=True):
72 | device_ids = []
73 | if isinstance(G, torch.nn.DataParallel):
74 | device_ids = G.device_ids
75 | G = utils.unwrap_module(G)
76 | assert isinstance(G, models.Generator)
77 | assert isinstance(prior_generator, utils.PriorGenerator)
78 | if device is None:
79 | device = next(G.parameters()).device
80 | else:
81 | device = torch.device(device)
82 | assert torch.device(prior_generator.device) == device, \
83 | 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
84 | 'is not the same as the specified (or infered from the model)' + \
85 | 'device ({}) for the PPL evaluation.'.format(device)
86 | G.eval().to(device)
87 | if device_ids:
88 | G = torch.nn.DataParallel(G, device_ids=device_ids)
89 | self.G = G
90 | self.prior_generator = prior_generator
91 | self.device = device
92 | self.num_samples = num_samples
93 | self.batch_size = self.prior_generator.batch_size
94 | if fid_model is None:
95 | warnings.warn(
96 | 'Using default fid model metric based on Inception V3. ' + \
97 | 'This metric will only work on image data where values are in ' + \
98 | 'the range [-1, 1], please specify another module if you want ' + \
99 | 'to use other kinds of data formats.'
100 | )
101 | fid_model = inception.InceptionV3FeatureExtractor(pixel_min=-1, pixel_max=1)
102 | if device_ids:
103 | fid_model = torch.nn.DataParallel(fid_model, device_ids)
104 | self.fid_model = fid_model.eval().to(device)
105 | self.fid_size = fid_size
106 |
107 | dataset = _TruncatedDataset(dataset, self.num_samples)
108 | dataloader = torch.utils.data.DataLoader(
109 | dataset,
110 | batch_size=reals_batch_size or self.batch_size,
111 | num_workers=reals_data_workers
112 | )
113 | features = []
114 | self.labels = []
115 |
116 | if verbose:
117 | progress = utils.ProgressWriter(
118 | np.ceil(self.num_samples / (reals_batch_size or self.batch_size)))
119 | progress.write('FID: Gathering statistics for reals...', step=False)
120 |
121 | for batch in dataloader:
122 | data = batch
123 | if isinstance(batch, (tuple, list)):
124 | data = batch[0]
125 | if len(batch) > 1:
126 | self.labels.append(batch[1])
127 | data = self._scale_for_fid(data).to(self.device)
128 | with torch.no_grad():
129 | batch_features = self.fid_model(data)
130 | batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
131 | features.append(batch_features.cpu())
132 | progress.step()
133 |
134 | if verbose:
135 | progress.write('FID: Statistics for reals gathered!', step=False)
136 | progress.close()
137 |
138 | features = torch.cat(features, dim=0).numpy()
139 |
140 | self.mu_real = np.mean(features, axis=0)
141 | self.sigma_real = np.cov(features, rowvar=False)
142 | self.truncation_psi = truncation_psi
143 | self.truncation_cutoff = truncation_cutoff
144 |
145 | def _scale_for_fid(self, data):
146 | if not self.fid_size:
147 | return data
148 | scale_factor = self.fid_size / min(data.size()[2:])
149 | if scale_factor == 1:
150 | return data
151 | mode = 'nearest'
152 | if scale_factor < 1:
153 | mode = 'area'
154 | return F.interpolate(data, scale_factor=scale_factor, mode=mode)
155 |
156 | def __call__(self, *args, **kwargs):
157 | return self.evaluate(*args, **kwargs)
158 |
159 | def evaluate(self, verbose=True):
160 | """
161 | Evaluate the FID.
162 | Arguments:
163 | verbose (bool): Write progress to stdout.
164 | Default value is True.
165 | Returns:
166 | fid (float): Metric value.
167 | """
168 | utils.unwrap_module(self.G).set_truncation(
169 | truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff)
170 | self.G.eval()
171 | features = []
172 |
173 | if verbose:
174 | progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
175 | progress.write('FID: Gathering statistics for fakes...', step=False)
176 |
177 | remaining = self.num_samples
178 | for i in range(0, self.num_samples, self.batch_size):
179 |
180 | latents, latent_labels = self.prior_generator(
181 | batch_size=min(self.batch_size, remaining))
182 | if latent_labels is not None and self.labels:
183 | latent_labels = self.labels[i].to(self.device)
184 | length = min(len(latents), len(latent_labels))
185 | latents, latent_labels = latents[:length], latent_labels[:length]
186 |
187 | with torch.no_grad():
188 | fakes = self.G(latents, labels=latent_labels)
189 |
190 | with torch.no_grad():
191 | batch_features = self.fid_model(fakes)
192 | batch_features = batch_features.view(*batch_features.size()[:2], -1).mean(-1)
193 | features.append(batch_features.cpu())
194 |
195 | remaining -= len(latents)
196 | progress.step()
197 |
198 | if verbose:
199 | progress.write('FID: Statistics for fakes gathered!', step=False)
200 | progress.close()
201 |
202 | features = torch.cat(features, dim=0).numpy()
203 |
204 | mu_fake = np.mean(features, axis=0)
205 | sigma_fake = np.cov(features, rowvar=False)
206 |
207 | m = np.square(mu_fake - self.mu_real).sum()
208 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_fake, self.sigma_real), disp=False)
209 | dist = m + np.trace(sigma_fake + self.sigma_real - 2*s)
210 | return float(np.real(dist))
211 |
--------------------------------------------------------------------------------
/stylegan2/metrics/ppl.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numbers
3 | import numpy as np
4 | import torch
5 | from torch.nn import functional as F
6 |
7 | from .. import models, utils
8 | from ..external_models import lpips
9 |
10 |
11 | class PPL:
12 | """
13 | This class evaluates the PPL metric of a generator.
14 | Arguments:
15 | G (Generator)
16 | prior_generator (PriorGenerator)
17 | device (int, str, torch.device, optional): The device
18 | to use for calculations. By default, the same device
19 | is chosen as the parameters in `generator` reside on.
20 | num_samples (int): Number of samples of reals and fakes
21 | to gather statistics for which are used for calculating
22 | the metric. Default value is 50 000.
23 | epsilon (float): Perturbation value. Default value is 1e-4.
24 | use_dlatent (bool): Measure PPL against the dlatents instead
25 | of the latents. Default value is True.
26 | full_sampling (bool): Measure on a random interpolation between
27 | two inputs. Default value is False.
28 | crop (float, list, optional): Crop values that should be in the
29 | range [0, 1] with 1 representing the entire data length.
30 | If single value this will be the amount cropped from all
31 | sides of the data. If a list of same length as number of
32 | data dimensions, each crop is mirrored to both sides of
33 | each respective dimension. If the length is 2 * number
34 | of dimensions the crop values for the start and end of
35 | a dimension may be different.
36 | Example 1:
37 | We have 1d data of length 10. We want to crop 1
38 | from the start and end of the data. We then need
39 | to use `crop=0.1` or `crop=[0.1]` or `crop=[0.1, 0.9]`.
40 | Example 2:
41 | We have 2d data (images) of size 10, 10 (height, width)
42 | and we want to use only the top left quarter of the image
43 | we would use `crop=[0, 0.5, 0, 0.5]`.
44 | lpips_model (nn.Module): A model that returns feature the distance
45 | between two inputs. Default value is the LPIPS VGG16 model.
46 | lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
47 | the data so that its smallest side is the same size as this
48 | argument. Only has a default value of 256 if `lpips_model` is unspecified.
49 | """
50 | FFHQ_CROP = [1/8 * 3, 1/8 * 7, 1/8 * 2, 1/8 * 6]
51 |
52 | def __init__(self,
53 | G,
54 | prior_generator,
55 | device=None,
56 | num_samples=50000,
57 | epsilon=1e-4,
58 | use_dlatent=True,
59 | full_sampling=False,
60 | crop=None,
61 | lpips_model=None,
62 | lpips_size=None):
63 | device_ids = []
64 | if isinstance(G, torch.nn.DataParallel):
65 | device_ids = G.device_ids
66 | G = utils.unwrap_module(G)
67 | assert isinstance(G, models.Generator)
68 | assert isinstance(prior_generator, utils.PriorGenerator)
69 | if device is None:
70 | device = next(G.parameters()).device
71 | else:
72 | device = torch.device(device)
73 | assert torch.device(prior_generator.device) == device, \
74 | 'Prior generator device ({}) '.format(torch.device(prior_generator)) + \
75 | 'is not the same as the specified (or infered from the model)' + \
76 | 'device ({}) for the PPL evaluation.'.format(device)
77 | G.eval().to(device)
78 | self.G_mapping = G.G_mapping
79 | self.G_synthesis = G.G_synthesis
80 | if device_ids:
81 | self.G_mapping = torch.nn.DataParallel(self.G_mapping, device_ids=device_ids)
82 | self.G_synthesis = torch.nn.DataParallel(self.G_synthesis, device_ids=device_ids)
83 | self.prior_generator = prior_generator
84 | self.device = device
85 | self.num_samples = num_samples
86 | self.epsilon = epsilon
87 | self.use_dlatent = use_dlatent
88 | self.full_sampling = full_sampling
89 | self.crop = crop
90 | self.batch_size = self.prior_generator.batch_size
91 | if lpips_model is None:
92 | warnings.warn(
93 | 'Using default LPIPS distance metric based on VGG 16. ' + \
94 | 'This metric will only work on image data where values are in ' + \
95 | 'the range [-1, 1], please specify an lpips module if you want ' + \
96 | 'to use other kinds of data formats.'
97 | )
98 | lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
99 | if device_ids:
100 | lpips_model = torch.nn.DataParallel(lpips_model, device_ids=device_ids)
101 | lpips_size = lpips_size or 256
102 | self.lpips_model = lpips_model.eval().to(device)
103 | self.lpips_size = lpips_size
104 |
105 | def _scale_for_lpips(self, data):
106 | if not self.lpips_size:
107 | return data
108 | scale_factor = self.lpips_size / min(data.size()[2:])
109 | if scale_factor == 1:
110 | return data
111 | mode = 'nearest'
112 | if scale_factor < 1:
113 | mode = 'area'
114 | return F.interpolate(data, scale_factor=scale_factor, mode=mode)
115 |
116 | def crop_data(self, data):
117 | if not self.crop:
118 | return data
119 | dim = data.dim() - 2
120 | if isinstance(self.crop, numbers.Number):
121 | self.crop = [self.crop]
122 | else:
123 | self.crop = list(self.crop)
124 | if len(self.crop) == 1:
125 | self.crop = [self.crop[0], (1 if self.crop[0] < 1 else size) - self.crop[0]] * dim
126 | if len(self.crop) == dim:
127 | crop = self.crop
128 | self.crop = []
129 | for value in crop:
130 | self.crop += [value, (1 if value < 1 else size) - value]
131 | assert len(self.crop) == 2 * dim, 'Crop values has to be ' + \
132 | 'a single value or a sequence of values of the same ' + \
133 | 'size as number of dimensions of the data or twice of that.'
134 | pre_index = [Ellipsis]
135 | post_index = [slice(None, None, None) for _ in range(dim)]
136 | for i in range(0, 2 * dim, 2):
137 | j = i // 2
138 | size = data.size(2 + j)
139 | crop_min, crop_max = self.crop[i:i + 2]
140 | if crop_max < 1:
141 | crop_min, crop_max = crop_min * size, crop_max * size
142 | crop_min, crop_max = max(0, int(crop_min)), min(size, int(crop_max))
143 | dim_index = post_index.copy()
144 | dim_index[j] = slice(crop_min, crop_max, None)
145 | data = data[pre_index + dim_index]
146 | return data
147 |
148 | def prep_latents(self, latents):
149 | if self.full_sampling:
150 | lerp = utils.slerp
151 | if self.use_dlatent:
152 | lerp = utils.lerp
153 | latents_a, latents_b = latents[:self.batch_size], latents[self.batch_size:]
154 | latents = lerp(
155 | latents_a,
156 | latents_b,
157 | torch.rand(
158 | latents_a.size()[:-1],
159 | dtype=latents_a.dtype,
160 | device=latents_a.device
161 | ).unsqueeze(-1)
162 | )
163 | return torch.cat([latents, latents + self.epsilon], dim=0)
164 |
165 | def __call__(self, *args, **kwargs):
166 | return self.evaluate(*args, **kwargs)
167 |
168 | def evaluate(self, verbose=True):
169 | """
170 | Evaluate the PPL.
171 | Arguments:
172 | verbose (bool): Write progress to stdout.
173 | Default value is True.
174 | Returns:
175 | ppl (float): Metric value.
176 | """
177 | distances = []
178 | batch_size = self.batch_size
179 | if self.full_sampling:
180 | batch_size = 2 * batch_size
181 |
182 | if verbose:
183 | progress = utils.ProgressWriter(np.ceil(self.num_samples / self.batch_size))
184 | progress.write('PPL: Evaluating metric...', step=False)
185 |
186 | for _ in range(0, self.num_samples, self.batch_size):
187 | utils.unwrap_module(self.G_synthesis).static_noise()
188 |
189 | latents, latent_labels = self.prior_generator(batch_size=batch_size)
190 | if latent_labels is not None and self.full_sampling:
191 | # Labels should be the same for the first and second half of latents
192 | latent_labels = latent_labels.view(2, -1)[0].repeat(2)
193 |
194 | if self.use_dlatent:
195 | with torch.no_grad():
196 | dlatents = self.G_mapping(latents=latents, labels=latent_labels)
197 | dlatents = self.prep_latents(dlatents)
198 | else:
199 | latents = self.prep_latents(latents)
200 | with torch.no_grad():
201 | dlatents = self.G_mapping(latents=latents, labels=latent_labels)
202 |
203 | dlatents = dlatents.unsqueeze(1).repeat(1, len(utils.unwrap_module(self.G_synthesis)), 1)
204 |
205 | with torch.no_grad():
206 | output = self.G_synthesis(dlatents)
207 |
208 | output = self.crop_data(output)
209 | output = self._scale_for_lpips(output)
210 |
211 | output_a, output_b = output[:self.batch_size], output[self.batch_size:]
212 |
213 | with torch.no_grad():
214 | dist = self.lpips_model(output_a, output_b)
215 |
216 | distances.append(dist.cpu() * (1 / self.epsilon ** 2))
217 |
218 | if verbose:
219 | progress.step()
220 |
221 | if verbose:
222 | progress.write('PPL: Evaluated!', step=False)
223 | progress.close()
224 |
225 | distances = torch.cat(distances, dim=0).numpy()
226 | lo = np.percentile(distances, 1, interpolation='lower')
227 | hi = np.percentile(distances, 99, interpolation='higher')
228 | filtered_distances = np.extract(np.logical_and(lo <= distances, distances <= hi), distances)
229 | return float(np.mean(filtered_distances))
230 |
--------------------------------------------------------------------------------
/stylegan2/project.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from . import models, utils
8 | from .external_models import lpips
9 |
10 |
11 | class Projector(nn.Module):
12 | """
13 | Projects data to latent space and noise tensors.
14 | Arguments:
15 | G (Generator)
16 | dlatent_avg_samples (int): Number of dlatent samples
17 | to collect to find the mean and std.
18 | Default value is 10 000.
19 | dlatent_avg_label (int, torch.Tensor, optional): The label to
20 | use when gathering dlatent statistics.
21 | dlatent_device (int, str, torch.device, optional): Device to use
22 | for gathering statistics of dlatents. By default uses
23 | the same device as parameters of `G` reside on.
24 | dlatent_batch_size (int): The batch size to sample
25 | dlatents with. Default value is 1024.
26 | lpips_model (nn.Module): A model that returns feature the distance
27 | between two inputs. Default value is the LPIPS VGG16 model.
28 | lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
29 | the data so that its smallest side is the same size as this
30 | argument. Only has a default value of 256 if `lpips_model` is unspecified.
31 | verbose (bool): Write progress of dlatent statistics gathering to stdout.
32 | Default value is True.
33 | """
34 | def __init__(self,
35 | G,
36 | dlatent_avg_samples=10000,
37 | dlatent_avg_label=None,
38 | dlatent_device=None,
39 | dlatent_batch_size=1024,
40 | lpips_model=None,
41 | lpips_size=None,
42 | verbose=True):
43 | super(Projector, self).__init__()
44 | assert isinstance(G, models.Generator)
45 | G.eval().requires_grad_(False)
46 |
47 | self.G_synthesis = G.G_synthesis
48 |
49 | G_mapping = G.G_mapping
50 |
51 | dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples)
52 |
53 | if dlatent_device is None:
54 | dlatent_device = next(G_mapping.parameters()).device()
55 | else:
56 | dlatent_device = torch.device(dlatent_device)
57 |
58 | G_mapping.to(dlatent_device)
59 |
60 | latents = torch.empty(
61 | dlatent_avg_samples, G_mapping.latent_size).normal_()
62 | dlatents = []
63 |
64 | labels = None
65 | if dlatent_avg_label is not None:
66 | labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size)
67 |
68 | if verbose:
69 | progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size))
70 | progress.write('Gathering dlatents...', step=False)
71 |
72 | for i in range(0, dlatent_avg_samples, dlatent_batch_size):
73 | batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device)
74 | batch_labels = None
75 | if labels is not None:
76 | batch_labels = labels[:len(batch_latents)]
77 | with torch.no_grad():
78 | dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu())
79 | if verbose:
80 | progress.step()
81 |
82 | if verbose:
83 | progress.write('Done!', step=False)
84 | progress.close()
85 |
86 | dlatents = torch.cat(dlatents, dim=0)
87 |
88 | self.register_buffer(
89 | '_dlatent_avg',
90 | dlatents.mean(dim=0).view(1, 1, -1)
91 | )
92 | self.register_buffer(
93 | '_dlatent_std',
94 | torch.sqrt(
95 | torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8
96 | ).view(1, 1, 1)
97 | )
98 |
99 | if lpips_model is None:
100 | warnings.warn(
101 | 'Using default LPIPS distance metric based on VGG 16. ' + \
102 | 'This metric will only work on image data where values are in ' + \
103 | 'the range [-1, 1], please specify an lpips module if you want ' + \
104 | 'to use other kinds of data formats.'
105 | )
106 | lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
107 | lpips_size = 256
108 | self.lpips_model = lpips_model.eval().requires_grad_(False)
109 | self.lpips_size = lpips_size
110 |
111 | self.to(dlatent_device)
112 |
113 | def _scale_for_lpips(self, data):
114 | if not self.lpips_size:
115 | return data
116 | scale_factor = self.lpips_size / min(data.size()[2:])
117 | if scale_factor == 1:
118 | return data
119 | mode = 'nearest'
120 | if scale_factor < 1:
121 | mode = 'area'
122 | return F.interpolate(data, scale_factor=scale_factor, mode=mode)
123 |
124 | def _check_job(self):
125 | assert self._job is not None, 'Call `start()` first to set up target.'
126 | # device of dlatent param will not change with the rest of the models
127 | # and buffers of this class as it was never registered as a buffer or
128 | # parameter. Same goes for optimizer. Make sure it is on the correct device.
129 | if self._job.dlatent_param.device != self._dlatent_avg.device:
130 | self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg)
131 | self._job.opt.load_state_dict(
132 | utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0])
133 |
134 | def generate(self):
135 | """
136 | Generate an output with the current dlatent and noise values.
137 | Returns:
138 | output (torch.Tensor)
139 | """
140 | self._check_job()
141 | with torch.no_grad():
142 | return self.G_synthesis(self._job.dlatent_param)
143 |
144 | def get_dlatent(self):
145 | """
146 | Get a copy of the current dlatent values.
147 | Returns:
148 | dlatents (torch.Tensor)
149 | """
150 | self._check_job()
151 | return self._job.dlatent_param.data.clone()
152 |
153 | def get_noise(self):
154 | """
155 | Get a copy of the current noise values.
156 | Returns:
157 | noise_tensors (list)
158 | """
159 | self._check_job()
160 | return [noise.data.clone() for noise in self._job.noise_params]
161 |
162 | def start(self,
163 | target,
164 | num_steps=1000,
165 | initial_learning_rate=0.1,
166 | initial_noise_factor=0.05,
167 | lr_rampdown_length=0.25,
168 | lr_rampup_length=0.05,
169 | noise_ramp_length=0.75,
170 | regularize_noise_weight=1e5,
171 | verbose=True,
172 | verbose_prefix=''):
173 | """
174 | Set up a target and its projection parameters.
175 | Arguments:
176 | target (torch.Tensor): The data target. This should
177 | already be preprocessed (scaled to correct value range).
178 | num_steps (int): Number of optimization steps. Default
179 | value is 1000.
180 | initial_learning_rate (float): Default value is 0.1.
181 | initial_noise_factor (float): Default value is 0.05.
182 | lr_rampdown_length (float): Default value is 0.25.
183 | lr_rampup_length (float): Default value is 0.05.
184 | noise_ramp_length (float): Default value is 0.75.
185 | regularize_noise_weight (float): Default value is 1e5.
186 | verbose (bool): Write progress to stdout every time
187 | `step()` is called.
188 | verbose_prefix (str, optional): This is written before
189 | any other output to stdout.
190 | """
191 | if target.dim() == self.G_synthesis.dim + 1:
192 | target = target.unsqueeze(0)
193 | assert target.dim() == self.G_synthesis.dim + 2, \
194 | 'Number of dimensions of target data is incorrect.'
195 |
196 | target = target.to(self._dlatent_avg)
197 | target_scaled = self._scale_for_lpips(target)
198 |
199 | dlatent_param = nn.Parameter(
200 | self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1))
201 | noise_params = self.G_synthesis.static_noise(trainable=True)
202 | params = [dlatent_param] + noise_params
203 |
204 | opt = torch.optim.Adam(params)
205 |
206 | noise_tensor = torch.empty_like(dlatent_param)
207 |
208 | if verbose:
209 | progress = utils.ProgressWriter(num_steps)
210 | value_tracker = utils.ValueTracker()
211 |
212 | self._job = utils.AttributeDict(**locals())
213 | self._job.current_step = 0
214 |
215 | def step(self, steps=1):
216 | """
217 | Take a projection step.
218 | Arguments:
219 | steps (int): Number of steps to take. If this
220 | exceeds the remaining steps of the projection
221 | that amount of steps is taken instead. Default
222 | value is 1.
223 | """
224 | self._check_job()
225 |
226 | remaining_steps = self._job.num_steps - self._job.current_step
227 | if not remaining_steps > 0:
228 | warnings.warn(
229 | 'Trying to take a projection step after the ' + \
230 | 'final projection iteration has been completed.'
231 | )
232 | if steps < 0:
233 | steps = remaining_steps
234 | steps = min(remaining_steps, steps)
235 |
236 | if not steps > 0:
237 | return
238 |
239 | for _ in range(steps):
240 |
241 | if self._job.current_step >= self._job.num_steps:
242 | break
243 |
244 | # Hyperparameters.
245 | t = self._job.current_step / self._job.num_steps
246 | noise_strength = self._dlatent_std * self._job.initial_noise_factor \
247 | * max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2
248 | lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length)
249 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
250 | lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length)
251 | learning_rate = self._job.initial_learning_rate * lr_ramp
252 |
253 | for param_group in self._job.opt.param_groups:
254 | param_group['lr'] = learning_rate
255 |
256 | dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_()
257 |
258 | output = self.G_synthesis(dlatents)
259 | assert output.size() == self._job.target.size(), \
260 | 'target size {} does not fit output size {} of generator'.format(
261 | target.size(), output.size())
262 |
263 | output_scaled = self._scale_for_lpips(output)
264 |
265 | # Main loss: LPIPS distance of output and target
266 | lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled))
267 |
268 | # Calculate noise regularization loss
269 | reg_loss = 0
270 | for p in self._job.noise_params:
271 | size = min(p.size()[2:])
272 | dim = p.dim() - 2
273 | while True:
274 | reg_loss += torch.mean(
275 | (p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2)
276 | if size <= 8:
277 | break
278 | p = F.interpolate(p, scale_factor=0.5, mode='area')
279 | size = size // 2
280 |
281 | # Combine loss, backward and update params
282 | loss = lpips_distance + self._job.regularize_noise_weight * reg_loss
283 | self._job.opt.zero_grad()
284 | loss.backward()
285 | self._job.opt.step()
286 |
287 | # Normalize noise values
288 | for p in self._job.noise_params:
289 | with torch.no_grad():
290 | p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True)
291 | p_rstd = torch.rsqrt(
292 | torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8)
293 | p.data = (p.data - p_mean) * p_rstd
294 |
295 | self._job.current_step += 1
296 |
297 | if self._job.verbose:
298 | self._job.value_tracker.add('loss', float(loss))
299 | self._job.value_tracker.add('lpips_distance', float(lpips_distance))
300 | self._job.value_tracker.add('noise_reg', float(reg_loss))
301 | self._job.value_tracker.add('lr', learning_rate, beta=0)
302 | self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker))
303 | if self._job.current_step >= self._job.num_steps:
304 | self._job.progress.close()
305 |
--------------------------------------------------------------------------------
/stylegan2/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numbers
3 | import re
4 | import sys
5 | import collections
6 | import argparse
7 | import yaml
8 | from PIL import Image
9 | import numpy as np
10 | import torch
11 | from torch import nn
12 | from torch.nn import functional as F
13 | import torchvision
14 | try:
15 | import tqdm
16 | except ImportError:
17 | pass
18 | try:
19 | from IPython.display import display as notebook_display
20 | from IPython.display import clear_output as notebook_clear
21 | except ImportError:
22 | pass
23 |
24 |
25 | #----------------------------------------------------------------------------
26 | # Miscellaneous utils
27 |
28 |
29 | class AttributeDict(dict):
30 | """
31 | Dict where values can be accessed using attribute syntax.
32 | Same as "EasyDict" in the NVIDIA stylegan git repository.
33 | """
34 |
35 | def __getattr__(self, name):
36 | try:
37 | return self[name]
38 | except KeyError:
39 | raise AttributeError(name)
40 |
41 | def __setattr__(self, name, value):
42 | self[name] = value
43 |
44 | def __delattr__(self, name):
45 | del self[name]
46 |
47 | def __getstate__(self):
48 | return dict(**self)
49 |
50 | def __setstate__(self, state):
51 | self.update(**state)
52 |
53 | def __repr__(self):
54 | return '{}({})'.format(
55 | self.__class__.__name__,
56 | ', '.join('{}={}'.format(key, value) for key, value in self.items())
57 | )
58 |
59 | @classmethod
60 | def convert_dict_recursive(cls, obj):
61 | if isinstance(obj, dict):
62 | for key in list(obj.keys()):
63 | obj[key] = cls.convert_dict_recursive(obj[key])
64 | if not isinstance(obj, cls):
65 | return cls(**obj)
66 | return obj
67 |
68 |
69 | class Timer:
70 |
71 | def __init__(self):
72 | self.reset()
73 |
74 | def __enter__(self):
75 | self._t0 = time.time()
76 |
77 | def __exit__(self, *args):
78 | self._t += time.time() - self._t0
79 |
80 | def value(self):
81 | return self._t
82 |
83 | def reset(self):
84 | self._t = 0
85 |
86 | def __str__(self):
87 | """
88 | Get a string representation of the recorded time.
89 | Returns:
90 | time_as_string (str)
91 | """
92 | value = self.value()
93 | if not value or value >= 100:
94 | return '{} s'.format(int(value))
95 | elif value >= 1:
96 | return '{:.3g} s'.format(value)
97 | elif value >= 1e-3:
98 | return '{:.3g} ms'.format(value * 1e+3)
99 | elif value >= 1e-6:
100 | return '{:.3g} us'.format(value * 1e+6)
101 | elif value >= 1e-9:
102 | return '{:.3g} ns'.format(value * 1e+9)
103 | else:
104 | return '{:.2E} s'.format(value)
105 |
106 |
107 | def to_list(values):
108 | if values is None:
109 | return []
110 | if isinstance(values, tuple):
111 | return list(values)
112 | if not isinstance(values, list):
113 | return [values]
114 | return values
115 |
116 |
117 | def lerp(a, b, beta):
118 | if isinstance(beta, numbers.Number):
119 | if beta == 1:
120 | return b
121 | elif beta == 0:
122 | return a
123 | if torch.is_tensor(a) and a.dtype == torch.float32:
124 | # torch lerp only available for fp32
125 | return torch.lerp(a, b, beta)
126 | # More numerically stable than a + beta * (b - a)
127 | return (1 - beta) * a + beta * b
128 |
129 |
130 | def _normalize(v):
131 | return v * torch.rsqrt(torch.sum(v ** 2, dim=-1, keepdim=True))
132 |
133 |
134 | def slerp(a, b, beta):
135 | assert a.size() == b.size(), 'Size mismatch between ' + \
136 | 'slerp arguments, received {} and {}'.format(a.size(), b.size())
137 | if not torch.is_tensor(beta):
138 | beta = torch.tensor(beta).to(a)
139 | a = _normalize(a)
140 | b = _normalize(b)
141 | d = torch.sum(a * b, axis=-1, keepdim=True)
142 | p = beta * torch.acos(beta)
143 | c = _normalize(b - d * a)
144 | d = a * torch.cos(p) + c * torch.sin(p)
145 | return _normalize(d)
146 |
147 |
148 | #----------------------------------------------------------------------------
149 | # Command line utils
150 |
151 |
152 | def _parse_configs(configs):
153 | kwargs = {}
154 | for config in configs:
155 | with open(config, 'r') as fp:
156 | kwargs.update(yaml.safe_load(fp))
157 | return kwargs
158 |
159 |
160 | class ConfigArgumentParser(argparse.ArgumentParser):
161 |
162 | _CONFIG_ARG_KEY = '_configs'
163 |
164 | def __init__(self, *args, **kwargs):
165 | super(ConfigArgumentParser, self).__init__(*args, **kwargs)
166 | self.add_argument(
167 | self._CONFIG_ARG_KEY,
168 | nargs='*',
169 | help='Any yaml-style config file whos values will override the defaults of this argument parser.',
170 | type=str
171 | )
172 |
173 | def parse_args(self, args=None):
174 | config_args = _parse_configs(
175 | getattr(
176 | super(ConfigArgumentParser, self).parse_args(args),
177 | self._CONFIG_ARG_KEY
178 | )
179 | )
180 | self.set_defaults(**config_args)
181 | return super(ConfigArgumentParser, self).parse_args(args)
182 |
183 |
184 | def bool_type(v):
185 | if isinstance(v, bool):
186 | return v
187 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
188 | return True
189 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
190 | return False
191 | else:
192 | raise argparse.ArgumentTypeError('Boolean value expected.')
193 |
194 |
195 | def range_type(s):
196 | """
197 | Accept either a comma separated list of numbers
198 | 'a,b,c' or a range 'a-c' and return as a list of ints.
199 | """
200 | range_re = re.compile(r'^(\d+)-(\d+)$')
201 | m = range_re.match(s)
202 | if m:
203 | return range(int(m.group(1)), int(m.group(2))+1)
204 | vals = s.split(',')
205 | return [int(x) for x in vals]
206 |
207 |
208 | #----------------------------------------------------------------------------
209 | # Dataset and generation of latents
210 |
211 |
212 | class ResizeTransform:
213 |
214 | def __init__(self, height, width, resize=True, mode='bicubic'):
215 | if resize:
216 | assert height and width, 'Height and width have to be given ' + \
217 | 'when resizing data.'
218 | self.height = height
219 | self.width = width
220 | self.resize = resize
221 | self.mode = mode
222 |
223 | def __call__(self, tensor):
224 | if self.height and self.width:
225 | if tensor.size(1) != self.height or tensor.size(2) != self.width:
226 | if self.resize:
227 | kwargs = {}
228 | if 'cubic' in self.mode or 'linear' in self.mode:
229 | kwargs.update(align_corners=False)
230 | tensor = F.interpolate(
231 | tensor.unsqueeze(0),
232 | size=(self.height, self.width),
233 | mode=self.mode,
234 | **kwargs
235 | ).squeeze(0)
236 | else:
237 | raise ValueError(
238 | 'Data shape incorrect, expected ({},{}) '.format(self.width, self.height) + \
239 | 'but got ({},{}) (width, height)'.format(tensor.size(2), tensor.size(1))
240 | )
241 | return tensor
242 |
243 |
244 | def _PIL_RGB_loader(path):
245 | return Image.open(path).convert('RGB')
246 |
247 |
248 | def _PIL_grayscale_loader(path):
249 | return Image.open(path).convert('L')
250 |
251 |
252 | class ImageFolder(torchvision.datasets.ImageFolder):
253 |
254 | def __init__(self,
255 | *args,
256 | mirror=False,
257 | pixel_min=-1,
258 | pixel_max=1,
259 | height=None,
260 | width=None,
261 | resize=False,
262 | resize_mode='bicubic',
263 | grayscale=False,
264 | **kwargs):
265 | super(ImageFolder, self).__init__(
266 | *args,
267 | loader=_PIL_grayscale_loader if grayscale else _PIL_RGB_loader,
268 | **kwargs
269 | )
270 | transforms = []
271 | if mirror:
272 | transforms.append(torchvision.transforms.RandomHorizontalFlip())
273 | transforms.append(torchvision.transforms.ToTensor())
274 | transforms.append(
275 | torchvision.transforms.Normalize(
276 | mean=[-(pixel_min / (pixel_max - pixel_min))],
277 | std=[1. / (pixel_max - pixel_min)]
278 | )
279 | )
280 | transforms.append(ResizeTransform(
281 | height=height, width=width, resize=resize, mode=resize_mode))
282 | self.transform = torchvision.transforms.Compose(transforms)
283 |
284 | def _find_classes(self, *args, **kwargs):
285 | classes, class_to_idx = super(ImageFolder, self)._find_classes(*args, **kwargs)
286 | if not classes:
287 | classes = ['']
288 | class_to_idx = {'': 0}
289 | return classes, class_to_idx
290 |
291 |
292 | class PriorGenerator:
293 |
294 | def __init__(self, latent_size, label_size, batch_size, device):
295 | self.latent_size = latent_size
296 | self.label_size = label_size
297 | self.batch_size = batch_size
298 | self.device = device
299 |
300 | def __iter__(self):
301 | return self
302 |
303 | def __next__(self):
304 | return self()
305 |
306 | def __call__(self, batch_size=None, multi_latent_prob=0, seed=None):
307 | if batch_size is None:
308 | batch_size = self.batch_size
309 | shape = [batch_size, self.latent_size]
310 | if multi_latent_prob:
311 | if seed is not None:
312 | np.random.seed(seed)
313 | if np.random.uniform() < multi_latent_prob:
314 | shape = [batch_size, 2, self.latent_size]
315 | if seed is not None:
316 | torch.manual_seed(seed)
317 | latents = torch.empty(*shape, device=self.device).normal_()
318 | labels = None
319 | if self.label_size:
320 | label_shape = [batch_size]
321 | labels = torch.randint(0, self.label_size, label_shape, device=self.device)
322 | return latents, labels
323 |
324 |
325 | #----------------------------------------------------------------------------
326 | # Training utils
327 |
328 |
329 | class MovingAverageModule:
330 |
331 | def __init__(self,
332 | from_module,
333 | to_module=None,
334 | param_beta=0.995,
335 | buffer_beta=0,
336 | device=None):
337 | from_module = unwrap_module(from_module)
338 | to_module = unwrap_module(to_module)
339 | if device is None:
340 | module = from_module
341 | if to_module is not None:
342 | module = to_module
343 | device = next(module.parameters()).device
344 | else:
345 | device = torch.device(device)
346 | self.from_module = from_module
347 | if to_module is None:
348 | self.module = from_module.clone().to(device)
349 | else:
350 | assert type(to_module) == type(from_module), \
351 | 'Mismatch between type of source and target module.'
352 | assert set(self._get_named_parameters(to_module).keys()) \
353 | == set(self._get_named_parameters(from_module).keys()), \
354 | 'Mismatch between parameters of source and target module.'
355 | assert set(self._get_named_buffers(to_module).keys()) \
356 | == set(self._get_named_buffers(from_module).keys()), \
357 | 'Mismatch between buffers of source and target module.'
358 | self.module = to_module.to(device)
359 | self.module.eval().requires_grad_(False)
360 | self.param_beta = param_beta
361 | self.buffer_beta = buffer_beta
362 | self.device = device
363 |
364 | def __getattr__(self, name):
365 | try:
366 | return super(object, self).__getattr__(name)
367 | except AttributeError:
368 | return getattr(self.module, name)
369 |
370 | def update(self):
371 | self._update_data(
372 | from_data=self._get_named_parameters(self.from_module),
373 | to_data=self._get_named_parameters(self.module),
374 | beta=self.param_beta
375 | )
376 | self._update_data(
377 | from_data=self._get_named_buffers(self.from_module),
378 | to_data=self._get_named_buffers(self.module),
379 | beta=self.buffer_beta
380 | )
381 |
382 | @staticmethod
383 | def _update_data(from_data, to_data, beta):
384 | for name in from_data.keys():
385 | if name not in to_data:
386 | continue
387 | fr, to = from_data[name], to_data[name]
388 | with torch.no_grad():
389 | if beta == 0:
390 | to.data.copy_(fr.data.to(to.data))
391 | elif beta < 1:
392 | to.data.copy_(lerp(fr.data.to(to.data), to.data, beta))
393 |
394 | @staticmethod
395 | def _get_named_parameters(module):
396 | return {name: value for name, value in module.named_parameters()}
397 |
398 | @staticmethod
399 | def _get_named_buffers(module):
400 | return {name: value for name, value in module.named_buffers()}
401 |
402 | def __call__(self, *args, **kwargs):
403 | return self.forward(*args, **kwargs)
404 |
405 | def forward(self, *args, **kwargs):
406 | self.module.eval()
407 | args, args_in_device = move_to_device(args, self.device)
408 | kwargs, kwargs_in_device = move_to_device(kwargs, self.device)
409 | in_device = None
410 | if args_in_device is not None:
411 | in_device = args_in_device
412 | if kwargs_in_device is not None:
413 | in_device = kwargs_in_device
414 | out = self.module(*args, **kwargs)
415 | if in_device is not None:
416 | out, _ = move_to_device(out, in_device)
417 | return out
418 |
419 |
420 | def move_to_device(value, device):
421 | if torch.is_tensor(value):
422 | value.to(device), value.device
423 | orig_device = None
424 | if isinstance(value, (tuple, list)):
425 | values = []
426 | for val in value:
427 | _val, orig_device = move_to_device(val, device)
428 | values.append(_val)
429 | return type(value)(values), orig_device
430 | if isinstance(value, dict):
431 | if isinstance(value, collections.OrderedDict):
432 | values = collections.OrderedDict()
433 | else:
434 | values = {}
435 | for key, val in value.items():
436 | _val, orig_device = move_to_device(val, device)
437 | values[key] = val
438 | return values, orig_device
439 | return value, orig_device
440 |
441 |
442 | _WRAPPER_CLASSES = (MovingAverageModule, nn.DataParallel, nn.parallel.DistributedDataParallel)
443 | def unwrap_module(module):
444 | if isinstance(module, _WRAPPER_CLASSES):
445 | return module.module
446 | return module
447 |
448 |
449 | def get_grad_norm_from_optimizer(optimizer, norm_type=2):
450 | """
451 | Get the gradient norm for some parameters contained in an optimizer.
452 | Arguments:
453 | optimizer (torch.optim.Optimizer)
454 | norm_type (int): Type of norm. Default value is 2.
455 | Returns:
456 | norm (float)
457 | """
458 | total_norm = 0
459 | if optimizer is not None:
460 | for param_group in optimizer.param_groups:
461 | for p in param_group['params']:
462 | if p.grad is not None:
463 | with torch.no_grad():
464 | param_norm = p.grad.data.norm(norm_type)
465 | total_norm += param_norm ** norm_type
466 | total_norm = total_norm ** (1. / norm_type)
467 | return total_norm.item()
468 |
469 |
470 | #----------------------------------------------------------------------------
471 | # printing and logging utils
472 |
473 |
474 | class ValueTracker:
475 |
476 | def __init__(self, beta=0.95):
477 | self.beta = beta
478 | self.values = {}
479 |
480 | def add(self, name, value, beta=None):
481 | if torch.is_tensor(value):
482 | value = value.item()
483 | if beta is None:
484 | beta = self.beta
485 | if name not in self.values:
486 | self.values[name] = value
487 | else:
488 | self.values[name] = lerp(value, self.values[name], beta)
489 |
490 | def __getitem__(self, key):
491 | return self.values[key]
492 |
493 | def __str__(self):
494 | string = ''
495 | for i, name in enumerate(sorted(self.values.keys())):
496 | if i and i % 3 == 0:
497 | string += '\n'
498 | elif string:
499 | string += ', '
500 | format_string = '{}: {}'
501 | if isinstance(self.values[name], float):
502 | format_string = '{}: {:.4g}'
503 | string += format_string.format(name, self.values[name])
504 | return string
505 |
506 |
507 | def is_notebook():
508 | """
509 | Check if code is running from jupyter notebook.
510 | Returns:
511 | notebook (bool): True if running from jupyter notebook,
512 | else False.
513 | """
514 | try:
515 | __IPYTHON__
516 | return True
517 | except NameError:
518 | return False
519 |
520 |
521 | def _progress_bar(count, total):
522 | """
523 | Get a simple one-line string representing a progress bar.
524 | Arguments:
525 | count (int): Current count. Starts at 0.
526 | total (int): Total count.
527 | Returns:
528 | pbar_string (str): The string progress bar.
529 | """
530 | bar_len = 60
531 | filled_len = int(round(bar_len * (count + 1) / float(total)))
532 | bar = '=' * filled_len + '-' * (bar_len - filled_len)
533 | return '[{}] {}/{}'.format(bar, count + 1, total)
534 |
535 |
536 | class ProgressWriter:
537 | """
538 | Handles writing output and displaying a progress bar. Automatically
539 | adjust for notebooks. Supports outputting text
540 | that is compatible with the progressbar (in notebooks the text is
541 | refreshed instead of printed).
542 | Arguments:
543 | length (int, optional): Total length of the progressbar.
544 | Default value is None.
545 | progress_bar (bool, optional): Display a progressbar.
546 | Default value is True.
547 | clear (bool, optional): If running from a notebook, clear
548 | the current cell's output. Default value is False.
549 | """
550 | def __init__(self, length=None, progress_bar=True, clear=False):
551 | if is_notebook() and clear:
552 | notebook_clear()
553 |
554 | if length is not None:
555 | length = int(length)
556 | self.length = length
557 | self.count = 0
558 |
559 | self._simple_pbar = False
560 | if progress_bar and 'tqdm' not in sys.modules:
561 | self._simple_pbar = True
562 |
563 | progress_bar = progress_bar and 'tqdm' in sys.modules
564 |
565 | self._progress_bar = None
566 | if progress_bar:
567 | pbar = tqdm.tqdm
568 | if is_notebook():
569 | pbar = tqdm.tqdm_notebook
570 | if length is not None:
571 | self._progress_bar = pbar(total=length, file=sys.stdout)
572 | else:
573 | self._progress_bar = pbar(file=sys.stdout)
574 |
575 | if is_notebook():
576 | self._writer = notebook_display(
577 | _StrRepr(''),
578 | display_id=time.asctime()
579 | )
580 | else:
581 | if progress_bar:
582 | self._writer = self._progress_bar
583 | else:
584 | self._writer = sys.stdout
585 |
586 | def write(self, *lines, step=True):
587 | """
588 | Output values to stdout (or a display object if called from notebook).
589 | Arguments:
590 | *lines: The lines to write (positional arguments).
591 | step (bool): Update the progressbar if present.
592 | Default value is True.
593 | """
594 | string = '\n'.join(str(line) for line in lines if line and line.strip())
595 | if self._simple_pbar:
596 | string = _progress_bar(self.count, self.length) + '\n' + string
597 | if is_notebook():
598 | self._writer.update(_StrRepr(string))
599 | else:
600 | self._writer.write('\n\n' + string)
601 | if hasattr(self._writer, 'flush'):
602 | self._writer.flush()
603 | if step:
604 | self.step()
605 |
606 | def step(self):
607 | """
608 | Update the progressbar if present.
609 | """
610 | self.count += 1
611 | if self._progress_bar is not None:
612 | self._progress_bar.update()
613 |
614 | def __iter__(self):
615 | return self
616 |
617 | def __next__(self):
618 | return next(self.rnge)
619 |
620 | def close(self):
621 | if hasattr(self._writer, 'close'):
622 | can_close = True
623 | try:
624 | can_close = self._writer != sys.stdout and self._writer != sys.stderr
625 | except AttributeError:
626 | pass
627 | if can_close:
628 | self._writer.close()
629 | if hasattr(self._progress_bar, 'close'):
630 | self._progress_bar.close()
631 |
632 | def __del__(self):
633 | self.close()
634 |
635 |
636 | class _StrRepr:
637 | """
638 | A wrapper for strings that returns the string
639 | on repr() calls. Used by notebooks.
640 | """
641 | def __init__(self, string):
642 | self.string = string
643 |
644 | def __repr__(self):
645 | return self.string
646 |
647 |
648 | #----------------------------------------------------------------------------
649 | # image utils
650 |
651 |
652 | def tensor_to_PIL(image_tensor, pixel_min=-1, pixel_max=1):
653 | image_tensor = image_tensor.cpu()
654 | if pixel_min != 0 or pixel_max != 1:
655 | image_tensor = (image_tensor - pixel_min) / (pixel_max - pixel_min)
656 | image_tensor.clamp_(min=0, max=1)
657 | to_pil = torchvision.transforms.functional.to_pil_image
658 | if image_tensor.dim() == 4:
659 | return [to_pil(img) for img in image_tensor]
660 | return to_pil(image_tensor)
661 |
662 |
663 | def PIL_to_tensor(image, pixel_min=-1, pixel_max=1):
664 | to_tensor = torchvision.transforms.functional.to_tensor
665 | if isinstance(image, (list, tuple)):
666 | image_tensor = torch.stack([to_tensor(img) for img in image])
667 | else:
668 | image_tensor = to_tensor(image)
669 | if pixel_min != 0 or pixel_max != 1:
670 | image_tensor = image_tensor * (pixel_max - pixel_min) + pixel_min
671 | return image_tensor
672 |
673 |
674 | def stack_images_PIL(imgs, shape=None, individual_img_size=None):
675 | """
676 | Concatenate multiple images into a grid within a single image.
677 | Arguments:
678 | imgs (Sequence of PIL.Image): Input images.
679 | shape (list, tuple, int, optional): Shape of the grid. Should consist
680 | of two values, (width, height). If an integer value is passed it
681 | is used for both width and height. If no value is passed the shape
682 | is infered from the number of images. Default value is None.
683 | individual_img_size (list, tuple, int, optional): The size of the
684 | images being concatenated. Default value is None.
685 | Returns:
686 | canvas (PIL.Image): Image containing input images in a grid.
687 | """
688 | assert len(imgs) > 0, 'No images received.'
689 | if shape is None:
690 | size = int(np.ceil(np.sqrt(len(imgs))))
691 | shape = [int(np.ceil(len(imgs) / size)), size]
692 | else:
693 | if isinstance(shape, numbers.Number):
694 | shape = 2 * [shape]
695 | assert len(shape) == 2, 'Shape should specify (width, height).'
696 |
697 | if individual_img_size is None:
698 | for i in range(len(imgs) - 1):
699 | assert imgs[i].size == imgs[i + 1].size, \
700 | 'Images are of different sizes, please specify a ' + \
701 | 'size (width, height). Found sizes:\n' + \
702 | ', '.join(str(img.size) for img in imgs)
703 | individual_img_size = imgs[0].size
704 | else:
705 | if not isinstance(individual_img_size, (tuple, list)):
706 | individual_img_size = 2 * (individual_img_size,)
707 | individual_img_size = tuple(individual_img_size)
708 | for i in range(len(imgs)):
709 | if imgs[i].size != individual_img_size:
710 | imgs[i] = imgs[i].resize(individual_img_size)
711 |
712 | width, height = individual_img_size
713 | width, height = int(width), int(height)
714 | canvas = Image.new(
715 | 'RGB',
716 | (shape[0] * width, shape[1] * height),
717 | (0, 0, 0, 0)
718 | )
719 | imgs = imgs.copy()
720 | for h_i in range(shape[1]):
721 | for w_i in range(shape[0]):
722 | if len(imgs) > 0:
723 | img = imgs.pop(0).convert('RGB')
724 | offset = (w_i * width, h_i * height)
725 | canvas.paste(img, offset)
726 | return canvas
727 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | from matplotlib import pyplot as plt
4 |
5 | def save_grid(images, path):
6 | grid = torchvision.utils.make_grid(images)
7 | torchvision.utils.save_image(grid, path)
8 |
9 | def show_grid(images):
10 | grid = torchvision.utils.make_grid(images)
11 | plt.imshow(grid.permute(1, 2, 0).cpu().detach().numpy())
12 | plt.show()
13 |
14 | def biggan_norm(images):
15 | images = (images + 1) / 2.0
16 | images = images.clip(0, 1)
17 | return images
18 |
19 | def biggan_denorm(images):
20 | images = images*2 - 1
21 | return images
22 |
23 |
24 | def freeze_model(model):
25 | for param in model.parameters():
26 | param.requires_grad = False
27 |
28 |
--------------------------------------------------------------------------------