├── .gitignore
├── LICENSE.txt
├── LICENSE_ORIGINAL.txt
├── README.md
├── config.py
├── configs
├── sample.yaml
├── sample_cari2_128.yaml
├── sample_cari2_128_truncation.yaml
├── sample_cari2_128_truncation_logistic.yaml
├── sample_celeba.yaml
├── sample_celeba_128.yaml
├── sample_celeba_128_trancation.yaml
├── sample_celeba_128_trancation_logistic.yaml
├── sample_conditional.yaml
├── sample_ffhq_1024.yaml
├── sample_ffhq_1024_truncation.yaml
├── sample_ffhq_128.yaml
├── sample_race.yaml
├── sample_race_256.yaml
├── sample_race_256_mix.yaml
└── sample_race_256_mix_truncation.yaml
├── convert.py
├── data
├── __init__.py
├── datasets.py
└── transforms.py
├── diagrams
├── cari2_128.png
├── ffhq_1024.png
├── ffhq_128.png
├── figure03-style-mixing-mix.png
├── figure08-truncation-trick.png
└── grid.png
├── dnnlib
├── __init__.py
├── submission
│ ├── __init__.py
│ ├── _internal
│ │ └── run.py
│ ├── run_context.py
│ └── submit.py
├── tflib
│ ├── __init__.py
│ ├── autosummary.py
│ ├── network.py
│ ├── optimizer.py
│ └── tfutil.py
└── util.py
├── generate_grid.py
├── generate_mixing_figure.py
├── generate_samples.py
├── generate_truncation_figure.py
├── models
├── Blocks.py
├── CustomLayers.py
├── GAN.py
├── Losses.py
└── __init__.py
├── requirements.txt
├── test
├── __init__.py
├── test_Blocks.py
└── test_CustomLayers.py
├── train.py
└── utils
├── __init__.py
├── copy.py
└── logger.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 |
3 | .vscode
4 | .idea
5 | output
6 | scripts
7 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, akanimax. All rights reserved.
2 |
3 |
4 | Attribution-NonCommercial 4.0 International
5 |
6 | =======================================================================
7 |
8 | Creative Commons Corporation ("Creative Commons") is not a law firm and
9 | does not provide legal services or legal advice. Distribution of
10 | Creative Commons public licenses does not create a lawyer-client or
11 | other relationship. Creative Commons makes its licenses and related
12 | information available on an "as-is" basis. Creative Commons gives no
13 | warranties regarding its licenses, any material licensed under their
14 | terms and conditions, or any related information. Creative Commons
15 | disclaims all liability for damages resulting from their use to the
16 | fullest extent possible.
17 |
18 | Using Creative Commons Public Licenses
19 |
20 | Creative Commons public licenses provide a standard set of terms and
21 | conditions that creators and other rights holders may use to share
22 | original works of authorship and other material subject to copyright
23 | and certain other rights specified in the public license below. The
24 | following considerations are for informational purposes only, are not
25 | exhaustive, and do not form part of our licenses.
26 |
27 | Considerations for licensors: Our public licenses are
28 | intended for use by those authorized to give the public
29 | permission to use material in ways otherwise restricted by
30 | copyright and certain other rights. Our licenses are
31 | irrevocable. Licensors should read and understand the terms
32 | and conditions of the license they choose before applying it.
33 | Licensors should also secure all rights necessary before
34 | applying our licenses so that the public can reuse the
35 | material as expected. Licensors should clearly mark any
36 | material not subject to the license. This includes other CC-
37 | licensed material, or material used under an exception or
38 | limitation to copyright. More considerations for licensors:
39 | wiki.creativecommons.org/Considerations_for_licensors
40 |
41 | Considerations for the public: By using one of our public
42 | licenses, a licensor grants the public permission to use the
43 | licensed material under specified terms and conditions. If
44 | the licensor's permission is not necessary for any reason--for
45 | example, because of any applicable exception or limitation to
46 | copyright--then that use is not regulated by the license. Our
47 | licenses grant only permissions under copyright and certain
48 | other rights that a licensor has authority to grant. Use of
49 | the licensed material may still be restricted for other
50 | reasons, including because others have copyright or other
51 | rights in the material. A licensor may make special requests,
52 | such as asking that all changes be marked or described.
53 | Although not required by our licenses, you are encouraged to
54 | respect those requests where reasonable. More_considerations
55 | for the public:
56 | wiki.creativecommons.org/Considerations_for_licensees
57 |
58 | =======================================================================
59 |
60 | Creative Commons Attribution-NonCommercial 4.0 International Public
61 | License
62 |
63 | By exercising the Licensed Rights (defined below), You accept and agree
64 | to be bound by the terms and conditions of this Creative Commons
65 | Attribution-NonCommercial 4.0 International Public License ("Public
66 | License"). To the extent this Public License may be interpreted as a
67 | contract, You are granted the Licensed Rights in consideration of Your
68 | acceptance of these terms and conditions, and the Licensor grants You
69 | such rights in consideration of benefits the Licensor receives from
70 | making the Licensed Material available under these terms and
71 | conditions.
72 |
73 |
74 | Section 1 -- Definitions.
75 |
76 | a. Adapted Material means material subject to Copyright and Similar
77 | Rights that is derived from or based upon the Licensed Material
78 | and in which the Licensed Material is translated, altered,
79 | arranged, transformed, or otherwise modified in a manner requiring
80 | permission under the Copyright and Similar Rights held by the
81 | Licensor. For purposes of this Public License, where the Licensed
82 | Material is a musical work, performance, or sound recording,
83 | Adapted Material is always produced where the Licensed Material is
84 | synched in timed relation with a moving image.
85 |
86 | b. Adapter's License means the license You apply to Your Copyright
87 | and Similar Rights in Your contributions to Adapted Material in
88 | accordance with the terms and conditions of this Public License.
89 |
90 | c. Copyright and Similar Rights means copyright and/or similar rights
91 | closely related to copyright including, without limitation,
92 | performance, broadcast, sound recording, and Sui Generis Database
93 | Rights, without regard to how the rights are labeled or
94 | categorized. For purposes of this Public License, the rights
95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
96 | Rights.
97 | d. Effective Technological Measures means those measures that, in the
98 | absence of proper authority, may not be circumvented under laws
99 | fulfilling obligations under Article 11 of the WIPO Copyright
100 | Treaty adopted on December 20, 1996, and/or similar international
101 | agreements.
102 |
103 | e. Exceptions and Limitations means fair use, fair dealing, and/or
104 | any other exception or limitation to Copyright and Similar Rights
105 | that applies to Your use of the Licensed Material.
106 |
107 | f. Licensed Material means the artistic or literary work, database,
108 | or other material to which the Licensor applied this Public
109 | License.
110 |
111 | g. Licensed Rights means the rights granted to You subject to the
112 | terms and conditions of this Public License, which are limited to
113 | all Copyright and Similar Rights that apply to Your use of the
114 | Licensed Material and that the Licensor has authority to license.
115 |
116 | h. Licensor means the individual(s) or entity(ies) granting rights
117 | under this Public License.
118 |
119 | i. NonCommercial means not primarily intended for or directed towards
120 | commercial advantage or monetary compensation. For purposes of
121 | this Public License, the exchange of the Licensed Material for
122 | other material subject to Copyright and Similar Rights by digital
123 | file-sharing or similar means is NonCommercial provided there is
124 | no payment of monetary compensation in connection with the
125 | exchange.
126 |
127 | j. Share means to provide material to the public by any means or
128 | process that requires permission under the Licensed Rights, such
129 | as reproduction, public display, public performance, distribution,
130 | dissemination, communication, or importation, and to make material
131 | available to the public including in ways that members of the
132 | public may access the material from a place and at a time
133 | individually chosen by them.
134 |
135 | k. Sui Generis Database Rights means rights other than copyright
136 | resulting from Directive 96/9/EC of the European Parliament and of
137 | the Council of 11 March 1996 on the legal protection of databases,
138 | as amended and/or succeeded, as well as other essentially
139 | equivalent rights anywhere in the world.
140 |
141 | l. You means the individual or entity exercising the Licensed Rights
142 | under this Public License. Your has a corresponding meaning.
143 |
144 |
145 | Section 2 -- Scope.
146 |
147 | a. License grant.
148 |
149 | 1. Subject to the terms and conditions of this Public License,
150 | the Licensor hereby grants You a worldwide, royalty-free,
151 | non-sublicensable, non-exclusive, irrevocable license to
152 | exercise the Licensed Rights in the Licensed Material to:
153 |
154 | a. reproduce and Share the Licensed Material, in whole or
155 | in part, for NonCommercial purposes only; and
156 |
157 | b. produce, reproduce, and Share Adapted Material for
158 | NonCommercial purposes only.
159 |
160 | 2. Exceptions and Limitations. For the avoidance of doubt, where
161 | Exceptions and Limitations apply to Your use, this Public
162 | License does not apply, and You do not need to comply with
163 | its terms and conditions.
164 |
165 | 3. Term. The term of this Public License is specified in Section
166 | 6(a).
167 |
168 | 4. Media and formats; technical modifications allowed. The
169 | Licensor authorizes You to exercise the Licensed Rights in
170 | all media and formats whether now known or hereafter created,
171 | and to make technical modifications necessary to do so. The
172 | Licensor waives and/or agrees not to assert any right or
173 | authority to forbid You from making technical modifications
174 | necessary to exercise the Licensed Rights, including
175 | technical modifications necessary to circumvent Effective
176 | Technological Measures. For purposes of this Public License,
177 | simply making modifications authorized by this Section 2(a)
178 | (4) never produces Adapted Material.
179 |
180 | 5. Downstream recipients.
181 |
182 | a. Offer from the Licensor -- Licensed Material. Every
183 | recipient of the Licensed Material automatically
184 | receives an offer from the Licensor to exercise the
185 | Licensed Rights under the terms and conditions of this
186 | Public License.
187 |
188 | b. No downstream restrictions. You may not offer or impose
189 | any additional or different terms or conditions on, or
190 | apply any Effective Technological Measures to, the
191 | Licensed Material if doing so restricts exercise of the
192 | Licensed Rights by any recipient of the Licensed
193 | Material.
194 |
195 | 6. No endorsement. Nothing in this Public License constitutes or
196 | may be construed as permission to assert or imply that You
197 | are, or that Your use of the Licensed Material is, connected
198 | with, or sponsored, endorsed, or granted official status by,
199 | the Licensor or others designated to receive attribution as
200 | provided in Section 3(a)(1)(A)(i).
201 |
202 | b. Other rights.
203 |
204 | 1. Moral rights, such as the right of integrity, are not
205 | licensed under this Public License, nor are publicity,
206 | privacy, and/or other similar personality rights; however, to
207 | the extent possible, the Licensor waives and/or agrees not to
208 | assert any such rights held by the Licensor to the limited
209 | extent necessary to allow You to exercise the Licensed
210 | Rights, but not otherwise.
211 |
212 | 2. Patent and trademark rights are not licensed under this
213 | Public License.
214 |
215 | 3. To the extent possible, the Licensor waives any right to
216 | collect royalties from You for the exercise of the Licensed
217 | Rights, whether directly or through a collecting society
218 | under any voluntary or waivable statutory or compulsory
219 | licensing scheme. In all other cases the Licensor expressly
220 | reserves any right to collect such royalties, including when
221 | the Licensed Material is used other than for NonCommercial
222 | purposes.
223 |
224 |
225 | Section 3 -- License Conditions.
226 |
227 | Your exercise of the Licensed Rights is expressly made subject to the
228 | following conditions.
229 |
230 | a. Attribution.
231 |
232 | 1. If You Share the Licensed Material (including in modified
233 | form), You must:
234 |
235 | a. retain the following if it is supplied by the Licensor
236 | with the Licensed Material:
237 |
238 | i. identification of the creator(s) of the Licensed
239 | Material and any others designated to receive
240 | attribution, in any reasonable manner requested by
241 | the Licensor (including by pseudonym if
242 | designated);
243 |
244 | ii. a copyright notice;
245 |
246 | iii. a notice that refers to this Public License;
247 |
248 | iv. a notice that refers to the disclaimer of
249 | warranties;
250 |
251 | v. a URI or hyperlink to the Licensed Material to the
252 | extent reasonably practicable;
253 |
254 | b. indicate if You modified the Licensed Material and
255 | retain an indication of any previous modifications; and
256 |
257 | c. indicate the Licensed Material is licensed under this
258 | Public License, and include the text of, or the URI or
259 | hyperlink to, this Public License.
260 |
261 | 2. You may satisfy the conditions in Section 3(a)(1) in any
262 | reasonable manner based on the medium, means, and context in
263 | which You Share the Licensed Material. For example, it may be
264 | reasonable to satisfy the conditions by providing a URI or
265 | hyperlink to a resource that includes the required
266 | information.
267 |
268 | 3. If requested by the Licensor, You must remove any of the
269 | information required by Section 3(a)(1)(A) to the extent
270 | reasonably practicable.
271 |
272 | 4. If You Share Adapted Material You produce, the Adapter's
273 | License You apply must not prevent recipients of the Adapted
274 | Material from complying with this Public License.
275 |
276 |
277 | Section 4 -- Sui Generis Database Rights.
278 |
279 | Where the Licensed Rights include Sui Generis Database Rights that
280 | apply to Your use of the Licensed Material:
281 |
282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
283 | to extract, reuse, reproduce, and Share all or a substantial
284 | portion of the contents of the database for NonCommercial purposes
285 | only;
286 |
287 | b. if You include all or a substantial portion of the database
288 | contents in a database in which You have Sui Generis Database
289 | Rights, then the database in which You have Sui Generis Database
290 | Rights (but not its individual contents) is Adapted Material; and
291 |
292 | c. You must comply with the conditions in Section 3(a) if You Share
293 | all or a substantial portion of the contents of the database.
294 |
295 | For the avoidance of doubt, this Section 4 supplements and does not
296 | replace Your obligations under this Public License where the Licensed
297 | Rights include other Copyright and Similar Rights.
298 |
299 |
300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
301 |
302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
312 |
313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
322 |
323 | c. The disclaimer of warranties and limitation of liability provided
324 | above shall be interpreted in a manner that, to the extent
325 | possible, most closely approximates an absolute disclaimer and
326 | waiver of all liability.
327 |
328 |
329 | Section 6 -- Term and Termination.
330 |
331 | a. This Public License applies for the term of the Copyright and
332 | Similar Rights licensed here. However, if You fail to comply with
333 | this Public License, then Your rights under this Public License
334 | terminate automatically.
335 |
336 | b. Where Your right to use the Licensed Material has terminated under
337 | Section 6(a), it reinstates:
338 |
339 | 1. automatically as of the date the violation is cured, provided
340 | it is cured within 30 days of Your discovery of the
341 | violation; or
342 |
343 | 2. upon express reinstatement by the Licensor.
344 |
345 | For the avoidance of doubt, this Section 6(b) does not affect any
346 | right the Licensor may have to seek remedies for Your violations
347 | of this Public License.
348 |
349 | c. For the avoidance of doubt, the Licensor may also offer the
350 | Licensed Material under separate terms or conditions or stop
351 | distributing the Licensed Material at any time; however, doing so
352 | will not terminate this Public License.
353 |
354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
355 | License.
356 |
357 |
358 | Section 7 -- Other Terms and Conditions.
359 |
360 | a. The Licensor shall not be bound by any additional or different
361 | terms or conditions communicated by You unless expressly agreed.
362 |
363 | b. Any arrangements, understandings, or agreements regarding the
364 | Licensed Material not stated herein are separate from and
365 | independent of the terms and conditions of this Public License.
366 |
367 |
368 | Section 8 -- Interpretation.
369 |
370 | a. For the avoidance of doubt, this Public License does not, and
371 | shall not be interpreted to, reduce, limit, restrict, or impose
372 | conditions on any use of the Licensed Material that could lawfully
373 | be made without permission under this Public License.
374 |
375 | b. To the extent possible, if any provision of this Public License is
376 | deemed unenforceable, it shall be automatically reformed to the
377 | minimum extent necessary to make it enforceable. If the provision
378 | cannot be reformed, it shall be severed from this Public License
379 | without affecting the enforceability of the remaining terms and
380 | conditions.
381 |
382 | c. No term or condition of this Public License will be waived and no
383 | failure to comply consented to unless expressly agreed to by the
384 | Licensor.
385 |
386 | d. Nothing in this Public License constitutes or may be interpreted
387 | as a limitation upon, or waiver of, any privileges and immunities
388 | that apply to the Licensor or You, including from the legal
389 | processes of any jurisdiction or authority.
390 |
391 | =======================================================================
392 |
393 | Creative Commons is not a party to its public
394 | licenses. Notwithstanding, Creative Commons may elect to apply one of
395 | its public licenses to material it publishes and in those instances
396 | will be considered the "Licensor." The text of the Creative Commons
397 | public licenses is dedicated to the public domain under the CC0 Public
398 | Domain Dedication. Except for the limited purpose of indicating that
399 | material is shared under a Creative Commons public license or as
400 | otherwise permitted by the Creative Commons policies published at
401 | creativecommons.org/policies, Creative Commons does not authorize the
402 | use of the trademark "Creative Commons" or any other trademark or logo
403 | of Creative Commons without its prior written consent including,
404 | without limitation, in connection with any unauthorized modifications
405 | to any of its public licenses or any other arrangements,
406 | understandings, or agreements concerning use of licensed material. For
407 | the avoidance of doubt, this paragraph does not form part of the
408 | public licenses.
409 |
410 | Creative Commons may be contacted at creativecommons.org.
411 |
--------------------------------------------------------------------------------
/LICENSE_ORIGINAL.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 |
3 |
4 | Attribution-NonCommercial 4.0 International
5 |
6 | =======================================================================
7 |
8 | Creative Commons Corporation ("Creative Commons") is not a law firm and
9 | does not provide legal services or legal advice. Distribution of
10 | Creative Commons public licenses does not create a lawyer-client or
11 | other relationship. Creative Commons makes its licenses and related
12 | information available on an "as-is" basis. Creative Commons gives no
13 | warranties regarding its licenses, any material licensed under their
14 | terms and conditions, or any related information. Creative Commons
15 | disclaims all liability for damages resulting from their use to the
16 | fullest extent possible.
17 |
18 | Using Creative Commons Public Licenses
19 |
20 | Creative Commons public licenses provide a standard set of terms and
21 | conditions that creators and other rights holders may use to share
22 | original works of authorship and other material subject to copyright
23 | and certain other rights specified in the public license below. The
24 | following considerations are for informational purposes only, are not
25 | exhaustive, and do not form part of our licenses.
26 |
27 | Considerations for licensors: Our public licenses are
28 | intended for use by those authorized to give the public
29 | permission to use material in ways otherwise restricted by
30 | copyright and certain other rights. Our licenses are
31 | irrevocable. Licensors should read and understand the terms
32 | and conditions of the license they choose before applying it.
33 | Licensors should also secure all rights necessary before
34 | applying our licenses so that the public can reuse the
35 | material as expected. Licensors should clearly mark any
36 | material not subject to the license. This includes other CC-
37 | licensed material, or material used under an exception or
38 | limitation to copyright. More considerations for licensors:
39 | wiki.creativecommons.org/Considerations_for_licensors
40 |
41 | Considerations for the public: By using one of our public
42 | licenses, a licensor grants the public permission to use the
43 | licensed material under specified terms and conditions. If
44 | the licensor's permission is not necessary for any reason--for
45 | example, because of any applicable exception or limitation to
46 | copyright--then that use is not regulated by the license. Our
47 | licenses grant only permissions under copyright and certain
48 | other rights that a licensor has authority to grant. Use of
49 | the licensed material may still be restricted for other
50 | reasons, including because others have copyright or other
51 | rights in the material. A licensor may make special requests,
52 | such as asking that all changes be marked or described.
53 | Although not required by our licenses, you are encouraged to
54 | respect those requests where reasonable. More_considerations
55 | for the public:
56 | wiki.creativecommons.org/Considerations_for_licensees
57 |
58 | =======================================================================
59 |
60 | Creative Commons Attribution-NonCommercial 4.0 International Public
61 | License
62 |
63 | By exercising the Licensed Rights (defined below), You accept and agree
64 | to be bound by the terms and conditions of this Creative Commons
65 | Attribution-NonCommercial 4.0 International Public License ("Public
66 | License"). To the extent this Public License may be interpreted as a
67 | contract, You are granted the Licensed Rights in consideration of Your
68 | acceptance of these terms and conditions, and the Licensor grants You
69 | such rights in consideration of benefits the Licensor receives from
70 | making the Licensed Material available under these terms and
71 | conditions.
72 |
73 |
74 | Section 1 -- Definitions.
75 |
76 | a. Adapted Material means material subject to Copyright and Similar
77 | Rights that is derived from or based upon the Licensed Material
78 | and in which the Licensed Material is translated, altered,
79 | arranged, transformed, or otherwise modified in a manner requiring
80 | permission under the Copyright and Similar Rights held by the
81 | Licensor. For purposes of this Public License, where the Licensed
82 | Material is a musical work, performance, or sound recording,
83 | Adapted Material is always produced where the Licensed Material is
84 | synched in timed relation with a moving image.
85 |
86 | b. Adapter's License means the license You apply to Your Copyright
87 | and Similar Rights in Your contributions to Adapted Material in
88 | accordance with the terms and conditions of this Public License.
89 |
90 | c. Copyright and Similar Rights means copyright and/or similar rights
91 | closely related to copyright including, without limitation,
92 | performance, broadcast, sound recording, and Sui Generis Database
93 | Rights, without regard to how the rights are labeled or
94 | categorized. For purposes of this Public License, the rights
95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
96 | Rights.
97 | d. Effective Technological Measures means those measures that, in the
98 | absence of proper authority, may not be circumvented under laws
99 | fulfilling obligations under Article 11 of the WIPO Copyright
100 | Treaty adopted on December 20, 1996, and/or similar international
101 | agreements.
102 |
103 | e. Exceptions and Limitations means fair use, fair dealing, and/or
104 | any other exception or limitation to Copyright and Similar Rights
105 | that applies to Your use of the Licensed Material.
106 |
107 | f. Licensed Material means the artistic or literary work, database,
108 | or other material to which the Licensor applied this Public
109 | License.
110 |
111 | g. Licensed Rights means the rights granted to You subject to the
112 | terms and conditions of this Public License, which are limited to
113 | all Copyright and Similar Rights that apply to Your use of the
114 | Licensed Material and that the Licensor has authority to license.
115 |
116 | h. Licensor means the individual(s) or entity(ies) granting rights
117 | under this Public License.
118 |
119 | i. NonCommercial means not primarily intended for or directed towards
120 | commercial advantage or monetary compensation. For purposes of
121 | this Public License, the exchange of the Licensed Material for
122 | other material subject to Copyright and Similar Rights by digital
123 | file-sharing or similar means is NonCommercial provided there is
124 | no payment of monetary compensation in connection with the
125 | exchange.
126 |
127 | j. Share means to provide material to the public by any means or
128 | process that requires permission under the Licensed Rights, such
129 | as reproduction, public display, public performance, distribution,
130 | dissemination, communication, or importation, and to make material
131 | available to the public including in ways that members of the
132 | public may access the material from a place and at a time
133 | individually chosen by them.
134 |
135 | k. Sui Generis Database Rights means rights other than copyright
136 | resulting from Directive 96/9/EC of the European Parliament and of
137 | the Council of 11 March 1996 on the legal protection of databases,
138 | as amended and/or succeeded, as well as other essentially
139 | equivalent rights anywhere in the world.
140 |
141 | l. You means the individual or entity exercising the Licensed Rights
142 | under this Public License. Your has a corresponding meaning.
143 |
144 |
145 | Section 2 -- Scope.
146 |
147 | a. License grant.
148 |
149 | 1. Subject to the terms and conditions of this Public License,
150 | the Licensor hereby grants You a worldwide, royalty-free,
151 | non-sublicensable, non-exclusive, irrevocable license to
152 | exercise the Licensed Rights in the Licensed Material to:
153 |
154 | a. reproduce and Share the Licensed Material, in whole or
155 | in part, for NonCommercial purposes only; and
156 |
157 | b. produce, reproduce, and Share Adapted Material for
158 | NonCommercial purposes only.
159 |
160 | 2. Exceptions and Limitations. For the avoidance of doubt, where
161 | Exceptions and Limitations apply to Your use, this Public
162 | License does not apply, and You do not need to comply with
163 | its terms and conditions.
164 |
165 | 3. Term. The term of this Public License is specified in Section
166 | 6(a).
167 |
168 | 4. Media and formats; technical modifications allowed. The
169 | Licensor authorizes You to exercise the Licensed Rights in
170 | all media and formats whether now known or hereafter created,
171 | and to make technical modifications necessary to do so. The
172 | Licensor waives and/or agrees not to assert any right or
173 | authority to forbid You from making technical modifications
174 | necessary to exercise the Licensed Rights, including
175 | technical modifications necessary to circumvent Effective
176 | Technological Measures. For purposes of this Public License,
177 | simply making modifications authorized by this Section 2(a)
178 | (4) never produces Adapted Material.
179 |
180 | 5. Downstream recipients.
181 |
182 | a. Offer from the Licensor -- Licensed Material. Every
183 | recipient of the Licensed Material automatically
184 | receives an offer from the Licensor to exercise the
185 | Licensed Rights under the terms and conditions of this
186 | Public License.
187 |
188 | b. No downstream restrictions. You may not offer or impose
189 | any additional or different terms or conditions on, or
190 | apply any Effective Technological Measures to, the
191 | Licensed Material if doing so restricts exercise of the
192 | Licensed Rights by any recipient of the Licensed
193 | Material.
194 |
195 | 6. No endorsement. Nothing in this Public License constitutes or
196 | may be construed as permission to assert or imply that You
197 | are, or that Your use of the Licensed Material is, connected
198 | with, or sponsored, endorsed, or granted official status by,
199 | the Licensor or others designated to receive attribution as
200 | provided in Section 3(a)(1)(A)(i).
201 |
202 | b. Other rights.
203 |
204 | 1. Moral rights, such as the right of integrity, are not
205 | licensed under this Public License, nor are publicity,
206 | privacy, and/or other similar personality rights; however, to
207 | the extent possible, the Licensor waives and/or agrees not to
208 | assert any such rights held by the Licensor to the limited
209 | extent necessary to allow You to exercise the Licensed
210 | Rights, but not otherwise.
211 |
212 | 2. Patent and trademark rights are not licensed under this
213 | Public License.
214 |
215 | 3. To the extent possible, the Licensor waives any right to
216 | collect royalties from You for the exercise of the Licensed
217 | Rights, whether directly or through a collecting society
218 | under any voluntary or waivable statutory or compulsory
219 | licensing scheme. In all other cases the Licensor expressly
220 | reserves any right to collect such royalties, including when
221 | the Licensed Material is used other than for NonCommercial
222 | purposes.
223 |
224 |
225 | Section 3 -- License Conditions.
226 |
227 | Your exercise of the Licensed Rights is expressly made subject to the
228 | following conditions.
229 |
230 | a. Attribution.
231 |
232 | 1. If You Share the Licensed Material (including in modified
233 | form), You must:
234 |
235 | a. retain the following if it is supplied by the Licensor
236 | with the Licensed Material:
237 |
238 | i. identification of the creator(s) of the Licensed
239 | Material and any others designated to receive
240 | attribution, in any reasonable manner requested by
241 | the Licensor (including by pseudonym if
242 | designated);
243 |
244 | ii. a copyright notice;
245 |
246 | iii. a notice that refers to this Public License;
247 |
248 | iv. a notice that refers to the disclaimer of
249 | warranties;
250 |
251 | v. a URI or hyperlink to the Licensed Material to the
252 | extent reasonably practicable;
253 |
254 | b. indicate if You modified the Licensed Material and
255 | retain an indication of any previous modifications; and
256 |
257 | c. indicate the Licensed Material is licensed under this
258 | Public License, and include the text of, or the URI or
259 | hyperlink to, this Public License.
260 |
261 | 2. You may satisfy the conditions in Section 3(a)(1) in any
262 | reasonable manner based on the medium, means, and context in
263 | which You Share the Licensed Material. For example, it may be
264 | reasonable to satisfy the conditions by providing a URI or
265 | hyperlink to a resource that includes the required
266 | information.
267 |
268 | 3. If requested by the Licensor, You must remove any of the
269 | information required by Section 3(a)(1)(A) to the extent
270 | reasonably practicable.
271 |
272 | 4. If You Share Adapted Material You produce, the Adapter's
273 | License You apply must not prevent recipients of the Adapted
274 | Material from complying with this Public License.
275 |
276 |
277 | Section 4 -- Sui Generis Database Rights.
278 |
279 | Where the Licensed Rights include Sui Generis Database Rights that
280 | apply to Your use of the Licensed Material:
281 |
282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
283 | to extract, reuse, reproduce, and Share all or a substantial
284 | portion of the contents of the database for NonCommercial purposes
285 | only;
286 |
287 | b. if You include all or a substantial portion of the database
288 | contents in a database in which You have Sui Generis Database
289 | Rights, then the database in which You have Sui Generis Database
290 | Rights (but not its individual contents) is Adapted Material; and
291 |
292 | c. You must comply with the conditions in Section 3(a) if You Share
293 | all or a substantial portion of the contents of the database.
294 |
295 | For the avoidance of doubt, this Section 4 supplements and does not
296 | replace Your obligations under this Public License where the Licensed
297 | Rights include other Copyright and Similar Rights.
298 |
299 |
300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
301 |
302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
312 |
313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
322 |
323 | c. The disclaimer of warranties and limitation of liability provided
324 | above shall be interpreted in a manner that, to the extent
325 | possible, most closely approximates an absolute disclaimer and
326 | waiver of all liability.
327 |
328 |
329 | Section 6 -- Term and Termination.
330 |
331 | a. This Public License applies for the term of the Copyright and
332 | Similar Rights licensed here. However, if You fail to comply with
333 | this Public License, then Your rights under this Public License
334 | terminate automatically.
335 |
336 | b. Where Your right to use the Licensed Material has terminated under
337 | Section 6(a), it reinstates:
338 |
339 | 1. automatically as of the date the violation is cured, provided
340 | it is cured within 30 days of Your discovery of the
341 | violation; or
342 |
343 | 2. upon express reinstatement by the Licensor.
344 |
345 | For the avoidance of doubt, this Section 6(b) does not affect any
346 | right the Licensor may have to seek remedies for Your violations
347 | of this Public License.
348 |
349 | c. For the avoidance of doubt, the Licensor may also offer the
350 | Licensed Material under separate terms or conditions or stop
351 | distributing the Licensed Material at any time; however, doing so
352 | will not terminate this Public License.
353 |
354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
355 | License.
356 |
357 |
358 | Section 7 -- Other Terms and Conditions.
359 |
360 | a. The Licensor shall not be bound by any additional or different
361 | terms or conditions communicated by You unless expressly agreed.
362 |
363 | b. Any arrangements, understandings, or agreements regarding the
364 | Licensed Material not stated herein are separate from and
365 | independent of the terms and conditions of this Public License.
366 |
367 |
368 | Section 8 -- Interpretation.
369 |
370 | a. For the avoidance of doubt, this Public License does not, and
371 | shall not be interpreted to, reduce, limit, restrict, or impose
372 | conditions on any use of the Licensed Material that could lawfully
373 | be made without permission under this Public License.
374 |
375 | b. To the extent possible, if any provision of this Public License is
376 | deemed unenforceable, it shall be automatically reformed to the
377 | minimum extent necessary to make it enforceable. If the provision
378 | cannot be reformed, it shall be severed from this Public License
379 | without affecting the enforceability of the remaining terms and
380 | conditions.
381 |
382 | c. No term or condition of this Public License will be waived and no
383 | failure to comply consented to unless expressly agreed to by the
384 | Licensor.
385 |
386 | d. Nothing in this Public License constitutes or may be interpreted
387 | as a limitation upon, or waiver of, any privileges and immunities
388 | that apply to the Licensor or You, including from the legal
389 | processes of any jurisdiction or authority.
390 |
391 | =======================================================================
392 |
393 | Creative Commons is not a party to its public
394 | licenses. Notwithstanding, Creative Commons may elect to apply one of
395 | its public licenses to material it publishes and in those instances
396 | will be considered the "Licensor." The text of the Creative Commons
397 | public licenses is dedicated to the public domain under the CC0 Public
398 | Domain Dedication. Except for the limited purpose of indicating that
399 | material is shared under a Creative Commons public license or as
400 | otherwise permitted by the Creative Commons policies published at
401 | creativecommons.org/policies, Creative Commons does not authorize the
402 | use of the trademark "Creative Commons" or any other trademark or logo
403 | of Creative Commons without its prior written consent including,
404 | without limitation, in connection with any unauthorized modifications
405 | to any of its public licenses or any other arrangements,
406 | understandings, or agreements concerning use of licensed material. For
407 | the avoidance of doubt, this paragraph does not form part of the
408 | public licenses.
409 |
410 | Creative Commons may be contacted at creativecommons.org.
411 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # StyleGAN.pytorch
2 |
3 | ## \[:star: New :star:\] Please head over to [Official PyTorch implementation](https://github.com/NVlabs/stylegan2-ada-pytorch).
4 |
5 |
6 |
7 | [ChineseGirl Dataset]
8 |
9 |
10 | This repository contains the unofficial PyTorch implementation of the following paper:
11 |
12 | > A Style-Based Generator Architecture for Generative Adversarial Networks
13 | > Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)
14 | > http://stylegan.xyz/paper
15 | >
16 | > Abstract: We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.
17 |
18 |
19 | ## Features
20 |
21 | - [x] Progressive Growing Training
22 | - [x] Exponential Moving Average
23 | - [x] Equalized Learning Rate
24 | - [x] PixelNorm Layer
25 | - [x] Minibatch Standard Deviation Layer
26 | - [x] Style Mixing Regularization
27 | - [x] Truncation Trick
28 | - [x] Using official tensorflow pretrained weights
29 | - [x] Gradient Clipping
30 | - [ ] Multi-GPU Training
31 | - [ ] FP-16 Support
32 | - [ ] Conditional GAN
33 |
34 | ## How to use
35 |
36 | ### Requirements
37 | - yacs
38 | - tqdm
39 | - numpy
40 | - torch
41 | - torchvision
42 | - tensorflow(Optional, for ./convert.py)
43 |
44 | ### Running the training script:
45 | Train from scratch:
46 | ```shell script
47 | python train.py --config configs/sample.yaml
48 | ```
49 |
50 | ### Using trained model:
51 | Resume training from a checkpoint (start form 128x128):
52 | ```shell script
53 | python train.py --config config/sample.yaml --start_depth 5 --generator_file [] [--gen_shadow_file] --discriminator_file [] --gen_optim_file [] --dis_optim_file []
54 | ```
55 | ### Style Mixing
56 |
57 | ```shell script
58 | python generate_mixing_figure.py --config config/sample.yaml --generator_file []
59 | ```
60 |
61 |
62 |
63 |
64 |
65 | > Thanks to dataset provider:Copyright(c) 2018, seeprettyface.com, BUPT_GWY contributes the dataset.
66 |
67 | ### Truncation trick
68 |
69 | ```shell script
70 | python generate_truncation_figure.py --config configs/sample_cari2_128_truncation.yaml --generator_file cari2_128_truncation_gen.pth
71 | ```
72 |
73 |
74 |
75 |
76 |
77 | ### Convert from official format
78 | ```shell script
79 | python convert.py --config configs/sample_ffhq_1024.yaml --input_file PATH/karras2019stylegan-ffhq-1024x1024.pkl --output_file ffhq_1024_gen.pth
80 | ```
81 |
82 | ## Generated samples
83 |
84 |
85 |
86 | [FFHQ Dataset](128x128)
87 |
88 |
89 | Using weights tranferred from official tensorflow repo.
90 |
91 |
92 | [FFHQ Dataset](1024x1024)
93 |
94 |
95 |
96 |
97 | [WebCaricatureDataset](128x128)
98 |
99 |
100 | ## Reference
101 |
102 | - **stylegan[official]**: https://github.com/NVlabs/stylegan
103 | - **pro_gan_pytorch**: https://github.com/akanimax/pro_gan_pytorch
104 | - **pytorch_style_gan**: https://github.com/lernapparat/lernapparat
105 |
106 | ## Thanks
107 |
108 | Please feel free to open PRs / issues / suggestions here.
109 |
110 | ## Due Credit
111 | This code heavily uses NVIDIA's original
112 | [StyleGAN](https://github.com/NVlabs/stylegan) code. We accredit and acknowledge their work here. The
113 | [Original License](/LICENSE_ORIGINAL.txt)
114 | is located in the base directory (file named `LICENSE_ORIGINAL.txt`).
115 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: config.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/22
6 | Description: Global Configure.
7 | -------------------------------------------------
8 | """
9 |
10 | from yacs.config import CfgNode as CN
11 |
12 | cfg = CN()
13 |
14 | cfg.output_dir = ''
15 | cfg.device = 'cuda'
16 | cfg.device_id = '0'
17 |
18 | cfg.structure = 'fixed'
19 | cfg.conditional = False
20 | cfg.n_classes = 0
21 | cfg.loss = "logistic"
22 | cfg.drift = 0.001
23 | cfg.d_repeats = 1
24 | cfg.use_ema = True
25 | cfg.ema_decay = 0.999
26 |
27 | cfg.num_works = 4
28 | cfg.num_samples = 36
29 | cfg.feedback_factor = 10
30 | cfg.checkpoint_factor = 10
31 |
32 | # ---------------------------------------------------------------------------- #
33 | # Options for scheduler
34 | # ---------------------------------------------------------------------------- #
35 | cfg.sched = CN()
36 |
37 | # example for {depth:9,resolution:1024}
38 | # res --> [4,8,16,32,64,128,256,512,1024]
39 | cfg.sched.epochs = [4, 4, 4, 4, 8, 16, 32, 64, 64]
40 | # batches for oen 1080Ti with 11G memory
41 | cfg.sched.batch_sizes = [128, 128, 128, 64, 32, 16, 8, 4, 2]
42 | cfg.sched.fade_in_percentage = [50, 50, 50, 50, 50, 50, 50, 50, 50]
43 |
44 | # TODO
45 | # cfg.sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
46 | # cfg.sched.D_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003}
47 |
48 | # ---------------------------------------------------------------------------- #
49 | # Options for Dataset
50 | # ---------------------------------------------------------------------------- #
51 | cfg.dataset = CN()
52 | cfg.dataset.img_dir = ""
53 | cfg.dataset.folder = True
54 | cfg.dataset.resolution = 128
55 | cfg.dataset.channels = 3
56 |
57 | cfg.model = CN()
58 | # ---------------------------------------------------------------------------- #
59 | # Options for Generator
60 | # ---------------------------------------------------------------------------- #
61 | cfg.model.gen = CN()
62 | cfg.model.gen.latent_size = 512
63 | # 8 in original paper
64 | cfg.model.gen.mapping_layers = 4
65 | cfg.model.gen.blur_filter = [1, 2, 1]
66 | cfg.model.gen.truncation_psi = 0.7
67 | cfg.model.gen.truncation_cutoff = 8
68 |
69 | # ---------------------------------------------------------------------------- #
70 | # Options for Discriminator
71 | # ---------------------------------------------------------------------------- #
72 | cfg.model.dis = CN()
73 | cfg.model.dis.use_wscale = True
74 | cfg.model.dis.blur_filter = [1, 2, 1]
75 |
76 | # ---------------------------------------------------------------------------- #
77 | # Options for Generator Optimizer
78 | # ---------------------------------------------------------------------------- #
79 | cfg.model.g_optim = CN()
80 | cfg.model.g_optim.learning_rate = 0.003
81 | cfg.model.g_optim.beta_1 = 0
82 | cfg.model.g_optim.beta_2 = 0.99
83 | cfg.model.g_optim.eps = 1e-8
84 |
85 | # ---------------------------------------------------------------------------- #
86 | # Options for Discriminator Optimizer
87 | # ---------------------------------------------------------------------------- #
88 | cfg.model.d_optim = CN()
89 | cfg.model.d_optim.learning_rate = 0.003
90 | cfg.model.d_optim.beta_1 = 0
91 | cfg.model.d_optim.beta_2 = 0.99
92 | cfg.model.d_optim.eps = 1e-8
93 |
--------------------------------------------------------------------------------
/configs/sample.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_celeba_6'
4 | structure: 'linear'
5 | device_id: ('3')
6 | dataset:
7 | img_dir: '/home/hzh/data/img_align_celeba'
8 | folder: False
9 | resolution: 128
10 | sched:
11 | epochs: [2,4,4,4,4,8]
12 |
--------------------------------------------------------------------------------
/configs/sample_cari2_128.yaml:
--------------------------------------------------------------------------------
1 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_cari2'
2 | structure: 'linear'
3 | device_id: ('3')
4 | checkpoint_factor: 4
5 | feedback_factor: 4
6 | model:
7 | gen:
8 | mapping_layers: 8
9 | # Negative numbers represent no truncation layers
10 | truncation_psi: -1.
11 | dataset:
12 | img_dir: '/home/hzh/data/Cari2-128'
13 | folder: False
14 | resolution: 128
15 | sched:
16 | epochs: [8,16,32,32,64,64]
17 |
--------------------------------------------------------------------------------
/configs/sample_cari2_128_truncation.yaml:
--------------------------------------------------------------------------------
1 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_cari2_1203'
2 | structure: 'linear'
3 | device_id: ('1')
4 | checkpoint_factor: 4
5 | feedback_factor: 4
6 | model:
7 | gen:
8 | mapping_layers: 8
9 | dataset:
10 | img_dir: '/home/hzh/data/Cari2-128'
11 | folder: False
12 | resolution: 128
13 | sched:
14 | epochs: [8,16,32,32,64,64]
15 |
--------------------------------------------------------------------------------
/configs/sample_cari2_128_truncation_logistic.yaml:
--------------------------------------------------------------------------------
1 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_cari2_logistic'
2 | structure: 'linear'
3 | device_id: ('2')
4 | checkpoint_factor: 4
5 | feedback_factor: 4
6 | loss: 'logistic'
7 | model:
8 | gen:
9 | mapping_layers: 8
10 | dataset:
11 | img_dir: '/home/hzh/data/Cari2-128'
12 | folder: False
13 | resolution: 128
14 | sched:
15 | epochs: [8,16,32,32,64,64]
16 |
--------------------------------------------------------------------------------
/configs/sample_celeba.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_celeba_9'
4 | structure: 'linear'
5 | device_id: ('1')
6 | checkpoint_factor: 2
7 | dataset:
8 | img_dir: '/home/hzh/data/CelebA'
9 | folder: False
10 | resolution: 128
11 | sched:
12 | epochs: [2,4,8,8,16,24]
13 |
--------------------------------------------------------------------------------
/configs/sample_celeba_128.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_celeba_9'
4 | structure: 'linear'
5 | device_id: ('1')
6 | checkpoint_factor: 2
7 | dataset:
8 | img_dir: '/home/hzh/data/CelebA'
9 | folder: False
10 | resolution: 128
11 | sched:
12 | epochs: [4,8,16,16,32,48]
13 |
--------------------------------------------------------------------------------
/configs/sample_celeba_128_trancation.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_celeba_11'
4 | structure: 'linear'
5 | device_id: ('1')
6 | checkpoint_factor: 2
7 | dataset:
8 | img_dir: '/home/hzh/data/CelebA'
9 | folder: False
10 | resolution: 128
11 | sched:
12 | epochs: [4,4,8,16,16,32]
13 |
--------------------------------------------------------------------------------
/configs/sample_celeba_128_trancation_logistic.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_celeba_12'
4 | structure: 'linear'
5 | device_id: ('1')
6 | checkpoint_factor: 2
7 | loss: 'logistic'
8 | dataset:
9 | img_dir: '/home/hzh/data/CelebA'
10 | folder: False
11 | resolution: 128
12 | sched:
13 | epochs: [4,4,8,16,16,32]
14 |
--------------------------------------------------------------------------------
/configs/sample_conditional.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/home/itdfh/dev/StyleGAN.pytorch/GTA/stylegan-conditional'
4 | structure: 'linear'
5 | conditional: True
6 | n_classes: 81
7 | loss: 'conditional-loss'
8 | device_id: ('0')
9 | dataset:
10 | img_dir: '/home/itdfh/dev/StyleGAN.pytorch/GTA/gta-vggface2-age-imagefolder'
11 | folder: False
12 | resolution: 128
13 | sched:
14 | epochs: [2,4,4,4,4,8]
--------------------------------------------------------------------------------
/configs/sample_ffhq_1024.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_ffhq_1'
4 | structure: 'linear'
5 | device_id: ('3')
6 | checkpoint_factor: 4
7 | feedback_factor: 4
8 | dataset:
9 | img_dir: '/home/hzh/data/FFHQ'
10 | folder: True
11 | resolution: 1024
12 | model:
13 | gen:
14 | mapping_layers: 8
15 | # Negative numbers represent no truncation layers
16 | truncation_psi: -1.
17 | sched:
18 | epochs: [8,16,32,32,64,64]
19 |
--------------------------------------------------------------------------------
/configs/sample_ffhq_1024_truncation.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_ffhq_1'
4 | structure: 'linear'
5 | device_id: ('3')
6 | checkpoint_factor: 4
7 | feedback_factor: 4
8 | dataset:
9 | img_dir: '/home/hzh/data/FFHQ'
10 | folder: True
11 | resolution: 1024
12 | model:
13 | gen:
14 | mapping_layers: 8
15 |
16 | sched:
17 | epochs: [8,16,32,32,64,64]
18 |
--------------------------------------------------------------------------------
/configs/sample_ffhq_128.yaml:
--------------------------------------------------------------------------------
1 | # Config file for CelebA dataset (200k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_ffhq_1'
4 | structure: 'linear'
5 | device_id: ('3')
6 | checkpoint_factor: 4
7 | feedback_factor: 4
8 | dataset:
9 | img_dir: '/home/hzh/data/FFHQ128x128'
10 | folder: True
11 | resolution: 128
12 | sched:
13 | epochs: [8,16,32,32,64,64]
14 |
--------------------------------------------------------------------------------
/configs/sample_race.yaml:
--------------------------------------------------------------------------------
1 | # Config file for ChineseGirl Dataset (10k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_race'
4 | structure: 'linear'
5 | device_id: ('1')
6 | dataset:
7 | img_dir: '/data/hzh/datasets/race_chinese_girl'
8 | folder: False
9 | resolution: 128
10 | sched:
11 | epochs: [40,80,80,80,80,160]
12 |
--------------------------------------------------------------------------------
/configs/sample_race_256.yaml:
--------------------------------------------------------------------------------
1 | # Config file for ChineseGirl Dataset (10k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_race_256'
4 | structure: 'linear'
5 | device_id: ('3')
6 | dataset:
7 | img_dir: '/home/hzh/data/ChineseGirl'
8 | folder: False
9 | resolution: 256
10 | sched:
11 | # 4,8,16,32,64,128,256
12 | epochs: [40,80,80,80,80,160,160]
13 |
--------------------------------------------------------------------------------
/configs/sample_race_256_mix.yaml:
--------------------------------------------------------------------------------
1 | # Config file for ChineseGirl Dataset (10k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_race_256_mix'
4 | structure: 'linear'
5 | device_id: ('3')
6 | checkpoint_factor: 20
7 | dataset:
8 | img_dir: '/home/hzh/data/ChineseGirl'
9 | folder: False
10 | resolution: 256
11 | sched:
12 | # 4,8,16,32,64,128,256
13 | epochs: [40,80,80,80,80,160,160]
14 |
--------------------------------------------------------------------------------
/configs/sample_race_256_mix_truncation.yaml:
--------------------------------------------------------------------------------
1 | # Config file for ChineseGirl Dataset (10k)
2 |
3 | output_dir: '/data/hzh/checkpoints/StyleGAN.pytorch/ckp_race_256_mix_trunc'
4 | structure: 'linear'
5 | device_id: ('3')
6 | checkpoint_factor: 20
7 | dataset:
8 | img_dir: '/home/hzh/data/ChineseGirl'
9 | folder: False
10 | resolution: 256
11 | sched:
12 | # 4,8,16,32,64,128,256
13 | epochs: [40,60,80,80,80,120,160]
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: convert.py
4 | Date: 2019/11/20
5 | Description: Modified from:
6 | https://github.com/lernapparat/lernapparat
7 | -------------------------------------------------
8 | """
9 |
10 | import argparse
11 | import pickle
12 | import collections
13 | import numpy as np
14 |
15 | import torch
16 |
17 | from models.GAN import Generator
18 | from dnnlib import tflib
19 |
20 |
21 | def load_weights(weights_dir):
22 | tflib.init_tf()
23 | weights = pickle.load(open(weights_dir, 'rb'))
24 | weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval()))
25 | for k, v in w.trainables.items()]) for w in weights]
26 |
27 | # dlatent_avg
28 | for k, v in weights[2].vars.items():
29 | if k == 'dlatent_avg':
30 | weights_pt.append(collections.OrderedDict([(k, torch.from_numpy(v.value().eval()))]))
31 | return weights_pt
32 |
33 |
34 | def key_translate(k):
35 | k = k.lower().split('/')
36 | if k[0] == 'g_synthesis':
37 | if not k[1].startswith('torgb'):
38 | if k[1] != '4x4':
39 | k.insert(1, 'blocks')
40 | k[2] = str(int(np.log2(int(k[2].split('x')[0])) - 3))
41 | else:
42 | k[1] = 'init_block'
43 | k = '.'.join(k)
44 | k = (k.replace('const.const', 'const').replace('const.bias', 'bias')
45 | .replace('const.stylemod', 'epi1.style_mod.lin')
46 | .replace('const.noise.weight', 'epi1.top_epi.noise.weight')
47 | .replace('conv.noise.weight', 'epi2.top_epi.noise.weight')
48 | .replace('conv.stylemod', 'epi2.style_mod.lin')
49 | .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')
50 | .replace('conv0_up.stylemod', 'epi1.style_mod.lin')
51 | .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')
52 | .replace('conv1.stylemod', 'epi2.style_mod.lin')
53 | .replace('torgb_lod0', 'to_rgb.{}'.format(out_depth)))
54 | elif k[0] == 'g_mapping':
55 | k.insert(1, 'map')
56 | k = '.'.join(k)
57 | else:
58 | k = '.'.join(k)
59 |
60 | return k
61 |
62 |
63 | def weight_translate(k, w):
64 | k = key_translate(k)
65 | if k.endswith('.weight'):
66 | if w.dim() == 2:
67 | w = w.t()
68 | elif w.dim() == 1:
69 | pass
70 | else:
71 | assert w.dim() == 4
72 | w = w.permute(3, 2, 0, 1)
73 | return w
74 |
75 |
76 | def parse_arguments():
77 | """
78 | default command line argument parser
79 | :return: args => parsed command line arguments
80 | """
81 |
82 | parser = argparse.ArgumentParser()
83 |
84 | parser.add_argument('--config', default='./configs/sample.yaml')
85 | parser.add_argument("--input_file", action="store", type=str,
86 | help="pretrained weights from official tensorflow repo.", required=True)
87 | parser.add_argument("--output_file", action="store", type=str, required=True,
88 | help="path to the output weights.")
89 |
90 | args = parser.parse_args()
91 |
92 | return args
93 |
94 |
95 | if __name__ == '__main__':
96 | args = parse_arguments()
97 |
98 | from config import cfg as opt
99 |
100 | opt.merge_from_file(args.config)
101 | opt.freeze()
102 |
103 | print("Creating generator object ...")
104 | # create the generator object
105 | gen = Generator(resolution=opt.dataset.resolution,
106 | num_channels=opt.dataset.channels,
107 | structure=opt.structure,
108 | **opt.model.gen)
109 | out_depth = gen.g_synthesis.depth - 1
110 |
111 | state_G, state_D, state_Gs, dlatent_avg = load_weights(args.input_file)
112 |
113 | # we delete the useless to_rgb filters
114 | params = {}
115 | for k, v in state_Gs.items():
116 | params[k] = v
117 | param_dict = {key_translate(k): weight_translate(k, v) for k, v in state_Gs.items()
118 | if 'torgb_lod' not in key_translate(k)}
119 |
120 | for k, v in dlatent_avg.items():
121 | param_dict['truncation.avg_latent'] = v
122 |
123 | sd_shapes = {k: v.shape for k, v in gen.state_dict().items()}
124 | param_shapes = {k: v.shape for k, v in param_dict.items()}
125 |
126 | # check for mismatch
127 | for k in list(sd_shapes) + list(param_shapes):
128 | pds = param_shapes.get(k)
129 | sds = sd_shapes.get(k)
130 | if pds is None:
131 | print("sd only", k, sds)
132 | elif sds is None:
133 | print("pd only", k, pds)
134 | elif sds != pds:
135 | print("mismatch!", k, pds, sds)
136 |
137 | gen.load_state_dict(param_dict, strict=False) # needed for the blur kernels
138 | torch.save(gen.state_dict(), args.output_file)
139 | print('Done.')
140 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: __init__.py.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/22
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 | from torchvision.datasets import ImageFolder
11 |
12 | from data.datasets import FlatDirectoryImageDataset, FoldersDistributedDataset
13 | from data.transforms import get_transform
14 |
15 |
16 | def make_dataset(cfg, conditional=False):
17 |
18 | if conditional:
19 | Dataset = ImageFolder
20 | else:
21 | if cfg.folder:
22 | Dataset = FoldersDistributedDataset
23 | else:
24 | Dataset = FlatDirectoryImageDataset
25 |
26 | transforms = get_transform(new_size=(cfg.resolution, cfg.resolution))
27 | _dataset = Dataset(cfg.img_dir, transform=transforms)
28 |
29 | return _dataset
30 |
31 |
32 | def get_data_loader(dataset, batch_size, num_workers):
33 | """
34 | generate the data_loader from the given dataset
35 | :param dataset: dataset for training (Should be a PyTorch dataset)
36 | Make sure every item is an Image
37 | :param batch_size: batch size of the data
38 | :param num_workers: num of parallel readers
39 | :return: dl => data_loader for the dataset
40 | """
41 | from torch.utils.data import DataLoader
42 |
43 | dl = DataLoader(
44 | dataset,
45 | batch_size=batch_size,
46 | shuffle=True,
47 | num_workers=num_workers,
48 | drop_last=True,
49 | pin_memory=True
50 | )
51 |
52 | return dl
53 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: datasets.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/22
6 | Description: Module for the data loading
7 | pipeline for the model to train.
8 | -------------------------------------------------
9 | """
10 |
11 | import os
12 | import numpy as np
13 |
14 | from torch.utils.data import Dataset
15 |
16 |
17 | class FlatDirectoryImageDataset(Dataset):
18 | """ pyTorch Dataset wrapper for the generic flat directory images dataset """
19 |
20 | def __setup_files(self):
21 | """
22 | private helper for setting up the files_list
23 | :return: files => list of paths of files
24 | """
25 | file_names = os.listdir(self.data_dir)
26 | files = [] # initialize to empty list
27 |
28 | for file_name in file_names:
29 | possible_file = os.path.join(self.data_dir, file_name)
30 | if os.path.isfile(possible_file):
31 | files.append(possible_file)
32 |
33 | # return the files list
34 | return files
35 |
36 | def __init__(self, data_dir, transform=None):
37 | """
38 | constructor for the class
39 | :param data_dir: path to the directory containing the data
40 | :param transform: transforms to be applied to the images
41 | """
42 | # define the state of the object
43 | self.data_dir = data_dir
44 | self.transform = transform
45 |
46 | # setup the files for reading
47 | self.files = self.__setup_files()
48 |
49 | def __len__(self):
50 | """
51 | compute the length of the dataset
52 | :return: len => length of dataset
53 | """
54 | return len(self.files)
55 |
56 | def __getitem__(self, idx):
57 | """
58 | obtain the image (read and transform)
59 | :param idx: index of the file required
60 | :return: img => image array
61 | """
62 | from PIL import Image
63 |
64 | img_file = self.files[idx]
65 |
66 | if img_file[-4:] == ".npy":
67 | # files are in .npy format
68 | img = np.load(img_file)
69 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0))
70 |
71 | else:
72 | # read the image:
73 | img = Image.open(self.files[idx]).convert('RGB')
74 |
75 | # apply the transforms on the image
76 | if self.transform is not None:
77 | img = self.transform(img)
78 |
79 | if img.shape[0] >= 4:
80 | # ignore the alpha channel
81 | # in the image if it exists
82 | img = img[:3, :, :]
83 |
84 | # return the image:
85 | return img
86 |
87 |
88 | class FoldersDistributedDataset(Dataset):
89 | """ pyTorch Dataset wrapper for folder distributed dataset """
90 |
91 | def __setup_files(self):
92 | """
93 | private helper for setting up the files_list
94 | :return: files => list of paths of files
95 | """
96 |
97 | dir_names = os.listdir(self.data_dir)
98 | files = [] # initialize to empty list
99 |
100 | for dir_name in dir_names:
101 | file_path = os.path.join(self.data_dir, dir_name)
102 | file_names = os.listdir(file_path)
103 | for file_name in file_names:
104 | possible_file = os.path.join(file_path, file_name)
105 | if os.path.isfile(possible_file):
106 | files.append(possible_file)
107 |
108 | # return the files list
109 | return files
110 |
111 | def __init__(self, data_dir, transform=None):
112 | """
113 | constructor for the class
114 | :param data_dir: path to the directory containing the data
115 | :param transform: transforms to be applied to the images
116 | """
117 | # define the state of the object
118 | self.data_dir = data_dir
119 | self.transform = transform
120 |
121 | # setup the files for reading
122 | self.files = self.__setup_files()
123 |
124 | def __len__(self):
125 | """
126 | compute the length of the dataset
127 | :return: len => length of dataset
128 | """
129 | return len(self.files)
130 |
131 | def __getitem__(self, idx):
132 | """
133 | obtain the image (read and transform)
134 | :param idx: index of the file required
135 | :return: img => image array
136 | """
137 | from PIL import Image
138 |
139 | # read the image:
140 | img_name = self.files[idx]
141 | if img_name[-4:] == ".npy":
142 | img = np.load(img_name)
143 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0))
144 | else:
145 | img = Image.open(img_name).convert('RGB')
146 |
147 | # apply the transforms on the image
148 | if self.transform is not None:
149 | img = self.transform(img)
150 |
151 | if img.shape[0] >= 4:
152 | # ignore the alpha channel
153 | # in the image if it exists
154 | img = img[:3, :, :]
155 |
156 | # return the image:
157 | return img
158 |
--------------------------------------------------------------------------------
/data/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: transforms.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/22
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 |
11 | def get_transform(new_size=None):
12 | """
13 | obtain the image transforms required for the input data
14 | :param new_size: size of the resized images
15 | :return: image_transform => transform object from TorchVision
16 | """
17 | from torchvision.transforms import ToTensor, Normalize, Compose, Resize, RandomHorizontalFlip
18 |
19 | if new_size is not None:
20 | image_transform = Compose([
21 | RandomHorizontalFlip(),
22 | Resize(new_size),
23 | ToTensor(),
24 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
25 | ])
26 |
27 | else:
28 | image_transform = Compose([
29 | RandomHorizontalFlip(),
30 | ToTensor(),
31 | Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
32 | ])
33 | return image_transform
34 |
--------------------------------------------------------------------------------
/diagrams/cari2_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/cari2_128.png
--------------------------------------------------------------------------------
/diagrams/ffhq_1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/ffhq_1024.png
--------------------------------------------------------------------------------
/diagrams/ffhq_128.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/ffhq_128.png
--------------------------------------------------------------------------------
/diagrams/figure03-style-mixing-mix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/figure03-style-mixing-mix.png
--------------------------------------------------------------------------------
/diagrams/figure08-truncation-trick.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/figure08-truncation-trick.png
--------------------------------------------------------------------------------
/diagrams/grid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangzh13/StyleGAN.pytorch/b1dfc473eab7c1c590b39dfa7306802a0363c198/diagrams/grid.png
--------------------------------------------------------------------------------
/dnnlib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import submission
9 |
10 | from .submission.run_context import RunContext
11 |
12 | from .submission.submit import SubmitTarget
13 | from .submission.submit import PathType
14 | from .submission.submit import SubmitConfig
15 | from .submission.submit import get_path_from_template
16 | from .submission.submit import submit_run
17 |
18 | from .util import EasyDict
19 |
20 | submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
21 |
--------------------------------------------------------------------------------
/dnnlib/submission/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import run_context
9 | from . import submit
10 |
--------------------------------------------------------------------------------
/dnnlib/submission/_internal/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for launching run functions in computing clusters.
9 |
10 | During the submit process, this file is copied to the appropriate run dir.
11 | When the job is launched in the cluster, this module is the first thing that
12 | is run inside the docker container.
13 | """
14 |
15 | import os
16 | import pickle
17 | import sys
18 |
19 | # PYTHONPATH should have been set so that the run_dir/src is in it
20 | import dnnlib
21 |
22 | def main():
23 | if not len(sys.argv) >= 4:
24 | raise RuntimeError("This script needs three arguments: run_dir, task_name and host_name!")
25 |
26 | run_dir = str(sys.argv[1])
27 | task_name = str(sys.argv[2])
28 | host_name = str(sys.argv[3])
29 |
30 | submit_config_path = os.path.join(run_dir, "submit_config.pkl")
31 |
32 | # SubmitConfig should have been pickled to the run dir
33 | if not os.path.exists(submit_config_path):
34 | raise RuntimeError("SubmitConfig pickle file does not exist!")
35 |
36 | submit_config: dnnlib.SubmitConfig = pickle.load(open(submit_config_path, "rb"))
37 | dnnlib.submission.submit.set_user_name_override(submit_config.user_name)
38 |
39 | submit_config.task_name = task_name
40 | submit_config.host_name = host_name
41 |
42 | dnnlib.submission.submit.run_wrapper(submit_config)
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/dnnlib/submission/run_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helpers for managing the run/training loop."""
9 |
10 | import datetime
11 | import os
12 | import pprint
13 | import time
14 | import types
15 |
16 | from typing import Any
17 |
18 | from . import submit
19 |
20 |
21 | class RunContext(object):
22 | """Helper class for managing the run/training loop.
23 |
24 | The context will hide the implementation details of a basic run/training loop.
25 | It will set things up properly, tell if run should be stopped, and then cleans up.
26 | User should call update periodically and use should_stop to determine if run should be stopped.
27 |
28 | Args:
29 | submit_config: The SubmitConfig that is used for the current run.
30 | config_module: The whole config module that is used for the current run.
31 | max_epoch: Optional cached value for the max_epoch variable used in update.
32 | """
33 |
34 | def __init__(self, submit_config: submit.SubmitConfig, config_module: types.ModuleType = None, max_epoch: Any = None):
35 | self.submit_config = submit_config
36 | self.should_stop_flag = False
37 | self.has_closed = False
38 | self.start_time = time.time()
39 | self.last_update_time = time.time()
40 | self.last_update_interval = 0.0
41 | self.max_epoch = max_epoch
42 |
43 | # pretty print the all the relevant content of the config module to a text file
44 | if config_module is not None:
45 | with open(os.path.join(submit_config.run_dir, "config.txt"), "w") as f:
46 | filtered_dict = {k: v for k, v in config_module.__dict__.items() if not k.startswith("_") and not isinstance(v, (types.ModuleType, types.FunctionType, types.LambdaType, submit.SubmitConfig, type))}
47 | pprint.pprint(filtered_dict, stream=f, indent=4, width=200, compact=False)
48 |
49 | # write out details about the run to a text file
50 | self.run_txt_data = {"task_name": submit_config.task_name, "host_name": submit_config.host_name, "start_time": datetime.datetime.now().isoformat(sep=" ")}
51 | with open(os.path.join(submit_config.run_dir, "run.txt"), "w") as f:
52 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
53 |
54 | def __enter__(self) -> "RunContext":
55 | return self
56 |
57 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
58 | self.close()
59 |
60 | def update(self, loss: Any = 0, cur_epoch: Any = 0, max_epoch: Any = None) -> None:
61 | """Do general housekeeping and keep the state of the context up-to-date.
62 | Should be called often enough but not in a tight loop."""
63 | assert not self.has_closed
64 |
65 | self.last_update_interval = time.time() - self.last_update_time
66 | self.last_update_time = time.time()
67 |
68 | if os.path.exists(os.path.join(self.submit_config.run_dir, "abort.txt")):
69 | self.should_stop_flag = True
70 |
71 | max_epoch_val = self.max_epoch if max_epoch is None else max_epoch
72 |
73 | def should_stop(self) -> bool:
74 | """Tell whether a stopping condition has been triggered one way or another."""
75 | return self.should_stop_flag
76 |
77 | def get_time_since_start(self) -> float:
78 | """How much time has passed since the creation of the context."""
79 | return time.time() - self.start_time
80 |
81 | def get_time_since_last_update(self) -> float:
82 | """How much time has passed since the last call to update."""
83 | return time.time() - self.last_update_time
84 |
85 | def get_last_update_interval(self) -> float:
86 | """How much time passed between the previous two calls to update."""
87 | return self.last_update_interval
88 |
89 | def close(self) -> None:
90 | """Close the context and clean up.
91 | Should only be called once."""
92 | if not self.has_closed:
93 | # update the run.txt with stopping time
94 | self.run_txt_data["stop_time"] = datetime.datetime.now().isoformat(sep=" ")
95 | with open(os.path.join(self.submit_config.run_dir, "run.txt"), "w") as f:
96 | pprint.pprint(self.run_txt_data, stream=f, indent=4, width=200, compact=False)
97 |
98 | self.has_closed = True
99 |
--------------------------------------------------------------------------------
/dnnlib/submission/submit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Submit a function to be run either locally or in a computing cluster."""
9 |
10 | import copy
11 | import os
12 | import pathlib
13 | import pickle
14 | import platform
15 | import pprint
16 | import re
17 | import shutil
18 | import time
19 | import traceback
20 |
21 | from enum import Enum
22 |
23 | from .. import util
24 |
25 |
26 | class SubmitTarget(Enum):
27 | """The target where the function should be run.
28 |
29 | LOCAL: Run it locally.
30 | """
31 | LOCAL = 1
32 |
33 |
34 | class PathType(Enum):
35 | """Determines in which format should a path be formatted.
36 |
37 | WINDOWS: Format with Windows style.
38 | LINUX: Format with Linux/Posix style.
39 | AUTO: Use current OS type to select either WINDOWS or LINUX.
40 | """
41 | WINDOWS = 1
42 | LINUX = 2
43 | AUTO = 3
44 |
45 |
46 | _user_name_override = None
47 |
48 |
49 | class SubmitConfig(util.EasyDict):
50 | """Strongly typed config dict needed to submit runs.
51 |
52 | Attributes:
53 | run_dir_root: Path to the run dir root. Can be optionally templated with tags. Needs to always be run through get_path_from_template.
54 | run_desc: Description of the run. Will be used in the run dir and task name.
55 | run_dir_ignore: List of file patterns used to ignore files when copying files to the run dir.
56 | run_dir_extra_files: List of (abs_path, rel_path) tuples of file paths. rel_path root will be the src directory inside the run dir.
57 | submit_target: Submit target enum value. Used to select where the run is actually launched.
58 | num_gpus: Number of GPUs used/requested for the run.
59 | print_info: Whether to print debug information when submitting.
60 | ask_confirmation: Whether to ask a confirmation before submitting.
61 | run_id: Automatically populated value during submit.
62 | run_name: Automatically populated value during submit.
63 | run_dir: Automatically populated value during submit.
64 | run_func_name: Automatically populated value during submit.
65 | run_func_kwargs: Automatically populated value during submit.
66 | user_name: Automatically populated value during submit. Can be set by the user which will then override the automatic value.
67 | task_name: Automatically populated value during submit.
68 | host_name: Automatically populated value during submit.
69 | """
70 |
71 | def __init__(self):
72 | super().__init__()
73 |
74 | # run (set these)
75 | self.run_dir_root = "" # should always be passed through get_path_from_template
76 | self.run_desc = ""
77 | self.run_dir_ignore = ["__pycache__", "*.pyproj", "*.sln", "*.suo", ".cache", ".idea", ".vs", ".vscode"]
78 | self.run_dir_extra_files = None
79 |
80 | # submit (set these)
81 | self.submit_target = SubmitTarget.LOCAL
82 | self.num_gpus = 1
83 | self.print_info = False
84 | self.ask_confirmation = False
85 |
86 | # (automatically populated)
87 | self.run_id = None
88 | self.run_name = None
89 | self.run_dir = None
90 | self.run_func_name = None
91 | self.run_func_kwargs = None
92 | self.user_name = None
93 | self.task_name = None
94 | self.host_name = "localhost"
95 |
96 |
97 | def get_path_from_template(path_template: str, path_type: PathType = PathType.AUTO) -> str:
98 | """Replace tags in the given path template and return either Windows or Linux formatted path."""
99 | # automatically select path type depending on running OS
100 | if path_type == PathType.AUTO:
101 | if platform.system() == "Windows":
102 | path_type = PathType.WINDOWS
103 | elif platform.system() == "Linux":
104 | path_type = PathType.LINUX
105 | else:
106 | raise RuntimeError("Unknown platform")
107 |
108 | path_template = path_template.replace("", get_user_name())
109 |
110 | # return correctly formatted path
111 | if path_type == PathType.WINDOWS:
112 | return str(pathlib.PureWindowsPath(path_template))
113 | elif path_type == PathType.LINUX:
114 | return str(pathlib.PurePosixPath(path_template))
115 | else:
116 | raise RuntimeError("Unknown platform")
117 |
118 |
119 | def get_template_from_path(path: str) -> str:
120 | """Convert a normal path back to its template representation."""
121 | # replace all path parts with the template tags
122 | path = path.replace("\\", "/")
123 | return path
124 |
125 |
126 | def convert_path(path: str, path_type: PathType = PathType.AUTO) -> str:
127 | """Convert a normal path to template and the convert it back to a normal path with given path type."""
128 | path_template = get_template_from_path(path)
129 | path = get_path_from_template(path_template, path_type)
130 | return path
131 |
132 |
133 | def set_user_name_override(name: str) -> None:
134 | """Set the global username override value."""
135 | global _user_name_override
136 | _user_name_override = name
137 |
138 |
139 | def get_user_name():
140 | """Get the current user name."""
141 | if _user_name_override is not None:
142 | return _user_name_override
143 | elif platform.system() == "Windows":
144 | return os.getlogin()
145 | elif platform.system() == "Linux":
146 | try:
147 | import pwd # pylint: disable=import-error
148 | return pwd.getpwuid(os.geteuid()).pw_name # pylint: disable=no-member
149 | except:
150 | return "unknown"
151 | else:
152 | raise RuntimeError("Unknown platform")
153 |
154 |
155 | def _create_run_dir_local(submit_config: SubmitConfig) -> str:
156 | """Create a new run dir with increasing ID number at the start."""
157 | run_dir_root = get_path_from_template(submit_config.run_dir_root, PathType.AUTO)
158 |
159 | if not os.path.exists(run_dir_root):
160 | print("Creating the run dir root: {}".format(run_dir_root))
161 | os.makedirs(run_dir_root)
162 |
163 | submit_config.run_id = _get_next_run_id_local(run_dir_root)
164 | submit_config.run_name = "{0:05d}-{1}".format(submit_config.run_id, submit_config.run_desc)
165 | run_dir = os.path.join(run_dir_root, submit_config.run_name)
166 |
167 | if os.path.exists(run_dir):
168 | raise RuntimeError("The run dir already exists! ({0})".format(run_dir))
169 |
170 | print("Creating the run dir: {}".format(run_dir))
171 | os.makedirs(run_dir)
172 |
173 | return run_dir
174 |
175 |
176 | def _get_next_run_id_local(run_dir_root: str) -> int:
177 | """Reads all directory names in a given directory (non-recursive) and returns the next (increasing) run id. Assumes IDs are numbers at the start of the directory names."""
178 | dir_names = [d for d in os.listdir(run_dir_root) if os.path.isdir(os.path.join(run_dir_root, d))]
179 | r = re.compile("^\\d+") # match one or more digits at the start of the string
180 | run_id = 0
181 |
182 | for dir_name in dir_names:
183 | m = r.match(dir_name)
184 |
185 | if m is not None:
186 | i = int(m.group())
187 | run_id = max(run_id, i + 1)
188 |
189 | return run_id
190 |
191 |
192 | def _populate_run_dir(run_dir: str, submit_config: SubmitConfig) -> None:
193 | """Copy all necessary files into the run dir. Assumes that the dir exists, is local, and is writable."""
194 | print("Copying files to the run dir")
195 | files = []
196 |
197 | run_func_module_dir_path = util.get_module_dir_by_obj_name(submit_config.run_func_name)
198 | assert '.' in submit_config.run_func_name
199 | for _idx in range(submit_config.run_func_name.count('.') - 1):
200 | run_func_module_dir_path = os.path.dirname(run_func_module_dir_path)
201 | files += util.list_dir_recursively_with_ignore(run_func_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=False)
202 |
203 | dnnlib_module_dir_path = util.get_module_dir_by_obj_name("dnnlib")
204 | files += util.list_dir_recursively_with_ignore(dnnlib_module_dir_path, ignores=submit_config.run_dir_ignore, add_base_to_relative=True)
205 |
206 | if submit_config.run_dir_extra_files is not None:
207 | files += submit_config.run_dir_extra_files
208 |
209 | files = [(f[0], os.path.join(run_dir, "src", f[1])) for f in files]
210 | files += [(os.path.join(dnnlib_module_dir_path, "submission", "_internal", "run.py"), os.path.join(run_dir, "run.py"))]
211 |
212 | util.copy_files_and_create_dirs(files)
213 |
214 | pickle.dump(submit_config, open(os.path.join(run_dir, "submit_config.pkl"), "wb"))
215 |
216 | with open(os.path.join(run_dir, "submit_config.txt"), "w") as f:
217 | pprint.pprint(submit_config, stream=f, indent=4, width=200, compact=False)
218 |
219 |
220 | def run_wrapper(submit_config: SubmitConfig) -> None:
221 | """Wrap the actual run function call for handling logging, exceptions, typing, etc."""
222 | is_local = submit_config.submit_target == SubmitTarget.LOCAL
223 |
224 | checker = None
225 |
226 | # when running locally, redirect stderr to stdout, log stdout to a file, and force flushing
227 | if is_local:
228 | logger = util.Logger(file_name=os.path.join(submit_config.run_dir, "log.txt"), file_mode="w", should_flush=True)
229 | else: # when running in a cluster, redirect stderr to stdout, and just force flushing (log writing is handled by run.sh)
230 | logger = util.Logger(file_name=None, should_flush=True)
231 |
232 | import dnnlib
233 | dnnlib.submit_config = submit_config
234 |
235 | try:
236 | print("dnnlib: Running {0}() on {1}...".format(submit_config.run_func_name, submit_config.host_name))
237 | start_time = time.time()
238 | util.call_func_by_name(func_name=submit_config.run_func_name, submit_config=submit_config, **submit_config.run_func_kwargs)
239 | print("dnnlib: Finished {0}() in {1}.".format(submit_config.run_func_name, util.format_time(time.time() - start_time)))
240 | except:
241 | if is_local:
242 | raise
243 | else:
244 | traceback.print_exc()
245 |
246 | log_src = os.path.join(submit_config.run_dir, "log.txt")
247 | log_dst = os.path.join(get_path_from_template(submit_config.run_dir_root), "{0}-error.txt".format(submit_config.run_name))
248 | shutil.copyfile(log_src, log_dst)
249 | finally:
250 | open(os.path.join(submit_config.run_dir, "_finished.txt"), "w").close()
251 |
252 | dnnlib.submit_config = None
253 | logger.close()
254 |
255 | if checker is not None:
256 | checker.stop()
257 |
258 |
259 | def submit_run(submit_config: SubmitConfig, run_func_name: str, **run_func_kwargs) -> None:
260 | """Create a run dir, gather files related to the run, copy files to the run dir, and launch the run in appropriate place."""
261 | submit_config = copy.copy(submit_config)
262 |
263 | if submit_config.user_name is None:
264 | submit_config.user_name = get_user_name()
265 |
266 | submit_config.run_func_name = run_func_name
267 | submit_config.run_func_kwargs = run_func_kwargs
268 |
269 | assert submit_config.submit_target == SubmitTarget.LOCAL
270 | if submit_config.submit_target in {SubmitTarget.LOCAL}:
271 | run_dir = _create_run_dir_local(submit_config)
272 |
273 | submit_config.task_name = "{0}-{1:05d}-{2}".format(submit_config.user_name, submit_config.run_id, submit_config.run_desc)
274 | submit_config.run_dir = run_dir
275 | _populate_run_dir(run_dir, submit_config)
276 |
277 | if submit_config.print_info:
278 | print("\nSubmit config:\n")
279 | pprint.pprint(submit_config, indent=4, width=200, compact=False)
280 | print()
281 |
282 | if submit_config.ask_confirmation:
283 | if not util.ask_yes_no("Continue submitting the job?"):
284 | return
285 |
286 | run_wrapper(submit_config)
287 |
--------------------------------------------------------------------------------
/dnnlib/tflib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | from . import autosummary
9 | from . import network
10 | from . import optimizer
11 | from . import tfutil
12 |
13 | from .tfutil import *
14 | from .network import Network
15 |
16 | from .optimizer import Optimizer
17 |
--------------------------------------------------------------------------------
/dnnlib/tflib/autosummary.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for adding automatically tracked values to Tensorboard.
9 |
10 | Autosummary creates an identity op that internally keeps track of the input
11 | values and automatically shows up in TensorBoard. The reported value
12 | represents an average over input components. The average is accumulated
13 | constantly over time and flushed when save_summaries() is called.
14 |
15 | Notes:
16 | - The output tensor must be used as an input for something else in the
17 | graph. Otherwise, the autosummary op will not get executed, and the average
18 | value will not get accumulated.
19 | - It is perfectly fine to include autosummaries with the same name in
20 | several places throughout the graph, even if they are executed concurrently.
21 | - It is ok to also pass in a python scalar or numpy array. In this case, it
22 | is added to the average immediately.
23 | """
24 |
25 | from collections import OrderedDict
26 | import numpy as np
27 | import tensorflow as tf
28 | from tensorboard import summary as summary_lib
29 | from tensorboard.plugins.custom_scalar import layout_pb2
30 |
31 | from . import tfutil
32 | from .tfutil import TfExpression
33 | from .tfutil import TfExpressionEx
34 |
35 | _dtype = tf.float64
36 | _vars = OrderedDict() # name => [var, ...]
37 | _immediate = OrderedDict() # name => update_op, update_value
38 | _finalized = False
39 | _merge_op = None
40 |
41 |
42 | def _create_var(name: str, value_expr: TfExpression) -> TfExpression:
43 | """Internal helper for creating autosummary accumulators."""
44 | assert not _finalized
45 | name_id = name.replace("/", "_")
46 | v = tf.cast(value_expr, _dtype)
47 |
48 | if v.shape.is_fully_defined():
49 | size = np.prod(tfutil.shape_to_list(v.shape))
50 | size_expr = tf.constant(size, dtype=_dtype)
51 | else:
52 | size = None
53 | size_expr = tf.reduce_prod(tf.cast(tf.shape(v), _dtype))
54 |
55 | if size == 1:
56 | if v.shape.ndims != 0:
57 | v = tf.reshape(v, [])
58 | v = [size_expr, v, tf.square(v)]
59 | else:
60 | v = [size_expr, tf.reduce_sum(v), tf.reduce_sum(tf.square(v))]
61 | v = tf.cond(tf.is_finite(v[1]), lambda: tf.stack(v), lambda: tf.zeros(3, dtype=_dtype))
62 |
63 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.control_dependencies(None):
64 | var = tf.Variable(tf.zeros(3, dtype=_dtype), trainable=False) # [sum(1), sum(x), sum(x**2)]
65 | update_op = tf.cond(tf.is_variable_initialized(var), lambda: tf.assign_add(var, v), lambda: tf.assign(var, v))
66 |
67 | if name in _vars:
68 | _vars[name].append(var)
69 | else:
70 | _vars[name] = [var]
71 | return update_op
72 |
73 |
74 | def autosummary(name: str, value: TfExpressionEx, passthru: TfExpressionEx = None) -> TfExpressionEx:
75 | """Create a new autosummary.
76 |
77 | Args:
78 | name: Name to use in TensorBoard
79 | value: TensorFlow expression or python value to track
80 | passthru: Optionally return this TF node without modifications but tack an autosummary update side-effect to this node.
81 |
82 | Example use of the passthru mechanism:
83 |
84 | n = autosummary('l2loss', loss, passthru=n)
85 |
86 | This is a shorthand for the following code:
87 |
88 | with tf.control_dependencies([autosummary('l2loss', loss)]):
89 | n = tf.identity(n)
90 | """
91 | tfutil.assert_tf_initialized()
92 | name_id = name.replace("/", "_")
93 |
94 | if tfutil.is_tf_expression(value):
95 | with tf.name_scope("summary_" + name_id), tf.device(value.device):
96 | update_op = _create_var(name, value)
97 | with tf.control_dependencies([update_op]):
98 | return tf.identity(value if passthru is None else passthru)
99 |
100 | else: # python scalar or numpy array
101 | if name not in _immediate:
102 | with tfutil.absolute_name_scope("Autosummary/" + name_id), tf.device(None), tf.control_dependencies(None):
103 | update_value = tf.placeholder(_dtype)
104 | update_op = _create_var(name, update_value)
105 | _immediate[name] = update_op, update_value
106 |
107 | update_op, update_value = _immediate[name]
108 | tfutil.run(update_op, {update_value: value})
109 | return value if passthru is None else passthru
110 |
111 |
112 | def finalize_autosummaries() -> None:
113 | """Create the necessary ops to include autosummaries in TensorBoard report.
114 | Note: This should be done only once per graph.
115 | """
116 | global _finalized
117 | tfutil.assert_tf_initialized()
118 |
119 | if _finalized:
120 | return None
121 |
122 | _finalized = True
123 | tfutil.init_uninitialized_vars([var for vars_list in _vars.values() for var in vars_list])
124 |
125 | # Create summary ops.
126 | with tf.device(None), tf.control_dependencies(None):
127 | for name, vars_list in _vars.items():
128 | name_id = name.replace("/", "_")
129 | with tfutil.absolute_name_scope("Autosummary/" + name_id):
130 | moments = tf.add_n(vars_list)
131 | moments /= moments[0]
132 | with tf.control_dependencies([moments]): # read before resetting
133 | reset_ops = [tf.assign(var, tf.zeros(3, dtype=_dtype)) for var in vars_list]
134 | with tf.name_scope(None), tf.control_dependencies(reset_ops): # reset before reporting
135 | mean = moments[1]
136 | std = tf.sqrt(moments[2] - tf.square(moments[1]))
137 | tf.summary.scalar(name, mean)
138 | tf.summary.scalar("xCustomScalars/" + name + "/margin_lo", mean - std)
139 | tf.summary.scalar("xCustomScalars/" + name + "/margin_hi", mean + std)
140 |
141 | # Group by category and chart name.
142 | cat_dict = OrderedDict()
143 | for series_name in sorted(_vars.keys()):
144 | p = series_name.split("/")
145 | cat = p[0] if len(p) >= 2 else ""
146 | chart = "/".join(p[1:-1]) if len(p) >= 3 else p[-1]
147 | if cat not in cat_dict:
148 | cat_dict[cat] = OrderedDict()
149 | if chart not in cat_dict[cat]:
150 | cat_dict[cat][chart] = []
151 | cat_dict[cat][chart].append(series_name)
152 |
153 | # Setup custom_scalar layout.
154 | categories = []
155 | for cat_name, chart_dict in cat_dict.items():
156 | charts = []
157 | for chart_name, series_names in chart_dict.items():
158 | series = []
159 | for series_name in series_names:
160 | series.append(layout_pb2.MarginChartContent.Series(
161 | value=series_name,
162 | lower="xCustomScalars/" + series_name + "/margin_lo",
163 | upper="xCustomScalars/" + series_name + "/margin_hi"))
164 | margin = layout_pb2.MarginChartContent(series=series)
165 | charts.append(layout_pb2.Chart(title=chart_name, margin=margin))
166 | categories.append(layout_pb2.Category(title=cat_name, chart=charts))
167 | layout = summary_lib.custom_scalar_pb(layout_pb2.Layout(category=categories))
168 | return layout
169 |
170 | def save_summaries(file_writer, global_step=None):
171 | """Call FileWriter.add_summary() with all summaries in the default graph,
172 | automatically finalizing and merging them on the first call.
173 | """
174 | global _merge_op
175 | tfutil.assert_tf_initialized()
176 |
177 | if _merge_op is None:
178 | layout = finalize_autosummaries()
179 | if layout is not None:
180 | file_writer.add_summary(layout)
181 | with tf.device(None), tf.control_dependencies(None):
182 | _merge_op = tf.summary.merge_all()
183 |
184 | file_writer.add_summary(_merge_op.eval(), global_step)
185 |
--------------------------------------------------------------------------------
/dnnlib/tflib/network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper for managing networks."""
9 |
10 | import types
11 | import inspect
12 | import re
13 | import uuid
14 | import sys
15 | import numpy as np
16 | import tensorflow as tf
17 |
18 | from collections import OrderedDict
19 | from typing import Any, List, Tuple, Union
20 |
21 | from . import tfutil
22 | from .. import util
23 |
24 | from .tfutil import TfExpression, TfExpressionEx
25 |
26 | _import_handlers = [] # Custom import handlers for dealing with legacy data in pickle import.
27 | _import_module_src = dict() # Source code for temporary modules created during pickle import.
28 |
29 |
30 | def import_handler(handler_func):
31 | """Function decorator for declaring custom import handlers."""
32 | _import_handlers.append(handler_func)
33 | return handler_func
34 |
35 |
36 | class Network:
37 | """Generic network abstraction.
38 |
39 | Acts as a convenience wrapper for a parameterized network construction
40 | function, providing several utility methods and convenient access to
41 | the inputs/outputs/weights.
42 |
43 | Network objects can be safely pickled and unpickled for long-term
44 | archival purposes. The pickling works reliably as long as the underlying
45 | network construction function is defined in a standalone Python module
46 | that has no side effects or application-specific imports.
47 |
48 | Args:
49 | name: Network name. Used to select TensorFlow name and variable scopes.
50 | func_name: Fully qualified name of the underlying network construction function, or a top-level function object.
51 | static_kwargs: Keyword arguments to be passed in to the network construction function.
52 |
53 | Attributes:
54 | name: User-specified name, defaults to build func name if None.
55 | scope: Unique TensorFlow scope containing template graph and variables, derived from the user-specified name.
56 | static_kwargs: Arguments passed to the user-supplied build func.
57 | components: Container for sub-networks. Passed to the build func, and retained between calls.
58 | num_inputs: Number of input tensors.
59 | num_outputs: Number of output tensors.
60 | input_shapes: Input tensor shapes (NC or NCHW), including minibatch dimension.
61 | output_shapes: Output tensor shapes (NC or NCHW), including minibatch dimension.
62 | input_shape: Short-hand for input_shapes[0].
63 | output_shape: Short-hand for output_shapes[0].
64 | input_templates: Input placeholders in the template graph.
65 | output_templates: Output tensors in the template graph.
66 | input_names: Name string for each input.
67 | output_names: Name string for each output.
68 | own_vars: Variables defined by this network (local_name => var), excluding sub-networks.
69 | vars: All variables (local_name => var).
70 | trainables: All trainable variables (local_name => var).
71 | var_global_to_local: Mapping from variable global names to local names.
72 | """
73 |
74 | def __init__(self, name: str = None, func_name: Any = None, **static_kwargs):
75 | tfutil.assert_tf_initialized()
76 | assert isinstance(name, str) or name is None
77 | assert func_name is not None
78 | assert isinstance(func_name, str) or util.is_top_level_function(func_name)
79 | assert util.is_pickleable(static_kwargs)
80 |
81 | self._init_fields()
82 | self.name = name
83 | self.static_kwargs = util.EasyDict(static_kwargs)
84 |
85 | # Locate the user-specified network build function.
86 | if util.is_top_level_function(func_name):
87 | func_name = util.get_top_level_function_name(func_name)
88 | module, self._build_func_name = util.get_module_from_obj_name(func_name)
89 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
90 | assert callable(self._build_func)
91 |
92 | # Dig up source code for the module containing the build function.
93 | self._build_module_src = _import_module_src.get(module, None)
94 | if self._build_module_src is None:
95 | self._build_module_src = inspect.getsource(module)
96 |
97 | # Init TensorFlow graph.
98 | self._init_graph()
99 | self.reset_own_vars()
100 |
101 | def _init_fields(self) -> None:
102 | self.name = None
103 | self.scope = None
104 | self.static_kwargs = util.EasyDict()
105 | self.components = util.EasyDict()
106 | self.num_inputs = 0
107 | self.num_outputs = 0
108 | self.input_shapes = [[]]
109 | self.output_shapes = [[]]
110 | self.input_shape = []
111 | self.output_shape = []
112 | self.input_templates = []
113 | self.output_templates = []
114 | self.input_names = []
115 | self.output_names = []
116 | self.own_vars = OrderedDict()
117 | self.vars = OrderedDict()
118 | self.trainables = OrderedDict()
119 | self.var_global_to_local = OrderedDict()
120 |
121 | self._build_func = None # User-supplied build function that constructs the network.
122 | self._build_func_name = None # Name of the build function.
123 | self._build_module_src = None # Full source code of the module containing the build function.
124 | self._run_cache = dict() # Cached graph data for Network.run().
125 |
126 | def _init_graph(self) -> None:
127 | # Collect inputs.
128 | self.input_names = []
129 |
130 | for param in inspect.signature(self._build_func).parameters.values():
131 | if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
132 | self.input_names.append(param.name)
133 |
134 | self.num_inputs = len(self.input_names)
135 | assert self.num_inputs >= 1
136 |
137 | # Choose name and scope.
138 | if self.name is None:
139 | self.name = self._build_func_name
140 | assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
141 | with tf.name_scope(None):
142 | self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)
143 |
144 | # Finalize build func kwargs.
145 | build_kwargs = dict(self.static_kwargs)
146 | build_kwargs["is_template_graph"] = True
147 | build_kwargs["components"] = self.components
148 |
149 | # Build template graph.
150 | with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope): # ignore surrounding scopes
151 | assert tf.get_variable_scope().name == self.scope
152 | assert tf.get_default_graph().get_name_scope() == self.scope
153 | with tf.control_dependencies(None): # ignore surrounding control dependencies
154 | self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
155 | out_expr = self._build_func(*self.input_templates, **build_kwargs)
156 |
157 | # Collect outputs.
158 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
159 | self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
160 | self.num_outputs = len(self.output_templates)
161 | assert self.num_outputs >= 1
162 | assert all(tfutil.is_tf_expression(t) for t in self.output_templates)
163 |
164 | # Perform sanity checks.
165 | if any(t.shape.ndims is None for t in self.input_templates):
166 | raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
167 | if any(t.shape.ndims is None for t in self.output_templates):
168 | raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
169 | if any(not isinstance(comp, Network) for comp in self.components.values()):
170 | raise ValueError("Components of a Network must be Networks themselves.")
171 | if len(self.components) != len(set(comp.name for comp in self.components.values())):
172 | raise ValueError("Components of a Network must have unique names.")
173 |
174 | # List inputs and outputs.
175 | self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
176 | self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
177 | self.input_shape = self.input_shapes[0]
178 | self.output_shape = self.output_shapes[0]
179 | self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]
180 |
181 | # List variables.
182 | self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
183 | self.vars = OrderedDict(self.own_vars)
184 | self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
185 | self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
186 | self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
187 |
188 | def reset_own_vars(self) -> None:
189 | """Re-initialize all variables of this network, excluding sub-networks."""
190 | tfutil.run([var.initializer for var in self.own_vars.values()])
191 |
192 | def reset_vars(self) -> None:
193 | """Re-initialize all variables of this network, including sub-networks."""
194 | tfutil.run([var.initializer for var in self.vars.values()])
195 |
196 | def reset_trainables(self) -> None:
197 | """Re-initialize all trainable variables of this network, including sub-networks."""
198 | tfutil.run([var.initializer for var in self.trainables.values()])
199 |
200 | def get_output_for(self, *in_expr: TfExpression, return_as_list: bool = False, **dynamic_kwargs) -> Union[TfExpression, List[TfExpression]]:
201 | """Construct TensorFlow expression(s) for the output(s) of this network, given the input expression(s)."""
202 | assert len(in_expr) == self.num_inputs
203 | assert not all(expr is None for expr in in_expr)
204 |
205 | # Finalize build func kwargs.
206 | build_kwargs = dict(self.static_kwargs)
207 | build_kwargs.update(dynamic_kwargs)
208 | build_kwargs["is_template_graph"] = False
209 | build_kwargs["components"] = self.components
210 |
211 | # Build TensorFlow graph to evaluate the network.
212 | with tfutil.absolute_variable_scope(self.scope, reuse=True), tf.name_scope(self.name):
213 | assert tf.get_variable_scope().name == self.scope
214 | valid_inputs = [expr for expr in in_expr if expr is not None]
215 | final_inputs = []
216 | for expr, name, shape in zip(in_expr, self.input_names, self.input_shapes):
217 | if expr is not None:
218 | expr = tf.identity(expr, name=name)
219 | else:
220 | expr = tf.zeros([tf.shape(valid_inputs[0])[0]] + shape[1:], name=name)
221 | final_inputs.append(expr)
222 | out_expr = self._build_func(*final_inputs, **build_kwargs)
223 |
224 | # Propagate input shapes back to the user-specified expressions.
225 | for expr, final in zip(in_expr, final_inputs):
226 | if isinstance(expr, tf.Tensor):
227 | expr.set_shape(final.shape)
228 |
229 | # Express outputs in the desired format.
230 | assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
231 | if return_as_list:
232 | out_expr = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
233 | return out_expr
234 |
235 | def get_var_local_name(self, var_or_global_name: Union[TfExpression, str]) -> str:
236 | """Get the local name of a given variable, without any surrounding name scopes."""
237 | assert tfutil.is_tf_expression(var_or_global_name) or isinstance(var_or_global_name, str)
238 | global_name = var_or_global_name if isinstance(var_or_global_name, str) else var_or_global_name.name
239 | return self.var_global_to_local[global_name]
240 |
241 | def find_var(self, var_or_local_name: Union[TfExpression, str]) -> TfExpression:
242 | """Find variable by local or global name."""
243 | assert tfutil.is_tf_expression(var_or_local_name) or isinstance(var_or_local_name, str)
244 | return self.vars[var_or_local_name] if isinstance(var_or_local_name, str) else var_or_local_name
245 |
246 | def get_var(self, var_or_local_name: Union[TfExpression, str]) -> np.ndarray:
247 | """Get the value of a given variable as NumPy array.
248 | Note: This method is very inefficient -- prefer to use tflib.run(list_of_vars) whenever possible."""
249 | return self.find_var(var_or_local_name).eval()
250 |
251 | def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None:
252 | """Set the value of a given variable based on the given NumPy array.
253 | Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
254 | tfutil.set_vars({self.find_var(var_or_local_name): new_value})
255 |
256 | def __getstate__(self) -> dict:
257 | """Pickle export."""
258 | state = dict()
259 | state["version"] = 3
260 | state["name"] = self.name
261 | state["static_kwargs"] = dict(self.static_kwargs)
262 | state["components"] = dict(self.components)
263 | state["build_module_src"] = self._build_module_src
264 | state["build_func_name"] = self._build_func_name
265 | state["variables"] = list(zip(self.own_vars.keys(), tfutil.run(list(self.own_vars.values()))))
266 | return state
267 |
268 | def __setstate__(self, state: dict) -> None:
269 | """Pickle import."""
270 | # pylint: disable=attribute-defined-outside-init
271 | tfutil.assert_tf_initialized()
272 | self._init_fields()
273 |
274 | # Execute custom import handlers.
275 | for handler in _import_handlers:
276 | state = handler(state)
277 |
278 | # Set basic fields.
279 | assert state["version"] in [2, 3]
280 | self.name = state["name"]
281 | self.static_kwargs = util.EasyDict(state["static_kwargs"])
282 | self.components = util.EasyDict(state.get("components", {}))
283 | self._build_module_src = state["build_module_src"]
284 | self._build_func_name = state["build_func_name"]
285 |
286 | # Create temporary module from the imported source code.
287 | module_name = "_tflib_network_import_" + uuid.uuid4().hex
288 | module = types.ModuleType(module_name)
289 | sys.modules[module_name] = module
290 | _import_module_src[module] = self._build_module_src
291 | exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used
292 |
293 | # Locate network build function in the temporary module.
294 | self._build_func = util.get_obj_from_module(module, self._build_func_name)
295 | assert callable(self._build_func)
296 |
297 | # Init TensorFlow graph.
298 | self._init_graph()
299 | self.reset_own_vars()
300 | tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})
301 |
302 | def clone(self, name: str = None, **new_static_kwargs) -> "Network":
303 | """Create a clone of this network with its own copy of the variables."""
304 | # pylint: disable=protected-access
305 | net = object.__new__(Network)
306 | net._init_fields()
307 | net.name = name if name is not None else self.name
308 | net.static_kwargs = util.EasyDict(self.static_kwargs)
309 | net.static_kwargs.update(new_static_kwargs)
310 | net._build_module_src = self._build_module_src
311 | net._build_func_name = self._build_func_name
312 | net._build_func = self._build_func
313 | net._init_graph()
314 | net.copy_vars_from(self)
315 | return net
316 |
317 | def copy_own_vars_from(self, src_net: "Network") -> None:
318 | """Copy the values of all variables from the given network, excluding sub-networks."""
319 | names = [name for name in self.own_vars.keys() if name in src_net.own_vars]
320 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
321 |
322 | def copy_vars_from(self, src_net: "Network") -> None:
323 | """Copy the values of all variables from the given network, including sub-networks."""
324 | names = [name for name in self.vars.keys() if name in src_net.vars]
325 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
326 |
327 | def copy_trainables_from(self, src_net: "Network") -> None:
328 | """Copy the values of all trainable variables from the given network, including sub-networks."""
329 | names = [name for name in self.trainables.keys() if name in src_net.trainables]
330 | tfutil.set_vars(tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
331 |
332 | def convert(self, new_func_name: str, new_name: str = None, **new_static_kwargs) -> "Network":
333 | """Create new network with the given parameters, and copy all variables from this network."""
334 | if new_name is None:
335 | new_name = self.name
336 | static_kwargs = dict(self.static_kwargs)
337 | static_kwargs.update(new_static_kwargs)
338 | net = Network(name=new_name, func_name=new_func_name, **static_kwargs)
339 | net.copy_vars_from(self)
340 | return net
341 |
342 | def setup_as_moving_average_of(self, src_net: "Network", beta: TfExpressionEx = 0.99, beta_nontrainable: TfExpressionEx = 0.0) -> tf.Operation:
343 | """Construct a TensorFlow op that updates the variables of this network
344 | to be slightly closer to those of the given network."""
345 | with tfutil.absolute_name_scope(self.scope + "/_MovingAvg"):
346 | ops = []
347 | for name, var in self.vars.items():
348 | if name in src_net.vars:
349 | cur_beta = beta if name in self.trainables else beta_nontrainable
350 | new_value = tfutil.lerp(src_net.vars[name], var, cur_beta)
351 | ops.append(var.assign(new_value))
352 | return tf.group(*ops)
353 |
354 | def run(self,
355 | *in_arrays: Tuple[Union[np.ndarray, None], ...],
356 | input_transform: dict = None,
357 | output_transform: dict = None,
358 | return_as_list: bool = False,
359 | print_progress: bool = False,
360 | minibatch_size: int = None,
361 | num_gpus: int = 1,
362 | assume_frozen: bool = False,
363 | **dynamic_kwargs) -> Union[np.ndarray, Tuple[np.ndarray, ...], List[np.ndarray]]:
364 | """Run this network for the given NumPy array(s), and return the output(s) as NumPy array(s).
365 |
366 | Args:
367 | input_transform: A dict specifying a custom transformation to be applied to the input tensor(s) before evaluating the network.
368 | The dict must contain a 'func' field that points to a top-level function. The function is called with the input
369 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
370 | output_transform: A dict specifying a custom transformation to be applied to the output tensor(s) after evaluating the network.
371 | The dict must contain a 'func' field that points to a top-level function. The function is called with the output
372 | TensorFlow expression(s) as positional arguments. Any remaining fields of the dict will be passed in as kwargs.
373 | return_as_list: True = return a list of NumPy arrays, False = return a single NumPy array, or a tuple if there are multiple outputs.
374 | print_progress: Print progress to the console? Useful for very large input arrays.
375 | minibatch_size: Maximum minibatch size to use, None = disable batching.
376 | num_gpus: Number of GPUs to use.
377 | assume_frozen: Improve multi-GPU performance by assuming that the trainable parameters will remain changed between calls.
378 | dynamic_kwargs: Additional keyword arguments to be passed into the network build function.
379 | """
380 | assert len(in_arrays) == self.num_inputs
381 | assert not all(arr is None for arr in in_arrays)
382 | assert input_transform is None or util.is_top_level_function(input_transform["func"])
383 | assert output_transform is None or util.is_top_level_function(output_transform["func"])
384 | output_transform, dynamic_kwargs = _handle_legacy_output_transforms(output_transform, dynamic_kwargs)
385 | num_items = in_arrays[0].shape[0]
386 | if minibatch_size is None:
387 | minibatch_size = num_items
388 |
389 | # Construct unique hash key from all arguments that affect the TensorFlow graph.
390 | key = dict(input_transform=input_transform, output_transform=output_transform, num_gpus=num_gpus, assume_frozen=assume_frozen, dynamic_kwargs=dynamic_kwargs)
391 | def unwind_key(obj):
392 | if isinstance(obj, dict):
393 | return [(key, unwind_key(value)) for key, value in sorted(obj.items())]
394 | if callable(obj):
395 | return util.get_top_level_function_name(obj)
396 | return obj
397 | key = repr(unwind_key(key))
398 |
399 | # Build graph.
400 | if key not in self._run_cache:
401 | with tfutil.absolute_name_scope(self.scope + "/_Run"), tf.control_dependencies(None):
402 | with tf.device("/cpu:0"):
403 | in_expr = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
404 | in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
405 |
406 | out_split = []
407 | for gpu in range(num_gpus):
408 | with tf.device("/gpu:%d" % gpu):
409 | net_gpu = self.clone() if assume_frozen else self
410 | in_gpu = in_split[gpu]
411 |
412 | if input_transform is not None:
413 | in_kwargs = dict(input_transform)
414 | in_gpu = in_kwargs.pop("func")(*in_gpu, **in_kwargs)
415 | in_gpu = [in_gpu] if tfutil.is_tf_expression(in_gpu) else list(in_gpu)
416 |
417 | assert len(in_gpu) == self.num_inputs
418 | out_gpu = net_gpu.get_output_for(*in_gpu, return_as_list=True, **dynamic_kwargs)
419 |
420 | if output_transform is not None:
421 | out_kwargs = dict(output_transform)
422 | out_gpu = out_kwargs.pop("func")(*out_gpu, **out_kwargs)
423 | out_gpu = [out_gpu] if tfutil.is_tf_expression(out_gpu) else list(out_gpu)
424 |
425 | assert len(out_gpu) == self.num_outputs
426 | out_split.append(out_gpu)
427 |
428 | with tf.device("/cpu:0"):
429 | out_expr = [tf.concat(outputs, axis=0) for outputs in zip(*out_split)]
430 | self._run_cache[key] = in_expr, out_expr
431 |
432 | # Run minibatches.
433 | in_expr, out_expr = self._run_cache[key]
434 | out_arrays = [np.empty([num_items] + tfutil.shape_to_list(expr.shape)[1:], expr.dtype.name) for expr in out_expr]
435 |
436 | for mb_begin in range(0, num_items, minibatch_size):
437 | if print_progress:
438 | print("\r%d / %d" % (mb_begin, num_items), end="")
439 |
440 | mb_end = min(mb_begin + minibatch_size, num_items)
441 | mb_num = mb_end - mb_begin
442 | mb_in = [src[mb_begin : mb_end] if src is not None else np.zeros([mb_num] + shape[1:]) for src, shape in zip(in_arrays, self.input_shapes)]
443 | mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
444 |
445 | for dst, src in zip(out_arrays, mb_out):
446 | dst[mb_begin: mb_end] = src
447 |
448 | # Done.
449 | if print_progress:
450 | print("\r%d / %d" % (num_items, num_items))
451 |
452 | if not return_as_list:
453 | out_arrays = out_arrays[0] if len(out_arrays) == 1 else tuple(out_arrays)
454 | return out_arrays
455 |
456 | def list_ops(self) -> List[TfExpression]:
457 | include_prefix = self.scope + "/"
458 | exclude_prefix = include_prefix + "_"
459 | ops = tf.get_default_graph().get_operations()
460 | ops = [op for op in ops if op.name.startswith(include_prefix)]
461 | ops = [op for op in ops if not op.name.startswith(exclude_prefix)]
462 | return ops
463 |
464 | def list_layers(self) -> List[Tuple[str, TfExpression, List[TfExpression]]]:
465 | """Returns a list of (layer_name, output_expr, trainable_vars) tuples corresponding to
466 | individual layers of the network. Mainly intended to be used for reporting."""
467 | layers = []
468 |
469 | def recurse(scope, parent_ops, parent_vars, level):
470 | # Ignore specific patterns.
471 | if any(p in scope for p in ["/Shape", "/strided_slice", "/Cast", "/concat", "/Assign"]):
472 | return
473 |
474 | # Filter ops and vars by scope.
475 | global_prefix = scope + "/"
476 | local_prefix = global_prefix[len(self.scope) + 1:]
477 | cur_ops = [op for op in parent_ops if op.name.startswith(global_prefix) or op.name == global_prefix[:-1]]
478 | cur_vars = [(name, var) for name, var in parent_vars if name.startswith(local_prefix) or name == local_prefix[:-1]]
479 | if not cur_ops and not cur_vars:
480 | return
481 |
482 | # Filter out all ops related to variables.
483 | for var in [op for op in cur_ops if op.type.startswith("Variable")]:
484 | var_prefix = var.name + "/"
485 | cur_ops = [op for op in cur_ops if not op.name.startswith(var_prefix)]
486 |
487 | # Scope does not contain ops as immediate children => recurse deeper.
488 | contains_direct_ops = any("/" not in op.name[len(global_prefix):] and op.type != "Identity" for op in cur_ops)
489 | if (level == 0 or not contains_direct_ops) and (len(cur_ops) + len(cur_vars)) > 1:
490 | visited = set()
491 | for rel_name in [op.name[len(global_prefix):] for op in cur_ops] + [name[len(local_prefix):] for name, _var in cur_vars]:
492 | token = rel_name.split("/")[0]
493 | if token not in visited:
494 | recurse(global_prefix + token, cur_ops, cur_vars, level + 1)
495 | visited.add(token)
496 | return
497 |
498 | # Report layer.
499 | layer_name = scope[len(self.scope) + 1:]
500 | layer_output = cur_ops[-1].outputs[0] if cur_ops else cur_vars[-1][1]
501 | layer_trainables = [var for _name, var in cur_vars if var.trainable]
502 | layers.append((layer_name, layer_output, layer_trainables))
503 |
504 | recurse(self.scope, self.list_ops(), list(self.vars.items()), 0)
505 | return layers
506 |
507 | def print_layers(self, title: str = None, hide_layers_with_no_params: bool = False) -> None:
508 | """Print a summary table of the network structure."""
509 | rows = [[title if title is not None else self.name, "Params", "OutputShape", "WeightShape"]]
510 | rows += [["---"] * 4]
511 | total_params = 0
512 |
513 | for layer_name, layer_output, layer_trainables in self.list_layers():
514 | num_params = sum(np.prod(tfutil.shape_to_list(var.shape)) for var in layer_trainables)
515 | weights = [var for var in layer_trainables if var.name.endswith("/weight:0")]
516 | weights.sort(key=lambda x: len(x.name))
517 | if len(weights) == 0 and len(layer_trainables) == 1:
518 | weights = layer_trainables
519 | total_params += num_params
520 |
521 | if not hide_layers_with_no_params or num_params != 0:
522 | num_params_str = str(num_params) if num_params > 0 else "-"
523 | output_shape_str = str(layer_output.shape)
524 | weight_shape_str = str(weights[0].shape) if len(weights) >= 1 else "-"
525 | rows += [[layer_name, num_params_str, output_shape_str, weight_shape_str]]
526 |
527 | rows += [["---"] * 4]
528 | rows += [["Total", str(total_params), "", ""]]
529 |
530 | widths = [max(len(cell) for cell in column) for column in zip(*rows)]
531 | print()
532 | for row in rows:
533 | print(" ".join(cell + " " * (width - len(cell)) for cell, width in zip(row, widths)))
534 | print()
535 |
536 | def setup_weight_histograms(self, title: str = None) -> None:
537 | """Construct summary ops to include histograms of all trainable parameters in TensorBoard."""
538 | if title is None:
539 | title = self.name
540 |
541 | with tf.name_scope(None), tf.device(None), tf.control_dependencies(None):
542 | for local_name, var in self.trainables.items():
543 | if "/" in local_name:
544 | p = local_name.split("/")
545 | name = title + "_" + p[-1] + "/" + "_".join(p[:-1])
546 | else:
547 | name = title + "_toplevel/" + local_name
548 |
549 | tf.summary.histogram(name, var)
550 |
551 | #----------------------------------------------------------------------------
552 | # Backwards-compatible emulation of legacy output transformation in Network.run().
553 |
554 | _print_legacy_warning = True
555 |
556 | def _handle_legacy_output_transforms(output_transform, dynamic_kwargs):
557 | global _print_legacy_warning
558 | legacy_kwargs = ["out_mul", "out_add", "out_shrink", "out_dtype"]
559 | if not any(kwarg in dynamic_kwargs for kwarg in legacy_kwargs):
560 | return output_transform, dynamic_kwargs
561 |
562 | if _print_legacy_warning:
563 | _print_legacy_warning = False
564 | print()
565 | print("WARNING: Old-style output transformations in Network.run() are deprecated.")
566 | print("Consider using 'output_transform=dict(func=tflib.convert_images_to_uint8)'")
567 | print("instead of 'out_mul=127.5, out_add=127.5, out_dtype=np.uint8'.")
568 | print()
569 | assert output_transform is None
570 |
571 | new_kwargs = dict(dynamic_kwargs)
572 | new_transform = {kwarg: new_kwargs.pop(kwarg) for kwarg in legacy_kwargs if kwarg in dynamic_kwargs}
573 | new_transform["func"] = _legacy_output_transform_func
574 | return new_transform, new_kwargs
575 |
576 | def _legacy_output_transform_func(*expr, out_mul=1.0, out_add=0.0, out_shrink=1, out_dtype=None):
577 | if out_mul != 1.0:
578 | expr = [x * out_mul for x in expr]
579 |
580 | if out_add != 0.0:
581 | expr = [x + out_add for x in expr]
582 |
583 | if out_shrink > 1:
584 | ksize = [1, 1, out_shrink, out_shrink]
585 | expr = [tf.nn.avg_pool(x, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW") for x in expr]
586 |
587 | if out_dtype is not None:
588 | if tf.as_dtype(out_dtype).is_integer:
589 | expr = [tf.round(x) for x in expr]
590 | expr = [tf.saturate_cast(x, out_dtype) for x in expr]
591 | return expr
592 |
--------------------------------------------------------------------------------
/dnnlib/tflib/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Helper wrapper for a Tensorflow optimizer."""
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 |
13 | from collections import OrderedDict
14 | from typing import List, Union
15 |
16 | from . import autosummary
17 | from . import tfutil
18 | from .. import util
19 |
20 | from .tfutil import TfExpression, TfExpressionEx
21 |
22 | try:
23 | # TensorFlow 1.13
24 | from tensorflow.python.ops import nccl_ops
25 | except:
26 | # Older TensorFlow versions
27 | import tensorflow.contrib.nccl as nccl_ops
28 |
29 | class Optimizer:
30 | """A Wrapper for tf.train.Optimizer.
31 |
32 | Automatically takes care of:
33 | - Gradient averaging for multi-GPU training.
34 | - Dynamic loss scaling and typecasts for FP16 training.
35 | - Ignoring corrupted gradients that contain NaNs/Infs.
36 | - Reporting statistics.
37 | - Well-chosen default settings.
38 | """
39 |
40 | def __init__(self,
41 | name: str = "Train",
42 | tf_optimizer: str = "tf.train.AdamOptimizer",
43 | learning_rate: TfExpressionEx = 0.001,
44 | use_loss_scaling: bool = False,
45 | loss_scaling_init: float = 64.0,
46 | loss_scaling_inc: float = 0.0005,
47 | loss_scaling_dec: float = 1.0,
48 | **kwargs):
49 |
50 | # Init fields.
51 | self.name = name
52 | self.learning_rate = tf.convert_to_tensor(learning_rate)
53 | self.id = self.name.replace("/", ".")
54 | self.scope = tf.get_default_graph().unique_name(self.id)
55 | self.optimizer_class = util.get_obj_by_name(tf_optimizer)
56 | self.optimizer_kwargs = dict(kwargs)
57 | self.use_loss_scaling = use_loss_scaling
58 | self.loss_scaling_init = loss_scaling_init
59 | self.loss_scaling_inc = loss_scaling_inc
60 | self.loss_scaling_dec = loss_scaling_dec
61 | self._grad_shapes = None # [shape, ...]
62 | self._dev_opt = OrderedDict() # device => optimizer
63 | self._dev_grads = OrderedDict() # device => [[(grad, var), ...], ...]
64 | self._dev_ls_var = OrderedDict() # device => variable (log2 of loss scaling factor)
65 | self._updates_applied = False
66 |
67 | def register_gradients(self, loss: TfExpression, trainable_vars: Union[List, dict]) -> None:
68 | """Register the gradients of the given loss function with respect to the given variables.
69 | Intended to be called once per GPU."""
70 | assert not self._updates_applied
71 |
72 | # Validate arguments.
73 | if isinstance(trainable_vars, dict):
74 | trainable_vars = list(trainable_vars.values()) # allow passing in Network.trainables as vars
75 |
76 | assert isinstance(trainable_vars, list) and len(trainable_vars) >= 1
77 | assert all(tfutil.is_tf_expression(expr) for expr in trainable_vars + [loss])
78 |
79 | if self._grad_shapes is None:
80 | self._grad_shapes = [tfutil.shape_to_list(var.shape) for var in trainable_vars]
81 |
82 | assert len(trainable_vars) == len(self._grad_shapes)
83 | assert all(
84 | tfutil.shape_to_list(var.shape) == var_shape for var, var_shape in zip(trainable_vars, self._grad_shapes))
85 |
86 | dev = loss.device
87 |
88 | assert all(var.device == dev for var in trainable_vars)
89 |
90 | # Register device and compute gradients.
91 | with tf.name_scope(self.id + "_grad"), tf.device(dev):
92 | if dev not in self._dev_opt:
93 | opt_name = self.scope.replace("/", "_") + "_opt%d" % len(self._dev_opt)
94 | assert callable(self.optimizer_class)
95 | self._dev_opt[dev] = self.optimizer_class(name=opt_name, learning_rate=self.learning_rate, **self.optimizer_kwargs)
96 | self._dev_grads[dev] = []
97 |
98 | loss = self.apply_loss_scaling(tf.cast(loss, tf.float32))
99 | grads = self._dev_opt[dev].compute_gradients(loss, trainable_vars, gate_gradients=tf.train.Optimizer.GATE_NONE) # disable gating to reduce memory usage
100 | grads = [(g, v) if g is not None else (tf.zeros_like(v), v) for g, v in grads] # replace disconnected gradients with zeros
101 | self._dev_grads[dev].append(grads)
102 |
103 | def apply_updates(self) -> tf.Operation:
104 | """Construct training op to update the registered variables based on their gradients."""
105 | tfutil.assert_tf_initialized()
106 | assert not self._updates_applied
107 | self._updates_applied = True
108 | devices = list(self._dev_grads.keys())
109 | total_grads = sum(len(grads) for grads in self._dev_grads.values())
110 | assert len(devices) >= 1 and total_grads >= 1
111 | ops = []
112 |
113 | with tfutil.absolute_name_scope(self.scope):
114 | # Cast gradients to FP32 and calculate partial sum within each device.
115 | dev_grads = OrderedDict() # device => [(grad, var), ...]
116 |
117 | for dev_idx, dev in enumerate(devices):
118 | with tf.name_scope("ProcessGrads%d" % dev_idx), tf.device(dev):
119 | sums = []
120 |
121 | for gv in zip(*self._dev_grads[dev]):
122 | assert all(v is gv[0][1] for g, v in gv)
123 | g = [tf.cast(g, tf.float32) for g, v in gv]
124 | g = g[0] if len(g) == 1 else tf.add_n(g)
125 | sums.append((g, gv[0][1]))
126 |
127 | dev_grads[dev] = sums
128 |
129 | # Sum gradients across devices.
130 | if len(devices) > 1:
131 | with tf.name_scope("SumAcrossGPUs"), tf.device(None):
132 | for var_idx, grad_shape in enumerate(self._grad_shapes):
133 | g = [dev_grads[dev][var_idx][0] for dev in devices]
134 |
135 | if np.prod(grad_shape): # nccl does not support zero-sized tensors
136 | g = nccl_ops.all_sum(g)
137 |
138 | for dev, gg in zip(devices, g):
139 | dev_grads[dev][var_idx] = (gg, dev_grads[dev][var_idx][1])
140 |
141 | # Apply updates separately on each device.
142 | for dev_idx, (dev, grads) in enumerate(dev_grads.items()):
143 | with tf.name_scope("ApplyGrads%d" % dev_idx), tf.device(dev):
144 | # Scale gradients as needed.
145 | if self.use_loss_scaling or total_grads > 1:
146 | with tf.name_scope("Scale"):
147 | coef = tf.constant(np.float32(1.0 / total_grads), name="coef")
148 | coef = self.undo_loss_scaling(coef)
149 | grads = [(g * coef, v) for g, v in grads]
150 |
151 | # Check for overflows.
152 | with tf.name_scope("CheckOverflow"):
153 | grad_ok = tf.reduce_all(tf.stack([tf.reduce_all(tf.is_finite(g)) for g, v in grads]))
154 |
155 | # Update weights and adjust loss scaling.
156 | with tf.name_scope("UpdateWeights"):
157 | # pylint: disable=cell-var-from-loop
158 | opt = self._dev_opt[dev]
159 | ls_var = self.get_loss_scaling_var(dev)
160 |
161 | if not self.use_loss_scaling:
162 | ops.append(tf.cond(grad_ok, lambda: opt.apply_gradients(grads), tf.no_op))
163 | else:
164 | ops.append(tf.cond(grad_ok,
165 | lambda: tf.group(tf.assign_add(ls_var, self.loss_scaling_inc), opt.apply_gradients(grads)),
166 | lambda: tf.group(tf.assign_sub(ls_var, self.loss_scaling_dec))))
167 |
168 | # Report statistics on the last device.
169 | if dev == devices[-1]:
170 | with tf.name_scope("Statistics"):
171 | ops.append(autosummary.autosummary(self.id + "/learning_rate", self.learning_rate))
172 | ops.append(
173 | autosummary.autosummary(self.id + "/overflow_frequency", tf.where(grad_ok, 0, 1)))
174 |
175 | if self.use_loss_scaling:
176 | ops.append(autosummary.autosummary(self.id + "/loss_scaling_log2", ls_var))
177 |
178 | # Initialize variables and group everything into a single op.
179 | self.reset_optimizer_state()
180 | tfutil.init_uninitialized_vars(list(self._dev_ls_var.values()))
181 |
182 | return tf.group(*ops, name="TrainingOp")
183 |
184 | def reset_optimizer_state(self) -> None:
185 | """Reset internal state of the underlying optimizer."""
186 | tfutil.assert_tf_initialized()
187 | tfutil.run([var.initializer for opt in self._dev_opt.values() for var in opt.variables()])
188 |
189 | def get_loss_scaling_var(self, device: str) -> Union[tf.Variable, None]:
190 | """Get or create variable representing log2 of the current dynamic loss scaling factor."""
191 | if not self.use_loss_scaling:
192 | return None
193 |
194 | if device not in self._dev_ls_var:
195 | with tfutil.absolute_name_scope(self.scope + "/LossScalingVars"), tf.control_dependencies(None):
196 | self._dev_ls_var[device] = tf.Variable(np.float32(self.loss_scaling_init), name="loss_scaling_var")
197 |
198 | return self._dev_ls_var[device]
199 |
200 | def apply_loss_scaling(self, value: TfExpression) -> TfExpression:
201 | """Apply dynamic loss scaling for the given expression."""
202 | assert tfutil.is_tf_expression(value)
203 |
204 | if not self.use_loss_scaling:
205 | return value
206 |
207 | return value * tfutil.exp2(self.get_loss_scaling_var(value.device))
208 |
209 | def undo_loss_scaling(self, value: TfExpression) -> TfExpression:
210 | """Undo the effect of dynamic loss scaling for the given expression."""
211 | assert tfutil.is_tf_expression(value)
212 |
213 | if not self.use_loss_scaling:
214 | return value
215 |
216 | return value * tfutil.exp2(-self.get_loss_scaling_var(value.device)) # pylint: disable=invalid-unary-operand-type
217 |
--------------------------------------------------------------------------------
/dnnlib/tflib/tfutil.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous helper utils for Tensorflow."""
9 |
10 | import os
11 | import numpy as np
12 | import tensorflow as tf
13 |
14 | from typing import Any, Iterable, List, Union
15 |
16 | TfExpression = Union[tf.Tensor, tf.Variable, tf.Operation]
17 | """A type that represents a valid Tensorflow expression."""
18 |
19 | TfExpressionEx = Union[TfExpression, int, float, np.ndarray]
20 | """A type that can be converted to a valid Tensorflow expression."""
21 |
22 |
23 | def run(*args, **kwargs) -> Any:
24 | """Run the specified ops in the default session."""
25 | assert_tf_initialized()
26 | return tf.get_default_session().run(*args, **kwargs)
27 |
28 |
29 | def is_tf_expression(x: Any) -> bool:
30 | """Check whether the input is a valid Tensorflow expression, i.e., Tensorflow Tensor, Variable, or Operation."""
31 | return isinstance(x, (tf.Tensor, tf.Variable, tf.Operation))
32 |
33 |
34 | def shape_to_list(shape: Iterable[tf.Dimension]) -> List[Union[int, None]]:
35 | """Convert a Tensorflow shape to a list of ints."""
36 | return [dim.value for dim in shape]
37 |
38 |
39 | def flatten(x: TfExpressionEx) -> TfExpression:
40 | """Shortcut function for flattening a tensor."""
41 | with tf.name_scope("Flatten"):
42 | return tf.reshape(x, [-1])
43 |
44 |
45 | def log2(x: TfExpressionEx) -> TfExpression:
46 | """Logarithm in base 2."""
47 | with tf.name_scope("Log2"):
48 | return tf.log(x) * np.float32(1.0 / np.log(2.0))
49 |
50 |
51 | def exp2(x: TfExpressionEx) -> TfExpression:
52 | """Exponent in base 2."""
53 | with tf.name_scope("Exp2"):
54 | return tf.exp(x * np.float32(np.log(2.0)))
55 |
56 |
57 | def lerp(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpressionEx:
58 | """Linear interpolation."""
59 | with tf.name_scope("Lerp"):
60 | return a + (b - a) * t
61 |
62 |
63 | def lerp_clip(a: TfExpressionEx, b: TfExpressionEx, t: TfExpressionEx) -> TfExpression:
64 | """Linear interpolation with clip."""
65 | with tf.name_scope("LerpClip"):
66 | return a + (b - a) * tf.clip_by_value(t, 0.0, 1.0)
67 |
68 |
69 | def absolute_name_scope(scope: str) -> tf.name_scope:
70 | """Forcefully enter the specified name scope, ignoring any surrounding scopes."""
71 | return tf.name_scope(scope + "/")
72 |
73 |
74 | def absolute_variable_scope(scope: str, **kwargs) -> tf.variable_scope:
75 | """Forcefully enter the specified variable scope, ignoring any surrounding scopes."""
76 | return tf.variable_scope(tf.VariableScope(name=scope, **kwargs), auxiliary_name_scope=False)
77 |
78 |
79 | def _sanitize_tf_config(config_dict: dict = None) -> dict:
80 | # Defaults.
81 | cfg = dict()
82 | cfg["rnd.np_random_seed"] = None # Random seed for NumPy. None = keep as is.
83 | cfg["rnd.tf_random_seed"] = "auto" # Random seed for TensorFlow. 'auto' = derive from NumPy random state. None = keep as is.
84 | cfg["env.TF_CPP_MIN_LOG_LEVEL"] = "1" # 0 = Print all available debug info from TensorFlow. 1 = Print warnings and errors, but disable debug info.
85 | cfg["graph_options.place_pruned_graph"] = True # False = Check that all ops are available on the designated device. True = Skip the check for ops that are not used.
86 | cfg["gpu_options.allow_growth"] = True # False = Allocate all GPU memory at the beginning. True = Allocate only as much GPU memory as needed.
87 |
88 | # User overrides.
89 | if config_dict is not None:
90 | cfg.update(config_dict)
91 | return cfg
92 |
93 |
94 | def init_tf(config_dict: dict = None) -> None:
95 | """Initialize TensorFlow session using good default settings."""
96 | # Skip if already initialized.
97 | if tf.get_default_session() is not None:
98 | return
99 |
100 | # Setup config dict and random seeds.
101 | cfg = _sanitize_tf_config(config_dict)
102 | np_random_seed = cfg["rnd.np_random_seed"]
103 | if np_random_seed is not None:
104 | np.random.seed(np_random_seed)
105 | tf_random_seed = cfg["rnd.tf_random_seed"]
106 | if tf_random_seed == "auto":
107 | tf_random_seed = np.random.randint(1 << 31)
108 | if tf_random_seed is not None:
109 | tf.set_random_seed(tf_random_seed)
110 |
111 | # Setup environment variables.
112 | for key, value in list(cfg.items()):
113 | fields = key.split(".")
114 | if fields[0] == "env":
115 | assert len(fields) == 2
116 | os.environ[fields[1]] = str(value)
117 |
118 | # Create default TensorFlow session.
119 | create_session(cfg, force_as_default=True)
120 |
121 |
122 | def assert_tf_initialized():
123 | """Check that TensorFlow session has been initialized."""
124 | if tf.get_default_session() is None:
125 | raise RuntimeError("No default TensorFlow session found. Please call dnnlib.tflib.init_tf().")
126 |
127 |
128 | def create_session(config_dict: dict = None, force_as_default: bool = False) -> tf.Session:
129 | """Create tf.Session based on config dict."""
130 | # Setup TensorFlow config proto.
131 | cfg = _sanitize_tf_config(config_dict)
132 | config_proto = tf.ConfigProto()
133 | for key, value in cfg.items():
134 | fields = key.split(".")
135 | if fields[0] not in ["rnd", "env"]:
136 | obj = config_proto
137 | for field in fields[:-1]:
138 | obj = getattr(obj, field)
139 | setattr(obj, fields[-1], value)
140 |
141 | # Create session.
142 | session = tf.Session(config=config_proto)
143 | if force_as_default:
144 | # pylint: disable=protected-access
145 | session._default_session = session.as_default()
146 | session._default_session.enforce_nesting = False
147 | session._default_session.__enter__() # pylint: disable=no-member
148 |
149 | return session
150 |
151 |
152 | def init_uninitialized_vars(target_vars: List[tf.Variable] = None) -> None:
153 | """Initialize all tf.Variables that have not already been initialized.
154 |
155 | Equivalent to the following, but more efficient and does not bloat the tf graph:
156 | tf.variables_initializer(tf.report_uninitialized_variables()).run()
157 | """
158 | assert_tf_initialized()
159 | if target_vars is None:
160 | target_vars = tf.global_variables()
161 |
162 | test_vars = []
163 | test_ops = []
164 |
165 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
166 | for var in target_vars:
167 | assert is_tf_expression(var)
168 |
169 | try:
170 | tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/IsVariableInitialized:0"))
171 | except KeyError:
172 | # Op does not exist => variable may be uninitialized.
173 | test_vars.append(var)
174 |
175 | with absolute_name_scope(var.name.split(":")[0]):
176 | test_ops.append(tf.is_variable_initialized(var))
177 |
178 | init_vars = [var for var, inited in zip(test_vars, run(test_ops)) if not inited]
179 | run([var.initializer for var in init_vars])
180 |
181 |
182 | def set_vars(var_to_value_dict: dict) -> None:
183 | """Set the values of given tf.Variables.
184 |
185 | Equivalent to the following, but more efficient and does not bloat the tf graph:
186 | tflib.run([tf.assign(var, value) for var, value in var_to_value_dict.items()]
187 | """
188 | assert_tf_initialized()
189 | ops = []
190 | feed_dict = {}
191 |
192 | for var, value in var_to_value_dict.items():
193 | assert is_tf_expression(var)
194 |
195 | try:
196 | setter = tf.get_default_graph().get_tensor_by_name(var.name.replace(":0", "/setter:0")) # look for existing op
197 | except KeyError:
198 | with absolute_name_scope(var.name.split(":")[0]):
199 | with tf.control_dependencies(None): # ignore surrounding control_dependencies
200 | setter = tf.assign(var, tf.placeholder(var.dtype, var.shape, "new_value"), name="setter") # create new setter
201 |
202 | ops.append(setter)
203 | feed_dict[setter.op.inputs[1]] = value
204 |
205 | run(ops, feed_dict)
206 |
207 |
208 | def create_var_with_large_initial_value(initial_value: np.ndarray, *args, **kwargs):
209 | """Create tf.Variable with large initial value without bloating the tf graph."""
210 | assert_tf_initialized()
211 | assert isinstance(initial_value, np.ndarray)
212 | zeros = tf.zeros(initial_value.shape, initial_value.dtype)
213 | var = tf.Variable(zeros, *args, **kwargs)
214 | set_vars({var: initial_value})
215 | return var
216 |
217 |
218 | def convert_images_from_uint8(images, drange=[-1,1], nhwc_to_nchw=False):
219 | """Convert a minibatch of images from uint8 to float32 with configurable dynamic range.
220 | Can be used as an input transformation for Network.run().
221 | """
222 | images = tf.cast(images, tf.float32)
223 | if nhwc_to_nchw:
224 | images = tf.transpose(images, [0, 3, 1, 2])
225 | return (images - drange[0]) * ((drange[1] - drange[0]) / 255)
226 |
227 |
228 | def convert_images_to_uint8(images, drange=[-1,1], nchw_to_nhwc=False, shrink=1):
229 | """Convert a minibatch of images from float32 to uint8 with configurable dynamic range.
230 | Can be used as an output transformation for Network.run().
231 | """
232 | images = tf.cast(images, tf.float32)
233 | if shrink > 1:
234 | ksize = [1, 1, shrink, shrink]
235 | images = tf.nn.avg_pool(images, ksize=ksize, strides=ksize, padding="VALID", data_format="NCHW")
236 | if nchw_to_nhwc:
237 | images = tf.transpose(images, [0, 2, 3, 1])
238 | scale = 255 / (drange[1] - drange[0])
239 | images = images * scale + (0.5 - drange[0] * scale)
240 | return tf.saturate_cast(images, tf.uint8)
241 |
--------------------------------------------------------------------------------
/dnnlib/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial
4 | # 4.0 International License. To view a copy of this license, visit
5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
7 |
8 | """Miscellaneous utility classes and functions."""
9 |
10 | import ctypes
11 | import fnmatch
12 | import importlib
13 | import inspect
14 | import numpy as np
15 | import os
16 | import shutil
17 | import sys
18 | import types
19 | import io
20 | import pickle
21 | import re
22 | import requests
23 | import html
24 | import hashlib
25 | import glob
26 | import uuid
27 |
28 | from distutils.util import strtobool
29 | from typing import Any, List, Tuple, Union
30 |
31 |
32 | # Util classes
33 | # ------------------------------------------------------------------------------------------
34 |
35 |
36 | class EasyDict(dict):
37 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
38 |
39 | def __getattr__(self, name: str) -> Any:
40 | try:
41 | return self[name]
42 | except KeyError:
43 | raise AttributeError(name)
44 |
45 | def __setattr__(self, name: str, value: Any) -> None:
46 | self[name] = value
47 |
48 | def __delattr__(self, name: str) -> None:
49 | del self[name]
50 |
51 |
52 | class Logger(object):
53 | """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
54 |
55 | def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
56 | self.file = None
57 |
58 | if file_name is not None:
59 | self.file = open(file_name, file_mode)
60 |
61 | self.should_flush = should_flush
62 | self.stdout = sys.stdout
63 | self.stderr = sys.stderr
64 |
65 | sys.stdout = self
66 | sys.stderr = self
67 |
68 | def __enter__(self) -> "Logger":
69 | return self
70 |
71 | def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
72 | self.close()
73 |
74 | def write(self, text: str) -> None:
75 | """Write text to stdout (and a file) and optionally flush."""
76 | if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
77 | return
78 |
79 | if self.file is not None:
80 | self.file.write(text)
81 |
82 | self.stdout.write(text)
83 |
84 | if self.should_flush:
85 | self.flush()
86 |
87 | def flush(self) -> None:
88 | """Flush written text to both stdout and a file, if open."""
89 | if self.file is not None:
90 | self.file.flush()
91 |
92 | self.stdout.flush()
93 |
94 | def close(self) -> None:
95 | """Flush, close possible files, and remove stdout/stderr mirroring."""
96 | self.flush()
97 |
98 | # if using multiple loggers, prevent closing in wrong order
99 | if sys.stdout is self:
100 | sys.stdout = self.stdout
101 | if sys.stderr is self:
102 | sys.stderr = self.stderr
103 |
104 | if self.file is not None:
105 | self.file.close()
106 |
107 |
108 | # Small util functions
109 | # ------------------------------------------------------------------------------------------
110 |
111 |
112 | def format_time(seconds: Union[int, float]) -> str:
113 | """Convert the seconds to human readable string with days, hours, minutes and seconds."""
114 | s = int(np.rint(seconds))
115 |
116 | if s < 60:
117 | return "{0}s".format(s)
118 | elif s < 60 * 60:
119 | return "{0}m {1:02}s".format(s // 60, s % 60)
120 | elif s < 24 * 60 * 60:
121 | return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
122 | else:
123 | return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
124 |
125 |
126 | def ask_yes_no(question: str) -> bool:
127 | """Ask the user the question until the user inputs a valid answer."""
128 | while True:
129 | try:
130 | print("{0} [y/n]".format(question))
131 | return strtobool(input().lower())
132 | except ValueError:
133 | pass
134 |
135 |
136 | def tuple_product(t: Tuple) -> Any:
137 | """Calculate the product of the tuple elements."""
138 | result = 1
139 |
140 | for v in t:
141 | result *= v
142 |
143 | return result
144 |
145 |
146 | _str_to_ctype = {
147 | "uint8": ctypes.c_ubyte,
148 | "uint16": ctypes.c_uint16,
149 | "uint32": ctypes.c_uint32,
150 | "uint64": ctypes.c_uint64,
151 | "int8": ctypes.c_byte,
152 | "int16": ctypes.c_int16,
153 | "int32": ctypes.c_int32,
154 | "int64": ctypes.c_int64,
155 | "float32": ctypes.c_float,
156 | "float64": ctypes.c_double
157 | }
158 |
159 |
160 | def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
161 | """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
162 | type_str = None
163 |
164 | if isinstance(type_obj, str):
165 | type_str = type_obj
166 | elif hasattr(type_obj, "__name__"):
167 | type_str = type_obj.__name__
168 | elif hasattr(type_obj, "name"):
169 | type_str = type_obj.name
170 | else:
171 | raise RuntimeError("Cannot infer type name from input")
172 |
173 | assert type_str in _str_to_ctype.keys()
174 |
175 | my_dtype = np.dtype(type_str)
176 | my_ctype = _str_to_ctype[type_str]
177 |
178 | assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
179 |
180 | return my_dtype, my_ctype
181 |
182 |
183 | def is_pickleable(obj: Any) -> bool:
184 | try:
185 | with io.BytesIO() as stream:
186 | pickle.dump(obj, stream)
187 | return True
188 | except:
189 | return False
190 |
191 |
192 | # Functionality to import modules/objects by name, and call functions by name
193 | # ------------------------------------------------------------------------------------------
194 |
195 | def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
196 | """Searches for the underlying module behind the name to some python object.
197 | Returns the module and the object name (original name with module part removed)."""
198 |
199 | # allow convenience shorthands, substitute them by full names
200 | obj_name = re.sub("^np.", "numpy.", obj_name)
201 | obj_name = re.sub("^tf.", "tensorflow.", obj_name)
202 |
203 | # list alternatives for (module_name, local_obj_name)
204 | parts = obj_name.split(".")
205 | name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
206 |
207 | # try each alternative in turn
208 | for module_name, local_obj_name in name_pairs:
209 | try:
210 | module = importlib.import_module(module_name) # may raise ImportError
211 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
212 | return module, local_obj_name
213 | except:
214 | pass
215 |
216 | # maybe some of the modules themselves contain errors?
217 | for module_name, _local_obj_name in name_pairs:
218 | try:
219 | importlib.import_module(module_name) # may raise ImportError
220 | except ImportError:
221 | if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
222 | raise
223 |
224 | # maybe the requested attribute is missing?
225 | for module_name, local_obj_name in name_pairs:
226 | try:
227 | module = importlib.import_module(module_name) # may raise ImportError
228 | get_obj_from_module(module, local_obj_name) # may raise AttributeError
229 | except ImportError:
230 | pass
231 |
232 | # we are out of luck, but we have no idea why
233 | raise ImportError(obj_name)
234 |
235 |
236 | def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
237 | """Traverses the object name and returns the last (rightmost) python object."""
238 | if obj_name == '':
239 | return module
240 | obj = module
241 | for part in obj_name.split("."):
242 | obj = getattr(obj, part)
243 | return obj
244 |
245 |
246 | def get_obj_by_name(name: str) -> Any:
247 | """Finds the python object with the given name."""
248 | module, obj_name = get_module_from_obj_name(name)
249 | return get_obj_from_module(module, obj_name)
250 |
251 |
252 | def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
253 | """Finds the python object with the given name and calls it as a function."""
254 | assert func_name is not None
255 | func_obj = get_obj_by_name(func_name)
256 | assert callable(func_obj)
257 | return func_obj(*args, **kwargs)
258 |
259 |
260 | def get_module_dir_by_obj_name(obj_name: str) -> str:
261 | """Get the directory path of the module containing the given object name."""
262 | module, _ = get_module_from_obj_name(obj_name)
263 | return os.path.dirname(inspect.getfile(module))
264 |
265 |
266 | def is_top_level_function(obj: Any) -> bool:
267 | """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
268 | return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
269 |
270 |
271 | def get_top_level_function_name(obj: Any) -> str:
272 | """Return the fully-qualified name of a top-level function."""
273 | assert is_top_level_function(obj)
274 | return obj.__module__ + "." + obj.__name__
275 |
276 |
277 | # File system helpers
278 | # ------------------------------------------------------------------------------------------
279 |
280 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
281 | """List all files recursively in a given directory while ignoring given file and directory names.
282 | Returns list of tuples containing both absolute and relative paths."""
283 | assert os.path.isdir(dir_path)
284 | base_name = os.path.basename(os.path.normpath(dir_path))
285 |
286 | if ignores is None:
287 | ignores = []
288 |
289 | result = []
290 |
291 | for root, dirs, files in os.walk(dir_path, topdown=True):
292 | for ignore_ in ignores:
293 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
294 |
295 | # dirs need to be edited in-place
296 | for d in dirs_to_remove:
297 | dirs.remove(d)
298 |
299 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
300 |
301 | absolute_paths = [os.path.join(root, f) for f in files]
302 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
303 |
304 | if add_base_to_relative:
305 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
306 |
307 | assert len(absolute_paths) == len(relative_paths)
308 | result += zip(absolute_paths, relative_paths)
309 |
310 | return result
311 |
312 |
313 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
314 | """Takes in a list of tuples of (src, dst) paths and copies files.
315 | Will create all necessary directories."""
316 | for file in files:
317 | target_dir_name = os.path.dirname(file[1])
318 |
319 | # will create all intermediate-level directories
320 | if not os.path.exists(target_dir_name):
321 | os.makedirs(target_dir_name)
322 |
323 | shutil.copyfile(file[0], file[1])
324 |
325 |
326 | # URL helpers
327 | # ------------------------------------------------------------------------------------------
328 |
329 | def is_url(obj: Any) -> bool:
330 | """Determine whether the given object is a valid URL string."""
331 | if not isinstance(obj, str) or not "://" in obj:
332 | return False
333 | try:
334 | res = requests.compat.urlparse(obj)
335 | if not res.scheme or not res.netloc or not "." in res.netloc:
336 | return False
337 | res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
338 | if not res.scheme or not res.netloc or not "." in res.netloc:
339 | return False
340 | except:
341 | return False
342 | return True
343 |
344 |
345 | def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True) -> Any:
346 | """Download the given URL and return a binary-mode file object to access the data."""
347 | assert is_url(url)
348 | assert num_attempts >= 1
349 |
350 | # Lookup from cache.
351 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
352 | if cache_dir is not None:
353 | cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
354 | if len(cache_files) == 1:
355 | return open(cache_files[0], "rb")
356 |
357 | # Download.
358 | url_name = None
359 | url_data = None
360 | with requests.Session() as session:
361 | if verbose:
362 | print("Downloading %s ..." % url, end="", flush=True)
363 | for attempts_left in reversed(range(num_attempts)):
364 | try:
365 | with session.get(url) as res:
366 | res.raise_for_status()
367 | if len(res.content) == 0:
368 | raise IOError("No data received")
369 |
370 | if len(res.content) < 8192:
371 | content_str = res.content.decode("utf-8")
372 | if "download_warning" in res.headers.get("Set-Cookie", ""):
373 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
374 | if len(links) == 1:
375 | url = requests.compat.urljoin(url, links[0])
376 | raise IOError("Google Drive virus checker nag")
377 | if "Google Drive - Quota exceeded" in content_str:
378 | raise IOError("Google Drive quota exceeded")
379 |
380 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
381 | url_name = match[1] if match else url
382 | url_data = res.content
383 | if verbose:
384 | print(" done")
385 | break
386 | except:
387 | if not attempts_left:
388 | if verbose:
389 | print(" failed")
390 | raise
391 | if verbose:
392 | print(".", end="", flush=True)
393 |
394 | # Save to cache.
395 | if cache_dir is not None:
396 | safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
397 | cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
398 | temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
399 | os.makedirs(cache_dir, exist_ok=True)
400 | with open(temp_file, "wb") as f:
401 | f.write(url_data)
402 | os.replace(temp_file, cache_file) # atomic
403 |
404 | # Return data as file object.
405 | return io.BytesIO(url_data)
406 |
--------------------------------------------------------------------------------
/generate_grid.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: generate_grid.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/27
6 | Description: Generate a image sample grid from a particular depth of a model
7 | -------------------------------------------------
8 | """
9 |
10 | import os
11 | import argparse
12 | import numpy as np
13 |
14 | import torch
15 | from torchvision.utils import save_image
16 |
17 | from models.GAN import Generator
18 |
19 |
20 | def parse_arguments():
21 | """
22 | default command line argument parser
23 | :return: args => parsed command line arguments
24 | """
25 |
26 | parser = argparse.ArgumentParser()
27 |
28 | parser.add_argument('--config', default='./configs/sample.yaml')
29 | parser.add_argument("--generator_file", action="store", type=str,
30 | help="pretrained weights file for generator", required=True)
31 | parser.add_argument("--n_row", action="store", type=int,
32 | default=10, help="number of synchronized grids to be generated")
33 | parser.add_argument("--n_col", action="store", type=int,
34 | default=4, help="number of synchronized grids to be generated")
35 | parser.add_argument("--output_dir", action="store", type=str,
36 | default="output/",
37 | help="path to the output directory for the frames")
38 |
39 | args = parser.parse_args()
40 |
41 | return args
42 |
43 |
44 | def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
45 | """
46 | adjust the dynamic colour range of the given input data
47 | :param data: input image data
48 | :param drange_in: original range of input
49 | :param drange_out: required range of output
50 | :return: img => colour range adjusted images
51 | """
52 | if drange_in != drange_out:
53 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
54 | np.float32(drange_in[1]) - np.float32(drange_in[0]))
55 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
56 | data = data * scale + bias
57 | return torch.clamp(data, min=0, max=1)
58 |
59 |
60 | def main(args):
61 | """
62 | Main function for the script
63 | :param args: parsed command line arguments
64 | :return: None
65 | """
66 |
67 | from config import cfg as opt
68 |
69 | opt.merge_from_file(args.config)
70 | opt.freeze()
71 |
72 | print("Creating generator object ...")
73 | # create the generator object
74 | gen = Generator(resolution=opt.dataset.resolution,
75 | num_channels=opt.dataset.channels,
76 | structure=opt.structure,
77 | **opt.model.gen)
78 |
79 | print("Loading the generator weights from:", args.generator_file)
80 | # load the weights into it
81 | gen.load_state_dict(torch.load(args.generator_file))
82 |
83 | # path for saving the files:
84 | save_path = args.output_dir
85 | os.makedirs(save_path, exist_ok=True)
86 | latent_size = opt.model.gen.latent_size
87 | out_depth = int(np.log2(opt.dataset.resolution)) - 2
88 |
89 | print("Generating scale synchronized images ...")
90 | # generate the images:
91 | with torch.no_grad():
92 | point = torch.randn(args.n_row * args.n_col, latent_size)
93 | point = (point / point.norm()) * (latent_size ** 0.5)
94 | ss_image = gen(point, depth=out_depth, alpha=1)
95 | # color adjust the generated image:
96 | ss_image = adjust_dynamic_range(ss_image)
97 |
98 | # save the ss_image in the directory
99 | save_image(ss_image, os.path.join(save_path, "grid.png"), nrow=args.n_row,
100 | normalize=True, scale_each=True, pad_value=128, padding=1)
101 |
102 | print('Done.')
103 |
104 |
105 | if __name__ == '__main__':
106 | main(parse_arguments())
107 |
--------------------------------------------------------------------------------
/generate_mixing_figure.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch
7 |
8 | from models.GAN import Generator
9 | from generate_grid import adjust_dynamic_range
10 |
11 |
12 | def draw_style_mixing_figure(png, gen, out_depth, src_seeds, dst_seeds, style_ranges):
13 | n_col = len(src_seeds)
14 | n_row = len(dst_seeds)
15 | w = h = 2 ** (out_depth + 2)
16 | with torch.no_grad():
17 | latent_size = gen.g_mapping.latent_size
18 | src_latents_np = np.stack([np.random.RandomState(seed).randn(latent_size, ) for seed in src_seeds])
19 | dst_latents_np = np.stack([np.random.RandomState(seed).randn(latent_size, ) for seed in dst_seeds])
20 | src_latents = torch.from_numpy(src_latents_np.astype(np.float32))
21 | dst_latents = torch.from_numpy(dst_latents_np.astype(np.float32))
22 | src_dlatents = gen.g_mapping(src_latents) # [seed, layer, component]
23 | dst_dlatents = gen.g_mapping(dst_latents) # [seed, layer, component]
24 | src_images = gen.g_synthesis(src_dlatents, depth=out_depth, alpha=1)
25 | dst_images = gen.g_synthesis(dst_dlatents, depth=out_depth, alpha=1)
26 |
27 | src_dlatents_np = src_dlatents.numpy()
28 | dst_dlatents_np = dst_dlatents.numpy()
29 | canvas = Image.new('RGB', (w * (n_col + 1), h * (n_row + 1)), 'white')
30 | for col, src_image in enumerate(list(src_images)):
31 | src_image = adjust_dynamic_range(src_image)
32 | src_image = src_image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
33 | canvas.paste(Image.fromarray(src_image, 'RGB'), ((col + 1) * w, 0))
34 | for row, dst_image in enumerate(list(dst_images)):
35 | dst_image = adjust_dynamic_range(dst_image)
36 | dst_image = dst_image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
37 | canvas.paste(Image.fromarray(dst_image, 'RGB'), (0, (row + 1) * h))
38 |
39 | row_dlatents = np.stack([dst_dlatents_np[row]] * n_col)
40 | row_dlatents[:, style_ranges[row]] = src_dlatents_np[:, style_ranges[row]]
41 | row_dlatents = torch.from_numpy(row_dlatents)
42 |
43 | row_images = gen.g_synthesis(row_dlatents, depth=out_depth, alpha=1)
44 | for col, image in enumerate(list(row_images)):
45 | image = adjust_dynamic_range(image)
46 | image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
47 | canvas.paste(Image.fromarray(image, 'RGB'), ((col + 1) * w, (row + 1) * h))
48 | canvas.save(png)
49 |
50 |
51 | def main(args):
52 | """
53 | Main function for the script
54 | :param args: parsed command line arguments
55 | :return: None
56 | """
57 |
58 | from config import cfg as opt
59 |
60 | opt.merge_from_file(args.config)
61 | opt.freeze()
62 |
63 | print("Creating generator object ...")
64 | # create the generator object
65 | gen = Generator(resolution=opt.dataset.resolution,
66 | num_channels=opt.dataset.channels,
67 | structure=opt.structure,
68 | **opt.model.gen)
69 |
70 | print("Loading the generator weights from:", args.generator_file)
71 | # load the weights into it
72 | gen.load_state_dict(torch.load(args.generator_file))
73 |
74 | # path for saving the files:
75 | # generate the images:
76 | # src_seeds = [639, 701, 687, 615, 1999], dst_seeds = [888, 888, 888],
77 | draw_style_mixing_figure(os.path.join('figure03-style-mixing.png'), gen,
78 | out_depth=6, src_seeds=[639, 1995, 687, 615, 1999], dst_seeds=[888, 888, 888],
79 | style_ranges=[range(0, 2)] * 1 + [range(2, 8)] * 1 + [range(8, 14)] * 1)
80 | print('Done.')
81 |
82 |
83 | def parse_arguments():
84 | """
85 | default command line argument parser
86 | :return: args => parsed command line arguments
87 | """
88 |
89 | parser = argparse.ArgumentParser()
90 |
91 | parser.add_argument('--config', default='./configs/sample_race_256.yaml')
92 | parser.add_argument("--generator_file", action="store", type=str,
93 | help="pretrained weights file for generator", required=True)
94 |
95 | args = parser.parse_args()
96 |
97 | return args
98 |
99 |
100 | if __name__ == '__main__':
101 | main(parse_arguments())
102 |
--------------------------------------------------------------------------------
/generate_samples.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: generate_samples.py
4 | Date: 2019/10/27
5 | Description: Generate single image samples from a particular depth of a model
6 | Modified from: https://github.com/akanimax/pro_gan_pytorch
7 | -------------------------------------------------
8 | """
9 |
10 | import os
11 | import argparse
12 | import numpy as np
13 | from tqdm import tqdm
14 |
15 | import torch
16 | from torchvision.utils import save_image
17 |
18 | from models.GAN import Generator
19 |
20 |
21 | def parse_arguments():
22 | """
23 | default command line argument parser
24 | :return: args => parsed command line arguments
25 | """
26 |
27 | parser = argparse.ArgumentParser()
28 |
29 | parser.add_argument('--config', default='./configs/sample.yaml')
30 | parser.add_argument("--generator_file", action="store", type=str,
31 | help="pretrained weights file for generator", required=True)
32 | parser.add_argument("--num_samples", action="store", type=int,
33 | default=300, help="number of synchronized grids to be generated")
34 | parser.add_argument("--output_dir", action="store", type=str,
35 | default="output/",
36 | help="path to the output directory for the frames")
37 | parser.add_argument("--input", action="store", type=str,
38 | default=None, help="the dlatent code (W) for a certain sample")
39 | parser.add_argument("--output", action="store", type=str,
40 | default="output.png", help="the output for the certain samples")
41 |
42 | args = parser.parse_args()
43 |
44 | return args
45 |
46 |
47 | def adjust_dynamic_range(data, drange_in=(-1, 1), drange_out=(0, 1)):
48 | """
49 | adjust the dynamic colour range of the given input data
50 | :param data: input image data
51 | :param drange_in: original range of input
52 | :param drange_out: required range of output
53 | :return: img => colour range adjusted images
54 | """
55 | if drange_in != drange_out:
56 | scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (
57 | np.float32(drange_in[1]) - np.float32(drange_in[0]))
58 | bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
59 | data = data * scale + bias
60 | return torch.clamp(data, min=0, max=1)
61 |
62 |
63 | def main(args):
64 | """
65 | Main function for the script
66 | :param args: parsed command line arguments
67 | :return: None
68 | """
69 |
70 | from config import cfg as opt
71 |
72 | opt.merge_from_file(args.config)
73 | opt.freeze()
74 |
75 | print("Creating generator object ...")
76 | # create the generator object
77 | gen = Generator(resolution=opt.dataset.resolution,
78 | num_channels=opt.dataset.channels,
79 | structure=opt.structure,
80 | **opt.model.gen)
81 |
82 | print("Loading the generator weights from:", args.generator_file)
83 | # load the weights into it
84 | gen.load_state_dict(torch.load(args.generator_file))
85 |
86 | # path for saving the files:
87 | save_path = args.output_dir
88 | os.makedirs(save_path, exist_ok=True)
89 | latent_size = opt.model.gen.latent_size
90 | out_depth = int(np.log2(opt.dataset.resolution)) - 2
91 |
92 | if args.input is None:
93 | print("Generating scale synchronized images ...")
94 | for img_num in tqdm(range(1, args.num_samples + 1)):
95 | # generate the images:
96 | with torch.no_grad():
97 | point = torch.randn(1, latent_size)
98 | point = (point / point.norm()) * (latent_size ** 0.5)
99 | ss_image = gen(point, depth=out_depth, alpha=1)
100 | # color adjust the generated image:
101 | ss_image = adjust_dynamic_range(ss_image)
102 |
103 | # save the ss_image in the directory
104 | save_image(ss_image, os.path.join(save_path, str(img_num) + ".png"))
105 |
106 | print("Generated %d images at %s" % (args.num_samples, save_path))
107 | else:
108 | code = np.load(args.input)
109 | dlatent_in = torch.unsqueeze(torch.from_numpy(code), 0)
110 | ss_image = gen.g_synthesis(dlatent_in, depth=out_depth, alpha=1)
111 | # color adjust the generated image:
112 | ss_image = adjust_dynamic_range(ss_image)
113 | save_image(ss_image, args.output)
114 |
115 |
116 | if __name__ == '__main__':
117 | main(parse_arguments())
118 |
--------------------------------------------------------------------------------
/generate_truncation_figure.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: generate_truncation_figure.py
4 | Author: Zhonghao Huang
5 | Date: 2019/11/23
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 | import argparse
11 | import numpy as np
12 | from PIL import Image
13 |
14 | import torch
15 |
16 | from generate_grid import adjust_dynamic_range
17 | from models.GAN import Generator
18 |
19 |
20 | def draw_truncation_trick_figure(png, gen, out_depth, seeds, psis):
21 | w = h = 2 ** (out_depth + 2)
22 | latent_size = gen.g_mapping.latent_size
23 |
24 | with torch.no_grad():
25 | latents_np = np.stack([np.random.RandomState(seed).randn(latent_size) for seed in seeds])
26 | latents = torch.from_numpy(latents_np.astype(np.float32))
27 | dlatents = gen.g_mapping(latents).detach().numpy() # [seed, layer, component]
28 | dlatent_avg = gen.truncation.avg_latent.numpy() # [component]
29 |
30 | canvas = Image.new('RGB', (w * len(psis), h * len(seeds)), 'white')
31 | for row, dlatent in enumerate(list(dlatents)):
32 | row_dlatents = (dlatent[np.newaxis] - dlatent_avg) * np.reshape(psis, [-1, 1, 1]) + dlatent_avg
33 | row_dlatents = torch.from_numpy(row_dlatents.astype(np.float32))
34 | row_images = gen.g_synthesis(row_dlatents, depth=out_depth, alpha=1)
35 | for col, image in enumerate(list(row_images)):
36 | image = adjust_dynamic_range(image)
37 | image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy()
38 | canvas.paste(Image.fromarray(image, 'RGB'), (col * w, row * h))
39 | canvas.save(png)
40 |
41 |
42 | def main(args):
43 | """
44 | Main function for the script
45 | :param args: parsed command line arguments
46 | :return: None
47 | """
48 |
49 | from config import cfg as opt
50 |
51 | opt.merge_from_file(args.config)
52 | opt.freeze()
53 |
54 | print("Creating generator object ...")
55 | # create the generator object
56 | gen = Generator(resolution=opt.dataset.resolution,
57 | num_channels=opt.dataset.channels,
58 | structure=opt.structure,
59 | **opt.model.gen)
60 |
61 | print("Loading the generator weights from:", args.generator_file)
62 | # load the weights into it
63 | gen.load_state_dict(torch.load(args.generator_file))
64 |
65 | draw_truncation_trick_figure('figure08-truncation-trick.png', gen, out_depth=5,
66 | seeds=[91, 388], psis=[1, 0.7, 0.5, 0, -0.5, -1])
67 |
68 | print('Done.')
69 |
70 |
71 | def parse_arguments():
72 | """
73 | default command line argument parser
74 | :return: args => parsed command line arguments
75 | """
76 |
77 | parser = argparse.ArgumentParser()
78 |
79 | parser.add_argument('--config', default='./configs/sample.yaml')
80 | parser.add_argument("--generator_file", action="store", type=str,
81 | help="pretrained weights file for generator", required=True)
82 |
83 | args = parser.parse_args()
84 |
85 | return args
86 |
87 |
88 | if __name__ == '__main__':
89 | main(parse_arguments())
90 |
--------------------------------------------------------------------------------
/models/Blocks.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: Blocks.py
4 | Date: 2019/10/17
5 | Description: Copy from: https://github.com/lernapparat/lernapparat
6 | -------------------------------------------------
7 | """
8 |
9 | from collections import OrderedDict
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 | from models.CustomLayers import EqualizedLinear, LayerEpilogue, EqualizedConv2d, BlurLayer, View, StddevLayer
15 |
16 |
17 | class InputBlock(nn.Module):
18 | """
19 | The first block (4x4 "pixels") doesn't have an input.
20 | The result of the first convolution is just replaced by a (trained) constant.
21 | We call it the InputBlock, the others GSynthesisBlock.
22 | (It might be nicer to do this the other way round,
23 | i.e. have the LayerEpilogue be the Layer and call the conv from that.)
24 | """
25 |
26 | def __init__(self, nf, dlatent_size, const_input_layer, gain,
27 | use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
28 | super().__init__()
29 | self.const_input_layer = const_input_layer
30 | self.nf = nf
31 |
32 | if self.const_input_layer:
33 | # called 'const' in tf
34 | self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
35 | self.bias = nn.Parameter(torch.ones(nf))
36 | else:
37 | self.dense = EqualizedLinear(dlatent_size, nf * 16, gain=gain / 4,
38 | use_wscale=use_wscale)
39 | # tweak gain to match the official implementation of Progressing GAN
40 |
41 | self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
42 | use_styles, activation_layer)
43 | self.conv = EqualizedConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
44 | self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
45 | use_styles, activation_layer)
46 |
47 | def forward(self, dlatents_in_range):
48 | batch_size = dlatents_in_range.size(0)
49 |
50 | if self.const_input_layer:
51 | x = self.const.expand(batch_size, -1, -1, -1)
52 | x = x + self.bias.view(1, -1, 1, 1)
53 | else:
54 | x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)
55 |
56 | x = self.epi1(x, dlatents_in_range[:, 0])
57 | x = self.conv(x)
58 | x = self.epi2(x, dlatents_in_range[:, 1])
59 |
60 | return x
61 |
62 |
63 | class GSynthesisBlock(nn.Module):
64 | def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain,
65 | use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
66 | # 2**res x 2**res
67 | # res = 3..resolution_log2
68 | super().__init__()
69 |
70 | if blur_filter:
71 | blur = BlurLayer(blur_filter)
72 | else:
73 | blur = None
74 |
75 | self.conv0_up = EqualizedConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
76 | intermediate=blur, upscale=True)
77 | self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
78 | use_styles, activation_layer)
79 | self.conv1 = EqualizedConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
80 | self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm,
81 | use_styles, activation_layer)
82 |
83 | def forward(self, x, dlatents_in_range):
84 | x = self.conv0_up(x)
85 | x = self.epi1(x, dlatents_in_range[:, 0])
86 | x = self.conv1(x)
87 | x = self.epi2(x, dlatents_in_range[:, 1])
88 | return x
89 |
90 |
91 | class DiscriminatorTop(nn.Sequential):
92 | def __init__(self,
93 | mbstd_group_size,
94 | mbstd_num_features,
95 | in_channels,
96 | intermediate_channels,
97 | gain, use_wscale,
98 | activation_layer,
99 | resolution=4,
100 | in_channels2=None,
101 | output_features=1,
102 | last_gain=1):
103 | """
104 | :param mbstd_group_size:
105 | :param mbstd_num_features:
106 | :param in_channels:
107 | :param intermediate_channels:
108 | :param gain:
109 | :param use_wscale:
110 | :param activation_layer:
111 | :param resolution:
112 | :param in_channels2:
113 | :param output_features:
114 | :param last_gain:
115 | """
116 |
117 | layers = []
118 | if mbstd_group_size > 1:
119 | layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))
120 |
121 | if in_channels2 is None:
122 | in_channels2 = in_channels
123 |
124 | layers.append(('conv', EqualizedConv2d(in_channels + mbstd_num_features, in_channels2, kernel_size=3,
125 | gain=gain, use_wscale=use_wscale)))
126 | layers.append(('act0', activation_layer))
127 | layers.append(('view', View(-1)))
128 | layers.append(('dense0', EqualizedLinear(in_channels2 * resolution * resolution, intermediate_channels,
129 | gain=gain, use_wscale=use_wscale)))
130 | layers.append(('act1', activation_layer))
131 | layers.append(('dense1', EqualizedLinear(intermediate_channels, output_features,
132 | gain=last_gain, use_wscale=use_wscale)))
133 |
134 | super().__init__(OrderedDict(layers))
135 |
136 |
137 | class DiscriminatorBlock(nn.Sequential):
138 | def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer, blur_kernel):
139 | super().__init__(OrderedDict([
140 | ('conv0', EqualizedConv2d(in_channels, in_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)),
141 | # out channels nf(res-1)
142 | ('act0', activation_layer),
143 | ('blur', BlurLayer(kernel=blur_kernel)),
144 | ('conv1_down', EqualizedConv2d(in_channels, out_channels, kernel_size=3,
145 | gain=gain, use_wscale=use_wscale, downscale=True)),
146 | ('act1', activation_layer)]))
147 |
148 |
149 | if __name__ == '__main__':
150 | # discriminator = DiscriminatorTop()
151 | print('Done.')
152 |
--------------------------------------------------------------------------------
/models/CustomLayers.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: CustomLayers.py
4 | Date: 2019/10/17
5 | Description: Copy from: https://github.com/lernapparat/lernapparat
6 | -------------------------------------------------
7 | """
8 |
9 | import numpy as np
10 | from collections import OrderedDict
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 |
17 | class PixelNormLayer(nn.Module):
18 | def __init__(self, epsilon=1e-8):
19 | super().__init__()
20 | self.epsilon = epsilon
21 |
22 | def forward(self, x):
23 | return x * torch.rsqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
24 |
25 |
26 | class Upscale2d(nn.Module):
27 | @staticmethod
28 | def upscale2d(x, factor=2, gain=1):
29 | assert x.dim() == 4
30 | if gain != 1:
31 | x = x * gain
32 | if factor != 1:
33 | shape = x.shape
34 | x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
35 | x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
36 | return x
37 |
38 | def __init__(self, factor=2, gain=1):
39 | super().__init__()
40 | assert isinstance(factor, int) and factor >= 1
41 | self.gain = gain
42 | self.factor = factor
43 |
44 | def forward(self, x):
45 | return self.upscale2d(x, factor=self.factor, gain=self.gain)
46 |
47 |
48 | class Downscale2d(nn.Module):
49 | def __init__(self, factor=2, gain=1):
50 | super().__init__()
51 | assert isinstance(factor, int) and factor >= 1
52 | self.factor = factor
53 | self.gain = gain
54 | if factor == 2:
55 | f = [np.sqrt(gain) / factor] * factor
56 | self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)
57 | else:
58 | self.blur = None
59 |
60 | def forward(self, x):
61 | assert x.dim() == 4
62 | # 2x2, float32 => downscale using _blur2d().
63 | if self.blur is not None and x.dtype == torch.float32:
64 | return self.blur(x)
65 |
66 | # Apply gain.
67 | if self.gain != 1:
68 | x = x * self.gain
69 |
70 | # No-op => early exit.
71 | if self.factor == 1:
72 | return x
73 |
74 | # Large factor => downscale using tf.nn.avg_pool().
75 | # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.
76 | return F.avg_pool2d(x, self.factor)
77 |
78 |
79 | class EqualizedLinear(nn.Module):
80 | """Linear layer with equalized learning rate and custom learning rate multiplier."""
81 |
82 | def __init__(self, input_size, output_size, gain=2 ** 0.5, use_wscale=False, lrmul=1, bias=True):
83 | super().__init__()
84 | he_std = gain * input_size ** (-0.5) # He init
85 | # Equalized learning rate and custom learning rate multiplier.
86 | if use_wscale:
87 | init_std = 1.0 / lrmul
88 | self.w_mul = he_std * lrmul
89 | else:
90 | init_std = he_std / lrmul
91 | self.w_mul = lrmul
92 | self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
93 | if bias:
94 | self.bias = torch.nn.Parameter(torch.zeros(output_size))
95 | self.b_mul = lrmul
96 | else:
97 | self.bias = None
98 |
99 | def forward(self, x):
100 | bias = self.bias
101 | if bias is not None:
102 | bias = bias * self.b_mul
103 | return F.linear(x, self.weight * self.w_mul, bias)
104 |
105 |
106 | class EqualizedConv2d(nn.Module):
107 | """Conv layer with equalized learning rate and custom learning rate multiplier."""
108 |
109 | def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2 ** 0.5, use_wscale=False,
110 | lrmul=1, bias=True, intermediate=None, upscale=False, downscale=False):
111 | super().__init__()
112 | if upscale:
113 | self.upscale = Upscale2d()
114 | else:
115 | self.upscale = None
116 | if downscale:
117 | self.downscale = Downscale2d()
118 | else:
119 | self.downscale = None
120 | he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
121 | self.kernel_size = kernel_size
122 | if use_wscale:
123 | init_std = 1.0 / lrmul
124 | self.w_mul = he_std * lrmul
125 | else:
126 | init_std = he_std / lrmul
127 | self.w_mul = lrmul
128 | self.weight = torch.nn.Parameter(
129 | torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
130 | if bias:
131 | self.bias = torch.nn.Parameter(torch.zeros(output_channels))
132 | self.b_mul = lrmul
133 | else:
134 | self.bias = None
135 | self.intermediate = intermediate
136 |
137 | def forward(self, x):
138 | bias = self.bias
139 | if bias is not None:
140 | bias = bias * self.b_mul
141 |
142 | have_convolution = False
143 | if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
144 | # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
145 | # this really needs to be cleaned up and go into the conv...
146 | w = self.weight * self.w_mul
147 | w = w.permute(1, 0, 2, 3)
148 | # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
149 | w = F.pad(w, [1, 1, 1, 1])
150 | w = w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
151 | x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
152 | have_convolution = True
153 | elif self.upscale is not None:
154 | x = self.upscale(x)
155 |
156 | downscale = self.downscale
157 | intermediate = self.intermediate
158 | if downscale is not None and min(x.shape[2:]) >= 128:
159 | w = self.weight * self.w_mul
160 | w = F.pad(w, [1, 1, 1, 1])
161 | # in contrast to upscale, this is a mean...
162 | w = (w[:, :, 1:, 1:] + w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]) * 0.25 # avg_pool?
163 | x = F.conv2d(x, w, stride=2, padding=(w.size(-1) - 1) // 2)
164 | have_convolution = True
165 | downscale = None
166 | elif downscale is not None:
167 | assert intermediate is None
168 | intermediate = downscale
169 |
170 | if not have_convolution and intermediate is None:
171 | return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size // 2)
172 | elif not have_convolution:
173 | x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size // 2)
174 |
175 | if intermediate is not None:
176 | x = intermediate(x)
177 |
178 | if bias is not None:
179 | x = x + bias.view(1, -1, 1, 1)
180 | return x
181 |
182 |
183 | class NoiseLayer(nn.Module):
184 | """adds noise. noise is per pixel (constant over channels) with per-channel weight"""
185 |
186 | def __init__(self, channels):
187 | super().__init__()
188 | self.weight = nn.Parameter(torch.zeros(channels))
189 | self.noise = None
190 |
191 | def forward(self, x, noise=None):
192 | if noise is None and self.noise is None:
193 | noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
194 | elif noise is None:
195 | # here is a little trick: if you get all the noise layers and set each
196 | # modules .noise attribute, you can have pre-defined noise.
197 | # Very useful for analysis
198 | noise = self.noise
199 | x = x + self.weight.view(1, -1, 1, 1) * noise
200 | return x
201 |
202 |
203 | class StyleMod(nn.Module):
204 | def __init__(self, latent_size, channels, use_wscale):
205 | super(StyleMod, self).__init__()
206 | self.lin = EqualizedLinear(latent_size,
207 | channels * 2,
208 | gain=1.0, use_wscale=use_wscale)
209 |
210 | def forward(self, x, latent):
211 | style = self.lin(latent) # style => [batch_size, n_channels*2]
212 |
213 | shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
214 | style = style.view(shape) # [batch_size, 2, n_channels, ...]
215 | x = x * (style[:, 0] + 1.) + style[:, 1]
216 | return x
217 |
218 |
219 | class LayerEpilogue(nn.Module):
220 | """Things to do at the end of each layer."""
221 |
222 | def __init__(self, channels, dlatent_size, use_wscale,
223 | use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
224 | super().__init__()
225 |
226 | layers = []
227 | if use_noise:
228 | layers.append(('noise', NoiseLayer(channels)))
229 | layers.append(('activation', activation_layer))
230 | if use_pixel_norm:
231 | layers.append(('pixel_norm', PixelNormLayer()))
232 | if use_instance_norm:
233 | layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
234 |
235 | self.top_epi = nn.Sequential(OrderedDict(layers))
236 |
237 | if use_styles:
238 | self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
239 | else:
240 | self.style_mod = None
241 |
242 | def forward(self, x, dlatents_in_slice=None):
243 | x = self.top_epi(x)
244 | if self.style_mod is not None:
245 | x = self.style_mod(x, dlatents_in_slice)
246 | else:
247 | assert dlatents_in_slice is None
248 | return x
249 |
250 |
251 | class BlurLayer(nn.Module):
252 | def __init__(self, kernel=None, normalize=True, flip=False, stride=1):
253 | super(BlurLayer, self).__init__()
254 | if kernel is None:
255 | kernel = [1, 2, 1]
256 | kernel = torch.tensor(kernel, dtype=torch.float32)
257 | kernel = kernel[:, None] * kernel[None, :]
258 | kernel = kernel[None, None]
259 | if normalize:
260 | kernel = kernel / kernel.sum()
261 | if flip:
262 | kernel = kernel[:, :, ::-1, ::-1]
263 | self.register_buffer('kernel', kernel)
264 | self.stride = stride
265 |
266 | def forward(self, x):
267 | # expand kernel channels
268 | kernel = self.kernel.expand(x.size(1), -1, -1, -1)
269 | x = F.conv2d(
270 | x,
271 | kernel,
272 | stride=self.stride,
273 | padding=int((self.kernel.size(2) - 1) / 2),
274 | groups=x.size(1)
275 | )
276 | return x
277 |
278 |
279 | class View(nn.Module):
280 | def __init__(self, *shape):
281 | super().__init__()
282 | self.shape = shape
283 |
284 | def forward(self, x):
285 | return x.view(x.size(0), *self.shape)
286 |
287 |
288 | class StddevLayer(nn.Module):
289 | def __init__(self, group_size=4, num_new_features=1):
290 | super().__init__()
291 | self.group_size = group_size
292 | self.num_new_features = num_new_features
293 |
294 | def forward(self, x):
295 | b, c, h, w = x.shape
296 | group_size = min(self.group_size, b)
297 | y = x.reshape([group_size, -1, self.num_new_features,
298 | c // self.num_new_features, h, w])
299 | y = y - y.mean(0, keepdim=True)
300 | y = (y ** 2).mean(0, keepdim=True)
301 | y = (y + 1e-8) ** 0.5
302 | y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels
303 | y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)
304 | z = torch.cat([x, y], dim=1)
305 | return z
306 |
307 |
308 | class Truncation(nn.Module):
309 | def __init__(self, avg_latent, max_layer=8, threshold=0.7, beta=0.995):
310 | super().__init__()
311 | self.max_layer = max_layer
312 | self.threshold = threshold
313 | self.beta = beta
314 | self.register_buffer('avg_latent', avg_latent)
315 |
316 | def update(self, last_avg):
317 | self.avg_latent.copy_(self.beta * self.avg_latent + (1. - self.beta) * last_avg)
318 |
319 | def forward(self, x):
320 | assert x.dim() == 3
321 | interp = torch.lerp(self.avg_latent, x, self.threshold)
322 | do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1).to(x.device)
323 | return torch.where(do_trunc, interp, x)
324 |
--------------------------------------------------------------------------------
/models/Losses.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: Losses.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/21
6 | Description: Module implementing various loss functions
7 | Copy from: https://github.com/akanimax/pro_gan_pytorch
8 | -------------------------------------------------
9 | """
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn as nn
14 | from torch.nn import BCEWithLogitsLoss
15 |
16 | # =============================================================
17 | # Interface for the losses
18 | # =============================================================
19 |
20 | class GANLoss:
21 | """ Base class for all losses
22 |
23 | @args:
24 | dis: Discriminator used for calculating the loss
25 | Note this must be a part of the GAN framework
26 | """
27 |
28 | def __init__(self, dis):
29 | self.dis = dis
30 |
31 | def dis_loss(self, real_samps, fake_samps, height, alpha):
32 | """
33 | calculate the discriminator loss using the following data
34 | :param real_samps: batch of real samples
35 | :param fake_samps: batch of generated (fake) samples
36 | :param height: current height at which training is going on
37 | :param alpha: current value of the fader alpha
38 | :return: loss => calculated loss Tensor
39 | """
40 | raise NotImplementedError("dis_loss method has not been implemented")
41 |
42 | def gen_loss(self, real_samps, fake_samps, height, alpha):
43 | """
44 | calculate the generator loss
45 | :param real_samps: batch of real samples
46 | :param fake_samps: batch of generated (fake) samples
47 | :param height: current height at which training is going on
48 | :param alpha: current value of the fader alpha
49 | :return: loss => calculated loss Tensor
50 | """
51 | raise NotImplementedError("gen_loss method has not been implemented")
52 |
53 |
54 | class ConditionalGANLoss:
55 | """ Base class for all conditional losses """
56 |
57 | def __init__(self, dis):
58 | self.criterion = BCEWithLogitsLoss()
59 | self.dis = dis
60 |
61 | def dis_loss(self, real_samps, fake_samps, labels, height, alpha):
62 | # small assertion:
63 | assert real_samps.device == fake_samps.device, \
64 | "Real and Fake samples are not on the same device"
65 |
66 | # device for computations:
67 | device = fake_samps.device
68 |
69 | # predictions for real images and fake images separately :
70 | r_preds = self.dis(real_samps, height, alpha, labels_in=labels)
71 | f_preds = self.dis(fake_samps, height, alpha, labels_in=labels)
72 |
73 | # calculate the real loss:
74 | real_loss = self.criterion(
75 | torch.squeeze(r_preds),
76 | torch.ones(real_samps.shape[0]).to(device))
77 |
78 | # calculate the fake loss:
79 | fake_loss = self.criterion(
80 | torch.squeeze(f_preds),
81 | torch.zeros(fake_samps.shape[0]).to(device))
82 |
83 | # return final losses
84 | return (real_loss + fake_loss) / 2
85 |
86 | def gen_loss(self, _, fake_samps, labels, height, alpha):
87 | preds = self.dis(fake_samps, height, alpha, labels_in=labels)
88 | return self.criterion(torch.squeeze(preds),
89 | torch.ones(fake_samps.shape[0]).to(fake_samps.device))
90 |
91 |
92 | # =============================================================
93 | # Normal versions of the Losses:
94 | # =============================================================
95 |
96 | class StandardGAN(GANLoss):
97 |
98 | def __init__(self, dis):
99 |
100 | super().__init__(dis)
101 |
102 | # define the criterion and activation used for object
103 | self.criterion = BCEWithLogitsLoss()
104 |
105 | def dis_loss(self, real_samps, fake_samps, height, alpha):
106 | # small assertion:
107 | assert real_samps.device == fake_samps.device, \
108 | "Real and Fake samples are not on the same device"
109 |
110 | # device for computations:
111 | device = fake_samps.device
112 |
113 | # predictions for real images and fake images separately :
114 | r_preds = self.dis(real_samps, height, alpha)
115 | f_preds = self.dis(fake_samps, height, alpha)
116 |
117 | # calculate the real loss:
118 | real_loss = self.criterion(
119 | torch.squeeze(r_preds),
120 | torch.ones(real_samps.shape[0]).to(device))
121 |
122 | # calculate the fake loss:
123 | fake_loss = self.criterion(
124 | torch.squeeze(f_preds),
125 | torch.zeros(fake_samps.shape[0]).to(device))
126 |
127 | # return final losses
128 | return (real_loss + fake_loss) / 2
129 |
130 | def gen_loss(self, _, fake_samps, height, alpha):
131 | preds, _, _ = self.dis(fake_samps, height, alpha)
132 | return self.criterion(torch.squeeze(preds),
133 | torch.ones(fake_samps.shape[0]).to(fake_samps.device))
134 |
135 |
136 | class HingeGAN(GANLoss):
137 |
138 | def __init__(self, dis):
139 | super().__init__(dis)
140 |
141 | def dis_loss(self, real_samps, fake_samps, height, alpha):
142 | r_preds = self.dis(real_samps, height, alpha)
143 | f_preds = self.dis(fake_samps, height, alpha)
144 |
145 | loss = (torch.mean(nn.ReLU()(1 - r_preds)) +
146 | torch.mean(nn.ReLU()(1 + f_preds)))
147 |
148 | return loss
149 |
150 | def gen_loss(self, _, fake_samps, height, alpha):
151 | return -torch.mean(self.dis(fake_samps, height, alpha))
152 |
153 |
154 | class RelativisticAverageHingeGAN(GANLoss):
155 |
156 | def __init__(self, dis):
157 | super().__init__(dis)
158 |
159 | def dis_loss(self, real_samps, fake_samps, height, alpha):
160 | # Obtain predictions
161 | r_preds = self.dis(real_samps, height, alpha)
162 | f_preds = self.dis(fake_samps, height, alpha)
163 |
164 | # difference between real and fake:
165 | r_f_diff = r_preds - torch.mean(f_preds)
166 |
167 | # difference between fake and real samples
168 | f_r_diff = f_preds - torch.mean(r_preds)
169 |
170 | # return the loss
171 | loss = (torch.mean(nn.ReLU()(1 - r_f_diff))
172 | + torch.mean(nn.ReLU()(1 + f_r_diff)))
173 |
174 | return loss
175 |
176 | def gen_loss(self, real_samps, fake_samps, height, alpha):
177 | # Obtain predictions
178 | r_preds = self.dis(real_samps, height, alpha)
179 | f_preds = self.dis(fake_samps, height, alpha)
180 |
181 | # difference between real and fake:
182 | r_f_diff = r_preds - torch.mean(f_preds)
183 |
184 | # difference between fake and real samples
185 | f_r_diff = f_preds - torch.mean(r_preds)
186 |
187 | # return the loss
188 | return (torch.mean(nn.ReLU()(1 + r_f_diff))
189 | + torch.mean(nn.ReLU()(1 - f_r_diff)))
190 |
191 |
192 | class LogisticGAN(GANLoss):
193 | def __init__(self, dis):
194 | super().__init__(dis)
195 |
196 | # gradient penalty
197 | def R1Penalty(self, real_img, height, alpha):
198 |
199 | # TODO: use_loss_scaling, for fp16
200 | apply_loss_scaling = lambda x: x * torch.exp(x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device))
201 | undo_loss_scaling = lambda x: x * torch.exp(-x * torch.Tensor([np.float32(np.log(2.0))]).to(real_img.device))
202 |
203 | real_img = torch.autograd.Variable(real_img, requires_grad=True)
204 | real_logit = self.dis(real_img, height, alpha)
205 | # real_logit = apply_loss_scaling(torch.sum(real_logit))
206 | real_grads = torch.autograd.grad(outputs=real_logit, inputs=real_img,
207 | grad_outputs=torch.ones(real_logit.size()).to(real_img.device),
208 | create_graph=True, retain_graph=True)[0].view(real_img.size(0), -1)
209 | # real_grads = undo_loss_scaling(real_grads)
210 | r1_penalty = torch.sum(torch.mul(real_grads, real_grads))
211 | return r1_penalty
212 |
213 | def dis_loss(self, real_samps, fake_samps, height, alpha, r1_gamma=10.0):
214 | # Obtain predictions
215 | r_preds = self.dis(real_samps, height, alpha)
216 | f_preds = self.dis(fake_samps, height, alpha)
217 |
218 | loss = torch.mean(nn.Softplus()(f_preds)) + torch.mean(nn.Softplus()(-r_preds))
219 |
220 | if r1_gamma != 0.0:
221 | r1_penalty = self.R1Penalty(real_samps.detach(), height, alpha) * (r1_gamma * 0.5)
222 | loss += r1_penalty
223 |
224 | return loss
225 |
226 | def gen_loss(self, _, fake_samps, height, alpha):
227 | f_preds = self.dis(fake_samps, height, alpha)
228 |
229 | return torch.mean(nn.Softplus()(-f_preds))
230 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: __init__.py.py
4 | Date: 2019/10/17
5 | Description:
6 | -------------------------------------------------
7 | """
8 |
9 |
10 | # function to calculate the Exponential moving averages for the Generator weights
11 | # This function updates the exponential average weights based on the current training
12 | # Copy from: https://github.com/akanimax/pro_gan_pytorch
13 | def update_average(model_tgt, model_src, beta):
14 | """
15 | update the model_target using exponential moving averages
16 | :param model_tgt: target model
17 | :param model_src: source model
18 | :param beta: value of decay beta
19 | :return: None (updates the target model)
20 | """
21 |
22 | # utility function for toggling the gradient requirements of the models
23 | def toggle_grad(model, requires_grad):
24 | for p in model.parameters():
25 | p.requires_grad_(requires_grad)
26 |
27 | # turn off gradient calculation
28 | toggle_grad(model_tgt, False)
29 | toggle_grad(model_src, False)
30 |
31 | param_dict_src = dict(model_src.named_parameters())
32 |
33 | for p_name, p_tgt in model_tgt.named_parameters():
34 | p_src = param_dict_src[p_name]
35 | assert (p_src is not p_tgt)
36 | p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src)
37 |
38 | # turn back on the gradient calculation
39 | toggle_grad(model_tgt, True)
40 | toggle_grad(model_src, True)
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | yacs
2 | tqdm
3 | numpy
4 | torchvision
5 | torch
6 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: __init__.py.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/17
6 | Description:
7 | -------------------------------------------------
8 | """
--------------------------------------------------------------------------------
/test/test_Blocks.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: test_Blocks.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/17
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 | from unittest import TestCase
11 |
12 |
13 | class TestGMapping(TestCase):
14 | def setUp(self) -> None:
15 | pass
16 |
17 | def test_forward(self):
18 | pass
19 |
20 | def tearDown(self) -> None:
21 | pass
22 |
--------------------------------------------------------------------------------
/test/test_CustomLayers.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: test_CustomLayers.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/17
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: train.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/18
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 | import argparse
11 | import os
12 | import shutil
13 |
14 | import torch
15 | from torch.backends import cudnn
16 |
17 | from data import make_dataset
18 | from models.GAN import StyleGAN
19 | from utils import (copy_files_and_create_dirs,
20 | list_dir_recursively_with_ignore, make_logger)
21 |
22 |
23 | # Load fewer layers of pre-trained models if possible
24 | def load(model, cpk_file):
25 | pretrained_dict = torch.load(cpk_file)
26 | model_dict = model.state_dict()
27 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
28 | model_dict.update(pretrained_dict)
29 | model.load_state_dict(model_dict)
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser(description="StyleGAN pytorch implementation.")
34 | parser.add_argument('--config', default='./configs/sample.yaml')
35 |
36 | parser.add_argument("--start_depth", action="store", type=int, default=0,
37 | help="Starting depth for training the network")
38 |
39 | parser.add_argument("--generator_file", action="store", type=str, default=None,
40 | help="pretrained Generator file (compatible with my code)")
41 | parser.add_argument("--gen_shadow_file", action="store", type=str, default=None,
42 | help="pretrained gen_shadow file")
43 | parser.add_argument("--discriminator_file", action="store", type=str, default=None,
44 | help="pretrained Discriminator file (compatible with my code)")
45 | parser.add_argument("--gen_optim_file", action="store", type=str, default=None,
46 | help="saved state of generator optimizer")
47 | parser.add_argument("--dis_optim_file", action="store", type=str, default=None,
48 | help="saved_state of discriminator optimizer")
49 | args = parser.parse_args()
50 |
51 | from config import cfg as opt
52 |
53 | opt.merge_from_file(args.config)
54 | opt.freeze()
55 |
56 | # make output dir
57 | output_dir = opt.output_dir
58 | if os.path.exists(output_dir):
59 | raise KeyError("Existing path: ", output_dir)
60 | os.makedirs(output_dir)
61 |
62 | # copy codes and config file
63 | files = list_dir_recursively_with_ignore('.', ignores=['diagrams', 'configs'])
64 | files = [(f[0], os.path.join(output_dir, "src", f[1])) for f in files]
65 | copy_files_and_create_dirs(files)
66 | shutil.copy2(args.config, output_dir)
67 |
68 | # logger
69 | logger = make_logger("project", opt.output_dir, 'log')
70 |
71 | # device
72 | if opt.device == 'cuda':
73 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.device_id
74 | num_gpus = len(opt.device_id.split(','))
75 | logger.info("Using {} GPUs.".format(num_gpus))
76 | logger.info("Training on {}.\n".format(torch.cuda.get_device_name(0)))
77 | cudnn.benchmark = True
78 | device = torch.device(opt.device)
79 |
80 | # create the dataset for training
81 | dataset = make_dataset(opt.dataset, conditional=opt.conditional)
82 |
83 | # init the network
84 | style_gan = StyleGAN(structure=opt.structure,
85 | conditional=opt.conditional,
86 | n_classes=opt.n_classes,
87 | resolution=opt.dataset.resolution,
88 | num_channels=opt.dataset.channels,
89 | latent_size=opt.model.gen.latent_size,
90 | g_args=opt.model.gen,
91 | d_args=opt.model.dis,
92 | g_opt_args=opt.model.g_optim,
93 | d_opt_args=opt.model.d_optim,
94 | loss=opt.loss,
95 | drift=opt.drift,
96 | d_repeats=opt.d_repeats,
97 | use_ema=opt.use_ema,
98 | ema_decay=opt.ema_decay,
99 | device=device)
100 |
101 | # Resume training from checkpoints
102 | if args.generator_file is not None:
103 | logger.info("Loading generator from: %s", args.generator_file)
104 | # style_gan.gen.load_state_dict(torch.load(args.generator_file))
105 | # Load fewer layers of pre-trained models if possible
106 | load(style_gan.gen, args.generator_file)
107 | else:
108 | logger.info("Training from scratch...")
109 |
110 | if args.discriminator_file is not None:
111 | logger.info("Loading discriminator from: %s", args.discriminator_file)
112 | style_gan.dis.load_state_dict(torch.load(args.discriminator_file))
113 |
114 | if args.gen_shadow_file is not None and opt.use_ema:
115 | logger.info("Loading shadow generator from: %s", args.gen_shadow_file)
116 | # style_gan.gen_shadow.load_state_dict(torch.load(args.gen_shadow_file))
117 | # Load fewer layers of pre-trained models if possible
118 | load(style_gan.gen_shadow, args.gen_shadow_file)
119 |
120 | if args.gen_optim_file is not None:
121 | logger.info("Loading generator optimizer from: %s", args.gen_optim_file)
122 | style_gan.gen_optim.load_state_dict(torch.load(args.gen_optim_file))
123 |
124 | if args.dis_optim_file is not None:
125 | logger.info("Loading discriminator optimizer from: %s", args.dis_optim_file)
126 | style_gan.dis_optim.load_state_dict(torch.load(args.dis_optim_file))
127 |
128 | # train the network
129 | style_gan.train(dataset=dataset,
130 | num_workers=opt.num_works,
131 | epochs=opt.sched.epochs,
132 | batch_sizes=opt.sched.batch_sizes,
133 | fade_in_percentage=opt.sched.fade_in_percentage,
134 | logger=logger,
135 | output=output_dir,
136 | num_samples=opt.num_samples,
137 | start_depth=args.start_depth,
138 | feedback_factor=opt.feedback_factor,
139 | checkpoint_factor=opt.checkpoint_factor)
140 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | -------------------------------------------------
3 | File Name: __init__.py.py
4 | Author: Zhonghao Huang
5 | Date: 2019/10/24
6 | Description:
7 | -------------------------------------------------
8 | """
9 |
10 | from .logger import make_logger
11 | from .copy import list_dir_recursively_with_ignore, copy_files_and_create_dirs
12 |
--------------------------------------------------------------------------------
/utils/copy.py:
--------------------------------------------------------------------------------
1 | import fnmatch
2 | import os
3 | import shutil
4 | from typing import List, Tuple
5 |
6 |
7 | def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
8 | """Takes in a list of tuples of (src, dst) paths and copies files.
9 | Will create all necessary directories."""
10 | for file in files:
11 | target_dir_name = os.path.dirname(file[1])
12 |
13 | # will create all intermediate-level directories
14 | if not os.path.exists(target_dir_name):
15 | os.makedirs(target_dir_name)
16 |
17 | shutil.copyfile(file[0], file[1])
18 |
19 |
20 | def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> \
21 | List[Tuple[str, str]]:
22 | """List all files recursively in a given directory while ignoring given file and directory names.
23 | Returns list of tuples containing both absolute and relative paths."""
24 | assert os.path.isdir(dir_path)
25 | base_name = os.path.basename(os.path.normpath(dir_path))
26 |
27 | if ignores is None:
28 | ignores = []
29 |
30 | result = []
31 |
32 | for root, dirs, files in os.walk(dir_path, topdown=True):
33 | for ignore_ in ignores:
34 | dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
35 |
36 | # dirs need to be edited in-place
37 | for d in dirs_to_remove:
38 | dirs.remove(d)
39 |
40 | files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
41 |
42 | absolute_paths = [os.path.join(root, f) for f in files]
43 | relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
44 |
45 | if add_base_to_relative:
46 | relative_paths = [os.path.join(base_name, p) for p in relative_paths]
47 |
48 | assert len(absolute_paths) == len(relative_paths)
49 | result += zip(absolute_paths, relative_paths)
50 |
51 | return result
52 |
53 |
54 | if __name__ == '__main__':
55 | output = '../checkpoint/Exp-copy'
56 | ignores = ['checkpoint', 'configs']
57 | files = list_dir_recursively_with_ignore('..', ignores=ignores)
58 | files = [(f[0], os.path.join(output, "src", f[1])) for f in files]
59 | copy_files_and_create_dirs(files)
60 |
61 | print('Done.')
62 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 |
6 | def make_logger(name, save_dir, save_filename):
7 | DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
8 |
9 | logger = logging.getLogger(name)
10 | logger.setLevel(logging.DEBUG)
11 |
12 | ch = logging.StreamHandler(stream=sys.stdout)
13 | ch.setLevel(logging.DEBUG)
14 | formatter = logging.Formatter("%(asctime)s %(levelname)s: %(message)s", datefmt=DATE_FORMAT)
15 | ch.setFormatter(formatter)
16 |
17 | logger.addHandler(ch)
18 |
19 | if save_dir:
20 | fh = logging.FileHandler(os.path.join(save_dir, save_filename + ".txt"), mode='w')
21 | fh.setLevel(logging.DEBUG)
22 | fh.setFormatter(formatter)
23 | logger.addHandler(fh)
24 |
25 | return logger
26 |
--------------------------------------------------------------------------------