├── .gitignore ├── LICENSE ├── README.md ├── classifier.py ├── data └── preprocess.py ├── images ├── interpolation.jpg ├── multi_attr.jpg ├── swap.jpg └── v3.png ├── interpolate.py ├── models └── download.sh ├── src ├── __init__.py ├── evaluation.py ├── loader.py ├── logger.py ├── model.py ├── training.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # End of https://www.gitignore.io/api/python 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FaderNetworks 2 | 3 | PyTorch implementation of [Fader Networks](https://arxiv.org/pdf/1706.00409.pdf) (NIPS 2017). 4 | 5 |

6 | 7 | Fader Networks can generate different realistic versions of images by modifying attributes such as gender or age group. They can swap multiple attributes at a time, and continuously interpolate between each attribute value. In this repository we provide the code to reproduce the results presented in the paper, as well as trained models. 8 | 9 | ### Single-attribute swap 10 | 11 | Below are some examples of different attribute swaps: 12 | 13 |

14 | 15 | ### Multi-attributes swap 16 | 17 | The Fader Networks are also designed to disentangle multiple attributes at a time: 18 | 19 |

20 | 21 | ## Model 22 | 23 |

24 | 25 | The main branch of the model (Inference Model), is an autoencoder of images. Given an image `x` and an attribute `y` (e.g. male/female), the decoder is trained to reconstruct the image from the latent state `E(x)` and `y`. The other branch (Adversarial Component), is composed of a discriminator trained to predict the attribute from the latent state. The encoder of the Inference Model is trained not only to reconstruct the image, but also to fool the discriminator, by removing from `E(x)` the information related to the attribute. As a result, the decoder needs to consider `y` to properly reconstruct the image. During training, the model is trained using real attribute values, but at test time, `y` can be manipulated to generate variations of the original image. 26 | 27 | ## Dependencies 28 | * Python 2/3 with [NumPy](http://www.numpy.org/)/[SciPy](https://www.scipy.org/) 29 | * [PyTorch](http://pytorch.org/) 30 | * OpenCV 31 | * CUDA 32 | 33 | 34 | ## Installation 35 | 36 | Simply clone the repository: 37 | 38 | ```bash 39 | git clone https://github.com/facebookresearch/FaderNetworks.git 40 | cd FaderNetworks 41 | ``` 42 | 43 | ## Dataset 44 | Download the aligned and cropped CelebA dataset from http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html. Extract all images and move them to the `data/img_align_celeba/` folder. There should be 202599 images. The dataset also provides a file `list_attr_celeba.txt` containing the list of the 40 attributes associated with each image. Move it to `data/`. Then simply run: 45 | 46 | ```batch 47 | cd data 48 | ./preprocess.py 49 | ``` 50 | 51 | It will resize images, and create 2 files: `images_256_256.pth` and `attributes.pth`. The first one contains a tensor of size `(202599, 3, 256, 256)` containing the concatenation of all resized images. Note that you can update the image size in `preprocess.py` to work with different resolutions. The second file is a pre-processed version of the attributes. 52 | 53 | ## Pretrained models 54 | You can download pretrained classifiers and Fader Networks by running: 55 | 56 | ```batch 57 | cd models 58 | ./download.sh 59 | ``` 60 | 61 | ## Train your own models 62 | 63 | ### Train a classifier 64 | To train your own model you first need to train a classifier to let the model evaluate the swap quality during the training. Training a good classifier is relatively simple for most attributes, and a good model can be trained in a few minutes. We provide a trained classifier for all attributes in `models/classifier256.pth`. Note that the classifier does not need to be state-of-the-art, it is not used during the training process, but is just here to monitor the swap quality. If you want to train your own classifier, you can run `classifier.py`, using the following parameters: 65 | 66 | 67 | ```bash 68 | python classifier.py 69 | 70 | # Main parameters 71 | --img_sz 256 # image size 72 | --img_fm 3 # number of feature maps 73 | --attr "*" # attributes list. "*" for all attributes 74 | 75 | # Network architecture 76 | --init_fm 32 # number of feature maps in the first layer 77 | --max_fm 512 # maximum number of feature maps 78 | --hid_dim 512 # hidden layer size 79 | 80 | # Training parameters 81 | --v_flip False # randomly flip images vertically (data augmentation) 82 | --h_flip True # randomly flip images horizontally (data augmentation) 83 | --batch_size 32 # batch size 84 | --optimizer "adam,lr=0.0002" # optimizer 85 | --clip_grad_norm 5 # clip gradient L2 norm 86 | --n_epochs 1000 # number of epochs 87 | --epoch_size 50000 # number of images per epoch 88 | 89 | # Reload 90 | --reload "" # reload a trained classifier 91 | --debug False # debug mode (if True, load a small subset of the dataset) 92 | ``` 93 | 94 | 95 | ### Train a Fader Network 96 | 97 | You can train a Fader Network with `train.py`. The autoencoder can receive feedback from: 98 | - The image reconstruction loss 99 | - The latent discriminator loss 100 | - The PatchGAN discriminator loss 101 | - The classifier loss 102 | 103 | In the paper, only the first two losses are used, but the two others could improve the results further. You can tune the impact of each of these losses with the lambda_ae, lambda_lat_dis, lambda_ptc_dis, and lambda_clf_dis coefficients. Below is a complete list of all parameters: 104 | 105 | ```bash 106 | # Main parameters 107 | --img_sz 256 # image size 108 | --img_fm 3 # number of feature maps 109 | --attr "Male" # attributes list. "*" for all attributes 110 | 111 | # Networks architecture 112 | --instance_norm False # use instance normalization instead of batch normalization 113 | --init_fm 32 # number of feature maps in the first layer 114 | --max_fm 512 # maximum number of feature maps 115 | --n_layers 6 # number of layers in the encoder / decoder 116 | --n_skip 0 # number of skip connections 117 | --deconv_method "convtranspose" # deconvolution method 118 | --hid_dim 512 # hidden layer size 119 | --dec_dropout 0 # dropout in the decoder 120 | --lat_dis_dropout 0.3 # dropout in the latent discriminator 121 | 122 | # Training parameters 123 | --n_lat_dis 1 # number of latent discriminator training steps 124 | --n_ptc_dis 0 # number of PatchGAN discriminator training steps 125 | --n_clf_dis 0 # number of classifier training steps 126 | --smooth_label 0.2 # smooth discriminator labels 127 | --lambda_ae 1 # autoencoder loss coefficient 128 | --lambda_lat_dis 0.0001 # latent discriminator loss coefficient 129 | --lambda_ptc_dis 0 # PatchGAN discriminator loss coefficient 130 | --lambda_clf_dis 0 # classifier loss coefficient 131 | --lambda_schedule 500000 # lambda scheduling (0 to disable) 132 | --v_flip False # randomly flip images vertically (data augmentation) 133 | --h_flip True # randomly flip images horizontally (data augmentation) 134 | --batch_size 32 # batch size 135 | --ae_optimizer "adam,lr=0.0002" # autoencoder optimizer 136 | --dis_optimizer "adam,lr=0.0002" # discriminator optimizer 137 | --clip_grad_norm 5 # clip gradient L2 norm 138 | --n_epochs 1000 # number of epochs 139 | --epoch_size 50000 # number of images per epoch 140 | 141 | # Reload 142 | --ae_reload "" # reload pretrained autoencoder 143 | --lat_dis_reload "" # reload pretrained latent discriminator 144 | --ptc_dis_reload "" # reload pretrained PatchGAN discriminator 145 | --clf_dis_reload "" # reload pretrained classifier 146 | --eval_clf "" # evaluation classifier (trained with classifier.py) 147 | --debug False # debug mode (if True, load a small subset of the dataset) 148 | ``` 149 | 150 | ## Generate interpolations 151 | 152 | Given a trained model, you can use it to swap attributes of images in the dataset. Below are examples using the pretrained models: 153 | 154 | ```bash 155 | # Narrow Eyes 156 | python interpolate.py --model_path models/narrow_eyes.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path narrow_eyes.png 157 | 158 | # Eyeglasses 159 | python interpolate.py --model_path models/eyeglasses.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path eyeglasses.png 160 | 161 | # Age 162 | python interpolate.py --model_path models/young.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path young.png 163 | 164 | # Gender 165 | python interpolate.py --model_path models/male.pth --n_images 10 --n_interpolations 10 --alpha_min 2.0 --alpha_max 2.0 --output_path male.png 166 | 167 | # Pointy nose 168 | python interpolate.py --model_path models/pointy_nose.pth --n_images 10 --n_interpolations 10 --alpha_min 10.0 --alpha_max 10.0 --output_path pointy_nose.png 169 | ``` 170 | 171 | These commands will generate images with 10 rows of 12 columns with the interpolated images. The first column corresponds to the original image, the second is the reconstructed image (without alteration of the attribute), and the remaining ones correspond to the interpolated images. `alpha_min` and `alpha_max` represent the range of the interpolation. Values superior to 1 represent generations over the True / False range of the boolean attribute in the model. Note that the variations of some attributes may only be noticeable for high values of alphas. For instance, for the "eyeglasses" or "gender" attributes, alpha_max=2 is usually enough, while for the "age" or "narrow eyes" attributes, it is better to go up to alpha_max=10. 172 | 173 | 174 | ## References 175 | 176 | If you find this code useful, please consider citing: 177 | 178 | [*Fader Networks: Manipulating Images by Sliding Attributes*](https://arxiv.org/pdf/1706.00409.pdf) - G. Lample, N. Zeghidour, N. Usunier, A. Bordes, L. Denoyer, M'A. Ranzato 179 | 180 | ``` 181 | @inproceedings{lample2017fader, 182 | title={Fader Networks: Manipulating Images by Sliding Attributes}, 183 | author={Lample, Guillaume and Zeghidour, Neil and Usunier, Nicolas and Bordes, Antoine and DENOYER, Ludovic and others}, 184 | booktitle={Advances in Neural Information Processing Systems}, 185 | pages={5963--5972}, 186 | year={2017} 187 | } 188 | ``` 189 | 190 | Contact: [gl@fb.com](mailto:gl@fb.com), [neilz@fb.com](mailto:neilz@fb.com) 191 | -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import json 10 | import argparse 11 | import numpy as np 12 | import torch 13 | 14 | from src.loader import load_images, DataSampler 15 | from src.utils import initialize_exp, bool_flag, attr_flag, check_attr 16 | from src.utils import get_optimizer, reload_model, print_accuracies 17 | from src.model import Classifier 18 | from src.training import classifier_step 19 | from src.evaluation import compute_accuracy 20 | 21 | 22 | # parse parameters 23 | parser = argparse.ArgumentParser(description='Classifier') 24 | parser.add_argument("--name", type=str, default="default", 25 | help="Experiment name") 26 | parser.add_argument("--img_sz", type=int, default=256, 27 | help="Image sizes (images have to be squared)") 28 | parser.add_argument("--img_fm", type=int, default=3, 29 | help="Number of feature maps (1 for grayscale, 3 for RGB)") 30 | parser.add_argument("--attr", type=attr_flag, default="Smiling", 31 | help="Attributes to classify") 32 | parser.add_argument("--init_fm", type=int, default=32, 33 | help="Number of initial filters in the encoder") 34 | parser.add_argument("--max_fm", type=int, default=512, 35 | help="Number maximum of filters in the autoencoder") 36 | parser.add_argument("--hid_dim", type=int, default=512, 37 | help="Last hidden layer dimension") 38 | parser.add_argument("--v_flip", type=bool_flag, default=False, 39 | help="Random vertical flip for data augmentation") 40 | parser.add_argument("--h_flip", type=bool_flag, default=True, 41 | help="Random horizontal flip for data augmentation") 42 | parser.add_argument("--batch_size", type=int, default=32, 43 | help="Batch size") 44 | parser.add_argument("--optimizer", type=str, default="adam", 45 | help="Classifier optimizer (SGD / RMSprop / Adam, etc.)") 46 | parser.add_argument("--clip_grad_norm", type=float, default=5, 47 | help="Clip gradient norms (0 to disable)") 48 | parser.add_argument("--n_epochs", type=int, default=1000, 49 | help="Total number of epochs") 50 | parser.add_argument("--epoch_size", type=int, default=50000, 51 | help="Number of samples per epoch") 52 | parser.add_argument("--reload", type=str, default="", 53 | help="Reload a pretrained classifier") 54 | parser.add_argument("--debug", type=bool_flag, default=False, 55 | help="Debug mode (only load a subset of the whole dataset)") 56 | params = parser.parse_args() 57 | 58 | # check parameters 59 | check_attr(params) 60 | assert len(params.name.strip()) > 0 61 | assert not params.reload or os.path.isfile(params.reload) 62 | 63 | # initialize experiment / load dataset 64 | logger = initialize_exp(params) 65 | data, attributes = load_images(params) 66 | train_data = DataSampler(data[0], attributes[0], params) 67 | valid_data = DataSampler(data[1], attributes[1], params) 68 | test_data = DataSampler(data[2], attributes[2], params) 69 | 70 | # build the model / reload / optimizer 71 | classifier = Classifier(params).cuda() 72 | if params.reload: 73 | reload_model(classifier, params.reload, 74 | ['img_sz', 'img_fm', 'init_fm', 'hid_dim', 'attr', 'n_attr']) 75 | optimizer = get_optimizer(classifier, params.optimizer) 76 | 77 | 78 | def save_model(name): 79 | """ 80 | Save the model. 81 | """ 82 | path = os.path.join(params.dump_path, '%s.pth' % name) 83 | logger.info('Saving the classifier to %s ...' % path) 84 | torch.save(classifier, path) 85 | 86 | 87 | # best accuracy 88 | best_accu = -1e12 89 | 90 | 91 | for n_epoch in range(params.n_epochs): 92 | 93 | logger.info('Starting epoch %i...' % n_epoch) 94 | costs = [] 95 | 96 | classifier.train() 97 | 98 | for n_iter in range(0, params.epoch_size, params.batch_size): 99 | 100 | # classifier training 101 | classifier_step(classifier, optimizer, train_data, params, costs) 102 | 103 | # average loss 104 | if len(costs) >= 25: 105 | logger.info('%06i - Classifier loss: %.5f' % (n_iter, np.mean(costs))) 106 | del costs[:] 107 | 108 | # compute accuracy 109 | valid_accu = compute_accuracy(classifier, valid_data, params) 110 | test_accu = compute_accuracy(classifier, test_data, params) 111 | 112 | # log classifier accuracy 113 | log_accu = [('valid_accu', np.mean(valid_accu)), ('test_accu', np.mean(test_accu))] 114 | for accu, (name, _) in zip(valid_accu, params.attr): 115 | log_accu.append(('valid_accu_%s' % name, accu)) 116 | for accu, (name, _) in zip(test_accu, params.attr): 117 | log_accu.append(('test_accu_%s' % name, accu)) 118 | logger.info('Classifier accuracy:') 119 | print_accuracies(log_accu) 120 | 121 | # JSON log 122 | logger.debug("__log__:%s" % json.dumps(dict([('n_epoch', n_epoch)] + log_accu))) 123 | 124 | # save best or periodic model 125 | if np.mean(valid_accu) > best_accu: 126 | best_accu = np.mean(valid_accu) 127 | logger.info('Best validation average accuracy: %.5f' % best_accu) 128 | save_model('best') 129 | elif n_epoch % 10 == 0 and n_epoch > 0: 130 | save_model('periodic-%i' % n_epoch) 131 | 132 | logger.info('End of epoch %i.\n' % n_epoch) 133 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import matplotlib.image as mpimg 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | 9 | N_IMAGES = 202599 10 | IMG_SIZE = 256 11 | IMG_PATH = 'images_%i_%i.pth' % (IMG_SIZE, IMG_SIZE) 12 | ATTR_PATH = 'attributes.pth' 13 | 14 | 15 | def preprocess_images(): 16 | 17 | if os.path.isfile(IMG_PATH): 18 | print("%s exists, nothing to do." % IMG_PATH) 19 | return 20 | 21 | print("Reading images from img_align_celeba/ ...") 22 | raw_images = [] 23 | for i in range(1, N_IMAGES + 1): 24 | if i % 10000 == 0: 25 | print(i) 26 | raw_images.append(mpimg.imread('img_align_celeba/%06i.jpg' % i)[20:-20]) 27 | 28 | if len(raw_images) != N_IMAGES: 29 | raise Exception("Found %i images. Expected %i" % (len(raw_images), N_IMAGES)) 30 | 31 | print("Resizing images ...") 32 | all_images = [] 33 | for i, image in enumerate(raw_images): 34 | if i % 10000 == 0: 35 | print(i) 36 | assert image.shape == (178, 178, 3) 37 | if IMG_SIZE < 178: 38 | image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA) 39 | elif IMG_SIZE > 178: 40 | image = cv2.resize(image, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LANCZOS4) 41 | assert image.shape == (IMG_SIZE, IMG_SIZE, 3) 42 | all_images.append(image) 43 | 44 | data = np.concatenate([img.transpose((2, 0, 1))[None] for img in all_images], 0) 45 | data = torch.from_numpy(data) 46 | assert data.size() == (N_IMAGES, 3, IMG_SIZE, IMG_SIZE) 47 | 48 | print("Saving images to %s ..." % IMG_PATH) 49 | torch.save(data[:20000].clone(), 'images_%i_%i_20000.pth' % (IMG_SIZE, IMG_SIZE)) 50 | torch.save(data, IMG_PATH) 51 | 52 | 53 | def preprocess_attributes(): 54 | 55 | if os.path.isfile(ATTR_PATH): 56 | print("%s exists, nothing to do." % ATTR_PATH) 57 | return 58 | 59 | attr_lines = [line.rstrip() for line in open('list_attr_celeba.txt', 'r')] 60 | assert len(attr_lines) == N_IMAGES + 2 61 | 62 | attr_keys = attr_lines[1].split() 63 | attributes = {k: np.zeros(N_IMAGES, dtype=np.bool) for k in attr_keys} 64 | 65 | for i, line in enumerate(attr_lines[2:]): 66 | image_id = i + 1 67 | split = line.split() 68 | assert len(split) == 41 69 | assert split[0] == ('%06i.jpg' % image_id) 70 | assert all(x in ['-1', '1'] for x in split[1:]) 71 | for j, value in enumerate(split[1:]): 72 | attributes[attr_keys[j]][i] = value == '1' 73 | 74 | print("Saving attributes to %s ..." % ATTR_PATH) 75 | torch.save(attributes, ATTR_PATH) 76 | 77 | 78 | preprocess_images() 79 | preprocess_attributes() 80 | -------------------------------------------------------------------------------- /images/interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FaderNetworks/cdd9e50659b635a6e04311e1cf4b9a6e6683319b/images/interpolation.jpg -------------------------------------------------------------------------------- /images/multi_attr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FaderNetworks/cdd9e50659b635a6e04311e1cf4b9a6e6683319b/images/multi_attr.jpg -------------------------------------------------------------------------------- /images/swap.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FaderNetworks/cdd9e50659b635a6e04311e1cf4b9a6e6683319b/images/swap.jpg -------------------------------------------------------------------------------- /images/v3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FaderNetworks/cdd9e50659b635a6e04311e1cf4b9a6e6683319b/images/v3.png -------------------------------------------------------------------------------- /interpolate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import argparse 10 | import numpy as np 11 | import torch 12 | from torch.autograd import Variable 13 | from torchvision.utils import make_grid 14 | import matplotlib.image 15 | 16 | from src.logger import create_logger 17 | from src.loader import load_images, DataSampler 18 | from src.utils import bool_flag 19 | 20 | 21 | # parse parameters 22 | parser = argparse.ArgumentParser(description='Attributes swapping') 23 | parser.add_argument("--model_path", type=str, default="", 24 | help="Trained model path") 25 | parser.add_argument("--n_images", type=int, default=10, 26 | help="Number of images to modify") 27 | parser.add_argument("--offset", type=int, default=0, 28 | help="First image index") 29 | parser.add_argument("--n_interpolations", type=int, default=10, 30 | help="Number of interpolations per image") 31 | parser.add_argument("--alpha_min", type=float, default=1, 32 | help="Min interpolation value") 33 | parser.add_argument("--alpha_max", type=float, default=1, 34 | help="Max interpolation value") 35 | parser.add_argument("--plot_size", type=int, default=5, 36 | help="Size of images in the grid") 37 | parser.add_argument("--row_wise", type=bool_flag, default=True, 38 | help="Represent image interpolations horizontally") 39 | parser.add_argument("--output_path", type=str, default="output.png", 40 | help="Output path") 41 | params = parser.parse_args() 42 | 43 | # check parameters 44 | assert os.path.isfile(params.model_path) 45 | assert params.n_images >= 1 and params.n_interpolations >= 2 46 | 47 | # create logger / load trained model 48 | logger = create_logger(None) 49 | ae = torch.load(params.model_path).eval() 50 | 51 | # restore main parameters 52 | params.debug = True 53 | params.batch_size = 32 54 | params.v_flip = False 55 | params.h_flip = False 56 | params.img_sz = ae.img_sz 57 | params.attr = ae.attr 58 | params.n_attr = ae.n_attr 59 | if not (len(params.attr) == 1 and params.n_attr == 2): 60 | raise Exception("The model must use a single boolean attribute only.") 61 | 62 | # load dataset 63 | data, attributes = load_images(params) 64 | test_data = DataSampler(data[2], attributes[2], params) 65 | 66 | 67 | def get_interpolations(ae, images, attributes, params): 68 | """ 69 | Reconstruct images / create interpolations 70 | """ 71 | assert len(images) == len(attributes) 72 | enc_outputs = ae.encode(images) 73 | 74 | # interpolation values 75 | alphas = np.linspace(1 - params.alpha_min, params.alpha_max, params.n_interpolations) 76 | alphas = [torch.FloatTensor([1 - alpha, alpha]) for alpha in alphas] 77 | 78 | # original image / reconstructed image / interpolations 79 | outputs = [] 80 | outputs.append(images) 81 | outputs.append(ae.decode(enc_outputs, attributes)[-1]) 82 | for alpha in alphas: 83 | alpha = Variable(alpha.unsqueeze(0).expand((len(images), 2)).cuda()) 84 | outputs.append(ae.decode(enc_outputs, alpha)[-1]) 85 | 86 | # return stacked images 87 | return torch.cat([x.unsqueeze(1) for x in outputs], 1).data.cpu() 88 | 89 | 90 | interpolations = [] 91 | 92 | for k in range(0, params.n_images, 100): 93 | i = params.offset + k 94 | j = params.offset + min(params.n_images, k + 100) 95 | images, attributes = test_data.eval_batch(i, j) 96 | interpolations.append(get_interpolations(ae, images, attributes, params)) 97 | 98 | interpolations = torch.cat(interpolations, 0) 99 | assert interpolations.size() == (params.n_images, 2 + params.n_interpolations, 100 | 3, params.img_sz, params.img_sz) 101 | 102 | 103 | def get_grid(images, row_wise, plot_size=5): 104 | """ 105 | Create a grid with all images. 106 | """ 107 | n_images, n_columns, img_fm, img_sz, _ = images.size() 108 | if not row_wise: 109 | images = images.transpose(0, 1).contiguous() 110 | images = images.view(n_images * n_columns, img_fm, img_sz, img_sz) 111 | images.add_(1).div_(2.0) 112 | return make_grid(images, nrow=(n_columns if row_wise else n_images)) 113 | 114 | 115 | # generate the grid / save it to a PNG file 116 | grid = get_grid(interpolations, params.row_wise, params.plot_size) 117 | matplotlib.image.imsave(params.output_path, grid.numpy().transpose((1, 2, 0))) 118 | -------------------------------------------------------------------------------- /models/download.sh: -------------------------------------------------------------------------------- 1 | curl -Lo classifier128.pth https://dl.fbaipublicfiles.com/FaderNetworks/classifier128.pth 2 | curl -Lo classifier256.pth https://dl.fbaipublicfiles.com/FaderNetworks/classifier256.pth 3 | curl -Lo eyeglasses.pth https://dl.fbaipublicfiles.com/FaderNetworks/eyeglasses.pth 4 | curl -Lo male.pth https://dl.fbaipublicfiles.com/FaderNetworks/male.pth 5 | curl -Lo narrow_eyes.pth https://dl.fbaipublicfiles.com/FaderNetworks/narrow_eyes.pth 6 | curl -Lo young.pth https://dl.fbaipublicfiles.com/FaderNetworks/young.pth 7 | curl -Lo pointy_nose.pth https://dl.fbaipublicfiles.com/FaderNetworks/pointy_nose.pth 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/FaderNetworks/cdd9e50659b635a6e04311e1cf4b9a6e6683319b/src/__init__.py -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import json 9 | import numpy as np 10 | from logging import getLogger 11 | 12 | from .model import update_predictions, flip_attributes 13 | from .utils import print_accuracies 14 | 15 | 16 | logger = getLogger() 17 | 18 | 19 | class Evaluator(object): 20 | 21 | def __init__(self, ae, lat_dis, ptc_dis, clf_dis, eval_clf, data, params): 22 | """ 23 | Evaluator initialization. 24 | """ 25 | # data / parameters 26 | self.data = data 27 | self.params = params 28 | 29 | # modules 30 | self.ae = ae 31 | self.lat_dis = lat_dis 32 | self.ptc_dis = ptc_dis 33 | self.clf_dis = clf_dis 34 | self.eval_clf = eval_clf 35 | assert eval_clf.img_sz == params.img_sz 36 | assert all(attr in eval_clf.attr for attr in params.attr) 37 | 38 | def eval_reconstruction_loss(self): 39 | """ 40 | Compute the autoencoder reconstruction perplexity. 41 | """ 42 | data = self.data 43 | params = self.params 44 | self.ae.eval() 45 | bs = params.batch_size 46 | 47 | costs = [] 48 | for i in range(0, len(data), bs): 49 | batch_x, batch_y = data.eval_batch(i, i + bs) 50 | _, dec_outputs = self.ae(batch_x, batch_y) 51 | costs.append(((dec_outputs[-1] - batch_x) ** 2).mean().data[0]) 52 | 53 | return np.mean(costs) 54 | 55 | def eval_lat_dis_accuracy(self): 56 | """ 57 | Compute the latent discriminator prediction accuracy. 58 | """ 59 | data = self.data 60 | params = self.params 61 | self.ae.eval() 62 | self.lat_dis.eval() 63 | bs = params.batch_size 64 | 65 | all_preds = [[] for _ in range(len(params.attr))] 66 | for i in range(0, len(data), bs): 67 | batch_x, batch_y = data.eval_batch(i, i + bs) 68 | enc_outputs = self.ae.encode(batch_x) 69 | preds = self.lat_dis(enc_outputs[-1 - params.n_skip]).data.cpu() 70 | update_predictions(all_preds, preds, batch_y.data.cpu(), params) 71 | 72 | return [np.mean(x) for x in all_preds] 73 | 74 | def eval_ptc_dis_accuracy(self): 75 | """ 76 | Compute the patch discriminator prediction accuracy. 77 | """ 78 | data = self.data 79 | params = self.params 80 | self.ae.eval() 81 | self.ptc_dis.eval() 82 | bs = params.batch_size 83 | 84 | real_preds = [] 85 | fake_preds = [] 86 | 87 | for i in range(0, len(data), bs): 88 | # batch / encode / decode 89 | batch_x, batch_y = data.eval_batch(i, i + bs) 90 | flipped = flip_attributes(batch_y, params, 'all') 91 | _, dec_outputs = self.ae(batch_x, flipped) 92 | # predictions 93 | real_preds.extend(self.ptc_dis(batch_x).data.tolist()) 94 | fake_preds.extend(self.ptc_dis(dec_outputs[-1]).data.tolist()) 95 | 96 | return real_preds, fake_preds 97 | 98 | def eval_clf_dis_accuracy(self): 99 | """ 100 | Compute the classifier discriminator prediction accuracy. 101 | """ 102 | data = self.data 103 | params = self.params 104 | self.ae.eval() 105 | self.clf_dis.eval() 106 | bs = params.batch_size 107 | 108 | all_preds = [[] for _ in range(params.n_attr)] 109 | for i in range(0, len(data), bs): 110 | # batch / encode / decode 111 | batch_x, batch_y = data.eval_batch(i, i + bs) 112 | enc_outputs = self.ae.encode(batch_x) 113 | # flip all attributes one by one 114 | k = 0 115 | for j, (_, n_cat) in enumerate(params.attr): 116 | for value in range(n_cat): 117 | flipped = flip_attributes(batch_y, params, j, new_value=value) 118 | dec_outputs = self.ae.decode(enc_outputs, flipped) 119 | # classify 120 | clf_dis_preds = self.clf_dis(dec_outputs[-1])[:, j:j + n_cat].max(1)[1].view(-1) 121 | all_preds[k].extend((clf_dis_preds.data.cpu() == value).tolist()) 122 | k += 1 123 | assert k == params.n_attr 124 | 125 | return [np.mean(x) for x in all_preds] 126 | 127 | def eval_clf_accuracy(self): 128 | """ 129 | Compute the accuracy of flipped attributes according to the trained classifier. 130 | """ 131 | data = self.data 132 | params = self.params 133 | self.ae.eval() 134 | bs = params.batch_size 135 | 136 | idx = [] 137 | for j in range(len(params.attr)): 138 | attr_index = self.eval_clf.attr.index(params.attr[j]) 139 | idx.append(sum([x[1] for x in self.eval_clf.attr[:attr_index]])) 140 | 141 | all_preds = [[] for _ in range(params.n_attr)] 142 | for i in range(0, len(data), bs): 143 | # batch / encode / decode 144 | batch_x, batch_y = data.eval_batch(i, i + bs) 145 | enc_outputs = self.ae.encode(batch_x) 146 | # flip all attributes one by one 147 | k = 0 148 | for j, (_, n_cat) in enumerate(params.attr): 149 | for value in range(n_cat): 150 | flipped = flip_attributes(batch_y, params, j, new_value=value) 151 | dec_outputs = self.ae.decode(enc_outputs, flipped) 152 | # classify 153 | clf_preds = self.eval_clf(dec_outputs[-1])[:, idx[j]:idx[j] + n_cat].max(1)[1].view(-1) 154 | all_preds[k].extend((clf_preds.data.cpu() == value).tolist()) 155 | k += 1 156 | assert k == params.n_attr 157 | 158 | return [np.mean(x) for x in all_preds] 159 | 160 | def evaluate(self, n_epoch): 161 | """ 162 | Evaluate all models / log evaluation results. 163 | """ 164 | params = self.params 165 | logger.info('') 166 | 167 | # reconstruction loss 168 | ae_loss = self.eval_reconstruction_loss() 169 | 170 | # latent discriminator accuracy 171 | log_lat_dis = [] 172 | if params.n_lat_dis: 173 | lat_dis_accu = self.eval_lat_dis_accuracy() 174 | log_lat_dis.append(('lat_dis_accu', np.mean(lat_dis_accu))) 175 | for accu, (name, _) in zip(lat_dis_accu, params.attr): 176 | log_lat_dis.append(('lat_dis_accu_%s' % name, accu)) 177 | logger.info('Latent discriminator accuracy:') 178 | print_accuracies(log_lat_dis) 179 | 180 | # patch discriminator accuracy 181 | log_ptc_dis = [] 182 | if params.n_ptc_dis: 183 | ptc_dis_real_preds, ptc_dis_fake_preds = self.eval_ptc_dis_accuracy() 184 | accu_real = (np.array(ptc_dis_real_preds).astype(np.float32) >= 0.5).mean() 185 | accu_fake = (np.array(ptc_dis_fake_preds).astype(np.float32) <= 0.5).mean() 186 | log_ptc_dis.append(('ptc_dis_preds_real', np.mean(ptc_dis_real_preds))) 187 | log_ptc_dis.append(('ptc_dis_preds_fake', np.mean(ptc_dis_fake_preds))) 188 | log_ptc_dis.append(('ptc_dis_accu_real', accu_real)) 189 | log_ptc_dis.append(('ptc_dis_accu_fake', accu_fake)) 190 | log_ptc_dis.append(('ptc_dis_accu', (accu_real + accu_fake) / 2)) 191 | logger.info('Patch discriminator accuracy:') 192 | print_accuracies(log_ptc_dis) 193 | 194 | # classifier discriminator accuracy 195 | log_clf_dis = [] 196 | if params.n_clf_dis: 197 | clf_dis_accu = self.eval_clf_dis_accuracy() 198 | k = 0 199 | log_clf_dis += [('clf_dis_accu', np.mean(clf_dis_accu))] 200 | for name, n_cat in params.attr: 201 | log_clf_dis.append(('clf_dis_accu_%s' % name, np.mean(clf_dis_accu[k:k + n_cat]))) 202 | log_clf_dis.extend([('clf_dis_accu_%s_%i' % (name, j), clf_dis_accu[k + j]) 203 | for j in range(n_cat)]) 204 | k += n_cat 205 | logger.info('Classifier discriminator accuracy:') 206 | print_accuracies(log_clf_dis) 207 | 208 | # classifier accuracy 209 | log_clf = [] 210 | clf_accu = self.eval_clf_accuracy() 211 | k = 0 212 | log_clf += [('clf_accu', np.mean(clf_accu))] 213 | for name, n_cat in params.attr: 214 | log_clf.append(('clf_accu_%s' % name, np.mean(clf_accu[k:k + n_cat]))) 215 | log_clf.extend([('clf_accu_%s_%i' % (name, j), clf_accu[k + j]) 216 | for j in range(n_cat)]) 217 | k += n_cat 218 | logger.info('Classifier accuracy:') 219 | print_accuracies(log_clf) 220 | 221 | # log autoencoder loss 222 | logger.info('Autoencoder loss: %.5f' % ae_loss) 223 | 224 | # JSON log 225 | to_log = dict([ 226 | ('n_epoch', n_epoch), 227 | ('ae_loss', ae_loss) 228 | ] + log_lat_dis + log_ptc_dis + log_clf_dis + log_clf) 229 | logger.debug("__log__:%s" % json.dumps(to_log)) 230 | 231 | return to_log 232 | 233 | 234 | def compute_accuracy(classifier, data, params): 235 | """ 236 | Compute the classifier prediction accuracy. 237 | """ 238 | classifier.eval() 239 | bs = params.batch_size 240 | 241 | all_preds = [[] for _ in range(len(classifier.attr))] 242 | for i in range(0, len(data), bs): 243 | batch_x, batch_y = data.eval_batch(i, i + bs) 244 | preds = classifier(batch_x).data.cpu() 245 | update_predictions(all_preds, preds, batch_y.data.cpu(), params) 246 | 247 | return [np.mean(x) for x in all_preds] 248 | -------------------------------------------------------------------------------- /src/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | from logging import getLogger 13 | 14 | 15 | logger = getLogger() 16 | 17 | 18 | AVAILABLE_ATTR = [ 19 | "5_o_Clock_Shadow", "Arched_Eyebrows", "Attractive", "Bags_Under_Eyes", "Bald", 20 | "Bangs", "Big_Lips", "Big_Nose", "Black_Hair", "Blond_Hair", "Blurry", "Brown_Hair", 21 | "Bushy_Eyebrows", "Chubby", "Double_Chin", "Eyeglasses", "Goatee", "Gray_Hair", 22 | "Heavy_Makeup", "High_Cheekbones", "Male", "Mouth_Slightly_Open", "Mustache", 23 | "Narrow_Eyes", "No_Beard", "Oval_Face", "Pale_Skin", "Pointy_Nose", 24 | "Receding_Hairline", "Rosy_Cheeks", "Sideburns", "Smiling", "Straight_Hair", 25 | "Wavy_Hair", "Wearing_Earrings", "Wearing_Hat", "Wearing_Lipstick", 26 | "Wearing_Necklace", "Wearing_Necktie", "Young" 27 | ] 28 | 29 | DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data') 30 | 31 | 32 | def log_attributes_stats(train_attributes, valid_attributes, test_attributes, params): 33 | """ 34 | Log attributes distributions. 35 | """ 36 | k = 0 37 | for (attr_name, n_cat) in params.attr: 38 | logger.debug('Train %s: %s' % (attr_name, ' / '.join(['%.5f' % train_attributes[:, k + i].mean() for i in range(n_cat)]))) 39 | logger.debug('Valid %s: %s' % (attr_name, ' / '.join(['%.5f' % valid_attributes[:, k + i].mean() for i in range(n_cat)]))) 40 | logger.debug('Test %s: %s' % (attr_name, ' / '.join(['%.5f' % test_attributes[:, k + i].mean() for i in range(n_cat)]))) 41 | assert train_attributes[:, k:k + n_cat].sum() == train_attributes.size(0) 42 | assert valid_attributes[:, k:k + n_cat].sum() == valid_attributes.size(0) 43 | assert test_attributes[:, k:k + n_cat].sum() == test_attributes.size(0) 44 | k += n_cat 45 | assert k == params.n_attr 46 | 47 | 48 | def load_images(params): 49 | """ 50 | Load celebA dataset. 51 | """ 52 | # load data 53 | images_filename = 'images_%i_%i_20000.pth' if params.debug else 'images_%i_%i.pth' 54 | images_filename = images_filename % (params.img_sz, params.img_sz) 55 | images = torch.load(os.path.join(DATA_PATH, images_filename)) 56 | attributes = torch.load(os.path.join(DATA_PATH, 'attributes.pth')) 57 | 58 | # parse attributes 59 | attrs = [] 60 | for name, n_cat in params.attr: 61 | for i in range(n_cat): 62 | attrs.append(torch.FloatTensor((attributes[name] == i).astype(np.float32))) 63 | attributes = torch.cat([x.unsqueeze(1) for x in attrs], 1) 64 | # split train / valid / test 65 | if params.debug: 66 | train_index = 10000 67 | valid_index = 15000 68 | test_index = 20000 69 | else: 70 | train_index = 162770 71 | valid_index = 162770 + 19867 72 | test_index = len(images) 73 | train_images = images[:train_index] 74 | valid_images = images[train_index:valid_index] 75 | test_images = images[valid_index:test_index] 76 | train_attributes = attributes[:train_index] 77 | valid_attributes = attributes[train_index:valid_index] 78 | test_attributes = attributes[valid_index:test_index] 79 | # log dataset statistics / return dataset 80 | logger.info('%i / %i / %i images with attributes for train / valid / test sets' 81 | % (len(train_images), len(valid_images), len(test_images))) 82 | log_attributes_stats(train_attributes, valid_attributes, test_attributes, params) 83 | images = train_images, valid_images, test_images 84 | attributes = train_attributes, valid_attributes, test_attributes 85 | return images, attributes 86 | 87 | 88 | def normalize_images(images): 89 | """ 90 | Normalize image values. 91 | """ 92 | return images.float().div_(255.0).mul_(2.0).add_(-1) 93 | 94 | 95 | class DataSampler(object): 96 | 97 | def __init__(self, images, attributes, params): 98 | """ 99 | Initialize the data sampler with training data. 100 | """ 101 | assert images.size(0) == attributes.size(0), (images.size(), attributes.size()) 102 | self.images = images 103 | self.attributes = attributes 104 | self.batch_size = params.batch_size 105 | self.v_flip = params.v_flip 106 | self.h_flip = params.h_flip 107 | 108 | def __len__(self): 109 | """ 110 | Number of images in the object dataset. 111 | """ 112 | return self.images.size(0) 113 | 114 | def train_batch(self, bs): 115 | """ 116 | Get a batch of random images with their attributes. 117 | """ 118 | # image IDs 119 | idx = torch.LongTensor(bs).random_(len(self.images)) 120 | 121 | # select images / attributes 122 | batch_x = normalize_images(self.images.index_select(0, idx).cuda()) 123 | batch_y = self.attributes.index_select(0, idx).cuda() 124 | 125 | # data augmentation 126 | if self.v_flip and np.random.rand() <= 0.5: 127 | batch_x = batch_x.index_select(2, torch.arange(batch_x.size(2) - 1, -1, -1).long().cuda()) 128 | if self.h_flip and np.random.rand() <= 0.5: 129 | batch_x = batch_x.index_select(3, torch.arange(batch_x.size(3) - 1, -1, -1).long().cuda()) 130 | 131 | return Variable(batch_x, volatile=False), Variable(batch_y, volatile=False) 132 | 133 | def eval_batch(self, i, j): 134 | """ 135 | Get a batch of images in a range with their attributes. 136 | """ 137 | assert i < j 138 | batch_x = normalize_images(self.images[i:j].cuda()) 139 | batch_y = self.attributes[i:j].cuda() 140 | return Variable(batch_x, volatile=True), Variable(batch_y, volatile=True) 141 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import logging 9 | import time 10 | from datetime import timedelta 11 | 12 | 13 | class LogFormatter(): 14 | 15 | def __init__(self): 16 | self.start_time = time.time() 17 | 18 | def format(self, record): 19 | elapsed_seconds = round(record.created - self.start_time) 20 | 21 | prefix = "%s - %s - %s" % ( 22 | record.levelname, 23 | time.strftime('%x %X'), 24 | timedelta(seconds=elapsed_seconds) 25 | ) 26 | message = record.getMessage() 27 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 28 | return "%s - %s" % (prefix, message) 29 | 30 | 31 | def create_logger(filepath): 32 | """ 33 | Create a logger. 34 | """ 35 | # create log formatter 36 | log_formatter = LogFormatter() 37 | 38 | # create file handler and set level to debug 39 | if filepath is not None: 40 | file_handler = logging.FileHandler(filepath, "a") 41 | file_handler.setLevel(logging.DEBUG) 42 | file_handler.setFormatter(log_formatter) 43 | 44 | # create console handler and set level to info 45 | console_handler = logging.StreamHandler() 46 | console_handler.setLevel(logging.INFO) 47 | console_handler.setFormatter(log_formatter) 48 | 49 | # create logger and set level to debug 50 | logger = logging.getLogger() 51 | logger.handlers = [] 52 | logger.setLevel(logging.DEBUG) 53 | logger.propagate = False 54 | if filepath is not None: 55 | logger.addHandler(file_handler) 56 | logger.addHandler(console_handler) 57 | 58 | # reset logger elapsed time 59 | def reset_time(): 60 | log_formatter.start_time = time.time() 61 | logger.reset_time = reset_time 62 | 63 | return logger 64 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | from torch.autograd import Variable 12 | from torch.nn import functional as F 13 | 14 | 15 | def build_layers(img_sz, img_fm, init_fm, max_fm, n_layers, n_attr, n_skip, 16 | deconv_method, instance_norm, enc_dropout, dec_dropout): 17 | """ 18 | Build auto-encoder layers. 19 | """ 20 | assert init_fm <= max_fm 21 | assert n_skip <= n_layers - 1 22 | assert np.log2(img_sz).is_integer() 23 | assert n_layers <= int(np.log2(img_sz)) 24 | assert type(instance_norm) is bool 25 | assert 0 <= enc_dropout < 1 26 | assert 0 <= dec_dropout < 1 27 | norm_fn = nn.InstanceNorm2d if instance_norm else nn.BatchNorm2d 28 | 29 | enc_layers = [] 30 | dec_layers = [] 31 | 32 | n_in = img_fm 33 | n_out = init_fm 34 | 35 | for i in range(n_layers): 36 | enc_layer = [] 37 | dec_layer = [] 38 | skip_connection = n_layers - (n_skip + 1) <= i < n_layers - 1 39 | n_dec_in = n_out + n_attr + (n_out if skip_connection else 0) 40 | n_dec_out = n_in 41 | 42 | # encoder layer 43 | enc_layer.append(nn.Conv2d(n_in, n_out, 4, 2, 1)) 44 | if i > 0: 45 | enc_layer.append(norm_fn(n_out, affine=True)) 46 | enc_layer.append(nn.LeakyReLU(0.2, inplace=True)) 47 | if enc_dropout > 0: 48 | enc_layer.append(nn.Dropout(enc_dropout)) 49 | 50 | # decoder layer 51 | if deconv_method == 'upsampling': 52 | dec_layer.append(nn.UpsamplingNearest2d(scale_factor=2)) 53 | dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out, 3, 1, 1)) 54 | elif deconv_method == 'convtranspose': 55 | dec_layer.append(nn.ConvTranspose2d(n_dec_in, n_dec_out, 4, 2, 1, bias=False)) 56 | else: 57 | assert deconv_method == 'pixelshuffle' 58 | dec_layer.append(nn.Conv2d(n_dec_in, n_dec_out * 4, 3, 1, 1)) 59 | dec_layer.append(nn.PixelShuffle(2)) 60 | if i > 0: 61 | dec_layer.append(norm_fn(n_dec_out, affine=True)) 62 | if dec_dropout > 0 and i >= n_layers - 3: 63 | dec_layer.append(nn.Dropout(dec_dropout)) 64 | dec_layer.append(nn.ReLU(inplace=True)) 65 | else: 66 | dec_layer.append(nn.Tanh()) 67 | 68 | # update 69 | n_in = n_out 70 | n_out = min(2 * n_out, max_fm) 71 | enc_layers.append(nn.Sequential(*enc_layer)) 72 | dec_layers.insert(0, nn.Sequential(*dec_layer)) 73 | 74 | return enc_layers, dec_layers 75 | 76 | 77 | class AutoEncoder(nn.Module): 78 | 79 | def __init__(self, params): 80 | super(AutoEncoder, self).__init__() 81 | 82 | self.img_sz = params.img_sz 83 | self.img_fm = params.img_fm 84 | self.instance_norm = params.instance_norm 85 | self.init_fm = params.init_fm 86 | self.max_fm = params.max_fm 87 | self.n_layers = params.n_layers 88 | self.n_skip = params.n_skip 89 | self.deconv_method = params.deconv_method 90 | self.dropout = params.dec_dropout 91 | self.attr = params.attr 92 | self.n_attr = params.n_attr 93 | 94 | enc_layers, dec_layers = build_layers(self.img_sz, self.img_fm, self.init_fm, 95 | self.max_fm, self.n_layers, self.n_attr, 96 | self.n_skip, self.deconv_method, 97 | self.instance_norm, 0, self.dropout) 98 | self.enc_layers = nn.ModuleList(enc_layers) 99 | self.dec_layers = nn.ModuleList(dec_layers) 100 | 101 | def encode(self, x): 102 | assert x.size()[1:] == (self.img_fm, self.img_sz, self.img_sz) 103 | 104 | enc_outputs = [x] 105 | for layer in self.enc_layers: 106 | enc_outputs.append(layer(enc_outputs[-1])) 107 | 108 | assert len(enc_outputs) == self.n_layers + 1 109 | return enc_outputs 110 | 111 | def decode(self, enc_outputs, y): 112 | bs = enc_outputs[0].size(0) 113 | assert len(enc_outputs) == self.n_layers + 1 114 | assert y.size() == (bs, self.n_attr) 115 | 116 | dec_outputs = [enc_outputs[-1]] 117 | y = y.unsqueeze(2).unsqueeze(3) 118 | for i, layer in enumerate(self.dec_layers): 119 | size = dec_outputs[-1].size(2) 120 | # attributes 121 | input = [dec_outputs[-1], y.expand(bs, self.n_attr, size, size)] 122 | # skip connection 123 | if 0 < i <= self.n_skip: 124 | input.append(enc_outputs[-1 - i]) 125 | input = torch.cat(input, 1) 126 | dec_outputs.append(layer(input)) 127 | 128 | assert len(dec_outputs) == self.n_layers + 1 129 | assert dec_outputs[-1].size() == (bs, self.img_fm, self.img_sz, self.img_sz) 130 | return dec_outputs 131 | 132 | def forward(self, x, y): 133 | enc_outputs = self.encode(x) 134 | dec_outputs = self.decode(enc_outputs, y) 135 | return enc_outputs, dec_outputs 136 | 137 | 138 | class LatentDiscriminator(nn.Module): 139 | 140 | def __init__(self, params): 141 | super(LatentDiscriminator, self).__init__() 142 | 143 | self.img_sz = params.img_sz 144 | self.img_fm = params.img_fm 145 | self.init_fm = params.init_fm 146 | self.max_fm = params.max_fm 147 | self.n_layers = params.n_layers 148 | self.n_skip = params.n_skip 149 | self.hid_dim = params.hid_dim 150 | self.dropout = params.lat_dis_dropout 151 | self.attr = params.attr 152 | self.n_attr = params.n_attr 153 | 154 | self.n_dis_layers = int(np.log2(self.img_sz)) 155 | self.conv_in_sz = self.img_sz / (2 ** (self.n_layers - self.n_skip)) 156 | self.conv_in_fm = min(self.init_fm * (2 ** (self.n_layers - self.n_skip - 1)), self.max_fm) 157 | self.conv_out_fm = min(self.init_fm * (2 ** (self.n_dis_layers - 1)), self.max_fm) 158 | 159 | # discriminator layers are identical to encoder, but convolve until size 1 160 | enc_layers, _ = build_layers(self.img_sz, self.img_fm, self.init_fm, self.max_fm, 161 | self.n_dis_layers, self.n_attr, 0, 'convtranspose', 162 | False, self.dropout, 0) 163 | 164 | self.conv_layers = nn.Sequential(*(enc_layers[self.n_layers - self.n_skip:])) 165 | self.proj_layers = nn.Sequential( 166 | nn.Linear(self.conv_out_fm, self.hid_dim), 167 | nn.LeakyReLU(0.2, inplace=True), 168 | nn.Linear(self.hid_dim, self.n_attr) 169 | ) 170 | 171 | def forward(self, x): 172 | assert x.size()[1:] == (self.conv_in_fm, self.conv_in_sz, self.conv_in_sz) 173 | conv_output = self.conv_layers(x) 174 | assert conv_output.size() == (x.size(0), self.conv_out_fm, 1, 1) 175 | return self.proj_layers(conv_output.view(x.size(0), self.conv_out_fm)) 176 | 177 | 178 | class PatchDiscriminator(nn.Module): 179 | def __init__(self, params): 180 | super(PatchDiscriminator, self).__init__() 181 | 182 | self.img_sz = params.img_sz 183 | self.img_fm = params.img_fm 184 | self.init_fm = params.init_fm 185 | self.max_fm = params.max_fm 186 | self.n_patch_dis_layers = 3 187 | 188 | layers = [] 189 | layers.append(nn.Conv2d(self.img_fm, self.init_fm, kernel_size=4, stride=2, padding=1)) 190 | layers.append(nn.LeakyReLU(0.2, True)) 191 | 192 | n_in = self.init_fm 193 | n_out = min(2 * n_in, self.max_fm) 194 | 195 | for n in range(self.n_patch_dis_layers): 196 | stride = 1 if n == self.n_patch_dis_layers - 1 else 2 197 | layers.append(nn.Conv2d(n_in, n_out, kernel_size=4, stride=stride, padding=1)) 198 | layers.append(nn.BatchNorm2d(n_out)) 199 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 200 | if n < self.n_patch_dis_layers - 1: 201 | n_in = n_out 202 | n_out = min(2 * n_out, self.max_fm) 203 | 204 | layers.append(nn.Conv2d(n_out, 1, kernel_size=4, stride=1, padding=1)) 205 | layers.append(nn.Sigmoid()) 206 | 207 | self.layers = nn.Sequential(*layers) 208 | 209 | def forward(self, x): 210 | assert x.dim() == 4 211 | return self.layers(x).view(x.size(0), -1).mean(1).view(x.size(0)) 212 | 213 | 214 | class Classifier(nn.Module): 215 | 216 | def __init__(self, params): 217 | super(Classifier, self).__init__() 218 | 219 | self.img_sz = params.img_sz 220 | self.img_fm = params.img_fm 221 | self.init_fm = params.init_fm 222 | self.max_fm = params.max_fm 223 | self.hid_dim = params.hid_dim 224 | self.attr = params.attr 225 | self.n_attr = params.n_attr 226 | 227 | self.n_clf_layers = int(np.log2(self.img_sz)) 228 | self.conv_out_fm = min(self.init_fm * (2 ** (self.n_clf_layers - 1)), self.max_fm) 229 | 230 | # classifier layers are identical to encoder, but convolve until size 1 231 | enc_layers, _ = build_layers(self.img_sz, self.img_fm, self.init_fm, self.max_fm, 232 | self.n_clf_layers, self.n_attr, 0, 'convtranspose', 233 | False, 0, 0) 234 | 235 | self.conv_layers = nn.Sequential(*enc_layers) 236 | self.proj_layers = nn.Sequential( 237 | nn.Linear(self.conv_out_fm, self.hid_dim), 238 | nn.LeakyReLU(0.2, inplace=True), 239 | nn.Linear(self.hid_dim, self.n_attr) 240 | ) 241 | 242 | def forward(self, x): 243 | assert x.size()[1:] == (self.img_fm, self.img_sz, self.img_sz) 244 | conv_output = self.conv_layers(x) 245 | assert conv_output.size() == (x.size(0), self.conv_out_fm, 1, 1) 246 | return self.proj_layers(conv_output.view(x.size(0), self.conv_out_fm)) 247 | 248 | 249 | def get_attr_loss(output, attributes, flip, params): 250 | """ 251 | Compute attributes loss. 252 | """ 253 | assert type(flip) is bool 254 | k = 0 255 | loss = 0 256 | for (_, n_cat) in params.attr: 257 | # categorical 258 | x = output[:, k:k + n_cat].contiguous() 259 | y = attributes[:, k:k + n_cat].max(1)[1].view(-1) 260 | if flip: 261 | # generate different categories 262 | shift = torch.LongTensor(y.size()).random_(n_cat - 1) + 1 263 | y = (y + Variable(shift.cuda())) % n_cat 264 | loss += F.cross_entropy(x, y) 265 | k += n_cat 266 | return loss 267 | 268 | 269 | def update_predictions(all_preds, preds, targets, params): 270 | """ 271 | Update discriminator / classifier predictions. 272 | """ 273 | assert len(all_preds) == len(params.attr) 274 | k = 0 275 | for j, (_, n_cat) in enumerate(params.attr): 276 | _preds = preds[:, k:k + n_cat].max(1)[1] 277 | _targets = targets[:, k:k + n_cat].max(1)[1] 278 | all_preds[j].extend((_preds == _targets).tolist()) 279 | k += n_cat 280 | assert k == params.n_attr 281 | 282 | 283 | def get_mappings(params): 284 | """ 285 | Create a mapping between attributes and their associated IDs. 286 | """ 287 | if not hasattr(params, 'mappings'): 288 | mappings = [] 289 | k = 0 290 | for (_, n_cat) in params.attr: 291 | assert n_cat >= 2 292 | mappings.append((k, k + n_cat)) 293 | k += n_cat 294 | assert k == params.n_attr 295 | params.mappings = mappings 296 | return params.mappings 297 | 298 | 299 | def flip_attributes(attributes, params, attribute_id, new_value=None): 300 | """ 301 | Randomly flip a set of attributes. 302 | """ 303 | assert attributes.size(1) == params.n_attr 304 | mappings = get_mappings(params) 305 | attributes = attributes.data.clone().cpu() 306 | 307 | def flip_attribute(attribute_id, new_value=None): 308 | bs = attributes.size(0) 309 | i, j = mappings[attribute_id] 310 | attributes[:, i:j].zero_() 311 | if new_value is None: 312 | y = torch.LongTensor(bs).random_(j - i) 313 | else: 314 | assert new_value in range(j - i) 315 | y = torch.LongTensor(bs).fill_(new_value) 316 | attributes[:, i:j].scatter_(1, y.unsqueeze(1), 1) 317 | 318 | if attribute_id == 'all': 319 | assert new_value is None 320 | for attribute_id in range(len(params.attr)): 321 | flip_attribute(attribute_id) 322 | else: 323 | assert type(new_value) is int 324 | flip_attribute(attribute_id, new_value) 325 | 326 | return Variable(attributes.cuda()) 327 | -------------------------------------------------------------------------------- /src/training.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch.autograd import Variable 12 | from torch.nn import functional as F 13 | from logging import getLogger 14 | 15 | from .utils import get_optimizer, clip_grad_norm, get_lambda, reload_model 16 | from .model import get_attr_loss, flip_attributes 17 | 18 | 19 | logger = getLogger() 20 | 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self, ae, lat_dis, ptc_dis, clf_dis, data, params): 25 | """ 26 | Trainer initialization. 27 | """ 28 | # data / parameters 29 | self.data = data 30 | self.params = params 31 | 32 | # modules 33 | self.ae = ae 34 | self.lat_dis = lat_dis 35 | self.ptc_dis = ptc_dis 36 | self.clf_dis = clf_dis 37 | 38 | # optimizers 39 | self.ae_optimizer = get_optimizer(ae, params.ae_optimizer) 40 | logger.info(ae) 41 | logger.info('%i parameters in the autoencoder. ' 42 | % sum([p.nelement() for p in ae.parameters()])) 43 | if params.n_lat_dis: 44 | logger.info(lat_dis) 45 | logger.info('%i parameters in the latent discriminator. ' 46 | % sum([p.nelement() for p in lat_dis.parameters()])) 47 | self.lat_dis_optimizer = get_optimizer(lat_dis, params.dis_optimizer) 48 | if params.n_ptc_dis: 49 | logger.info(ptc_dis) 50 | logger.info('%i parameters in the patch discriminator. ' 51 | % sum([p.nelement() for p in ptc_dis.parameters()])) 52 | self.ptc_dis_optimizer = get_optimizer(ptc_dis, params.dis_optimizer) 53 | if params.n_clf_dis: 54 | logger.info(clf_dis) 55 | logger.info('%i parameters in the classifier discriminator. ' 56 | % sum([p.nelement() for p in clf_dis.parameters()])) 57 | self.clf_dis_optimizer = get_optimizer(clf_dis, params.dis_optimizer) 58 | 59 | # reload pretrained models 60 | if params.ae_reload: 61 | reload_model(ae, params.ae_reload, 62 | ['img_sz', 'img_fm', 'init_fm', 'n_layers', 'n_skip', 'attr', 'n_attr']) 63 | if params.lat_dis_reload: 64 | reload_model(lat_dis, params.lat_dis_reload, 65 | ['enc_dim', 'attr', 'n_attr']) 66 | if params.ptc_dis_reload: 67 | reload_model(ptc_dis, params.ptc_dis_reload, 68 | ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'n_patch_dis_layers']) 69 | if params.clf_dis_reload: 70 | reload_model(clf_dis, params.clf_dis_reload, 71 | ['img_sz', 'img_fm', 'init_fm', 'max_fm', 'hid_dim', 'attr', 'n_attr']) 72 | 73 | # training statistics 74 | self.stats = {} 75 | self.stats['rec_costs'] = [] 76 | self.stats['lat_dis_costs'] = [] 77 | self.stats['ptc_dis_costs'] = [] 78 | self.stats['clf_dis_costs'] = [] 79 | 80 | # best reconstruction loss / best accuracy 81 | self.best_loss = 1e12 82 | self.best_accu = -1e12 83 | self.params.n_total_iter = 0 84 | 85 | def lat_dis_step(self): 86 | """ 87 | Train the latent discriminator. 88 | """ 89 | data = self.data 90 | params = self.params 91 | self.ae.eval() 92 | self.lat_dis.train() 93 | bs = params.batch_size 94 | # batch / encode / discriminate 95 | batch_x, batch_y = data.train_batch(bs) 96 | enc_outputs = self.ae.encode(Variable(batch_x.data, volatile=True)) 97 | preds = self.lat_dis(Variable(enc_outputs[-1 - params.n_skip].data)) 98 | # loss / optimize 99 | loss = get_attr_loss(preds, batch_y, False, params) 100 | self.stats['lat_dis_costs'].append(loss.data[0]) 101 | self.lat_dis_optimizer.zero_grad() 102 | loss.backward() 103 | if params.clip_grad_norm: 104 | clip_grad_norm(self.lat_dis.parameters(), params.clip_grad_norm) 105 | self.lat_dis_optimizer.step() 106 | 107 | def ptc_dis_step(self): 108 | """ 109 | Train the patch discriminator. 110 | """ 111 | data = self.data 112 | params = self.params 113 | self.ae.eval() 114 | self.ptc_dis.train() 115 | bs = params.batch_size 116 | # batch / encode / discriminate 117 | batch_x, batch_y = data.train_batch(bs) 118 | flipped = flip_attributes(batch_y, params, 'all') 119 | _, dec_outputs = self.ae(Variable(batch_x.data, volatile=True), flipped) 120 | real_preds = self.ptc_dis(batch_x) 121 | fake_preds = self.ptc_dis(Variable(dec_outputs[-1].data)) 122 | y_fake = Variable(torch.FloatTensor(real_preds.size()) 123 | .fill_(params.smooth_label).cuda()) 124 | # loss / optimize 125 | loss = F.binary_cross_entropy(real_preds, 1 - y_fake) 126 | loss += F.binary_cross_entropy(fake_preds, y_fake) 127 | self.stats['ptc_dis_costs'].append(loss.data[0]) 128 | self.ptc_dis_optimizer.zero_grad() 129 | loss.backward() 130 | if params.clip_grad_norm: 131 | clip_grad_norm(self.ptc_dis.parameters(), params.clip_grad_norm) 132 | self.ptc_dis_optimizer.step() 133 | 134 | def clf_dis_step(self): 135 | """ 136 | Train the classifier discriminator. 137 | """ 138 | data = self.data 139 | params = self.params 140 | self.clf_dis.train() 141 | bs = params.batch_size 142 | # batch / predict 143 | batch_x, batch_y = data.train_batch(bs) 144 | preds = self.clf_dis(batch_x) 145 | # loss / optimize 146 | loss = get_attr_loss(preds, batch_y, False, params) 147 | self.stats['clf_dis_costs'].append(loss.data[0]) 148 | self.clf_dis_optimizer.zero_grad() 149 | loss.backward() 150 | if params.clip_grad_norm: 151 | clip_grad_norm(self.clf_dis.parameters(), params.clip_grad_norm) 152 | self.clf_dis_optimizer.step() 153 | 154 | def autoencoder_step(self): 155 | """ 156 | Train the autoencoder with cross-entropy loss. 157 | Train the encoder with discriminator loss. 158 | """ 159 | data = self.data 160 | params = self.params 161 | self.ae.train() 162 | if params.n_lat_dis: 163 | self.lat_dis.eval() 164 | if params.n_ptc_dis: 165 | self.ptc_dis.eval() 166 | if params.n_clf_dis: 167 | self.clf_dis.eval() 168 | bs = params.batch_size 169 | # batch / encode / decode 170 | batch_x, batch_y = data.train_batch(bs) 171 | enc_outputs, dec_outputs = self.ae(batch_x, batch_y) 172 | # autoencoder loss from reconstruction 173 | loss = params.lambda_ae * ((batch_x - dec_outputs[-1]) ** 2).mean() 174 | self.stats['rec_costs'].append(loss.data[0]) 175 | # encoder loss from the latent discriminator 176 | if params.lambda_lat_dis: 177 | lat_dis_preds = self.lat_dis(enc_outputs[-1 - params.n_skip]) 178 | lat_dis_loss = get_attr_loss(lat_dis_preds, batch_y, True, params) 179 | loss = loss + get_lambda(params.lambda_lat_dis, params) * lat_dis_loss 180 | # decoding with random labels 181 | if params.lambda_ptc_dis + params.lambda_clf_dis > 0: 182 | flipped = flip_attributes(batch_y, params, 'all') 183 | dec_outputs_flipped = self.ae.decode(enc_outputs, flipped) 184 | # autoencoder loss from the patch discriminator 185 | if params.lambda_ptc_dis: 186 | ptc_dis_preds = self.ptc_dis(dec_outputs_flipped[-1]) 187 | y_fake = Variable(torch.FloatTensor(ptc_dis_preds.size()) 188 | .fill_(params.smooth_label).cuda()) 189 | ptc_dis_loss = F.binary_cross_entropy(ptc_dis_preds, 1 - y_fake) 190 | loss = loss + get_lambda(params.lambda_ptc_dis, params) * ptc_dis_loss 191 | # autoencoder loss from the classifier discriminator 192 | if params.lambda_clf_dis: 193 | clf_dis_preds = self.clf_dis(dec_outputs_flipped[-1]) 194 | clf_dis_loss = get_attr_loss(clf_dis_preds, flipped, False, params) 195 | loss = loss + get_lambda(params.lambda_clf_dis, params) * clf_dis_loss 196 | # check NaN 197 | if (loss != loss).data.any(): 198 | logger.error("NaN detected") 199 | exit() 200 | # optimize 201 | self.ae_optimizer.zero_grad() 202 | loss.backward() 203 | if params.clip_grad_norm: 204 | clip_grad_norm(self.ae.parameters(), params.clip_grad_norm) 205 | self.ae_optimizer.step() 206 | 207 | def step(self, n_iter): 208 | """ 209 | End training iteration / print training statistics. 210 | """ 211 | # average loss 212 | if len(self.stats['rec_costs']) >= 25: 213 | mean_loss = [ 214 | ('Latent discriminator', 'lat_dis_costs'), 215 | ('Patch discriminator', 'ptc_dis_costs'), 216 | ('Classifier discriminator', 'clf_dis_costs'), 217 | ('Reconstruction loss', 'rec_costs'), 218 | ] 219 | logger.info(('%06i - ' % n_iter) + 220 | ' / '.join(['%s : %.5f' % (a, np.mean(self.stats[b])) 221 | for a, b in mean_loss if len(self.stats[b]) > 0])) 222 | del self.stats['rec_costs'][:] 223 | del self.stats['lat_dis_costs'][:] 224 | del self.stats['ptc_dis_costs'][:] 225 | del self.stats['clf_dis_costs'][:] 226 | 227 | self.params.n_total_iter += 1 228 | 229 | def save_model(self, name): 230 | """ 231 | Save the model. 232 | """ 233 | def save(model, filename): 234 | path = os.path.join(self.params.dump_path, '%s_%s.pth' % (name, filename)) 235 | logger.info('Saving %s to %s ...' % (filename, path)) 236 | torch.save(model, path) 237 | save(self.ae, 'ae') 238 | if self.params.n_lat_dis: 239 | save(self.lat_dis, 'lat_dis') 240 | if self.params.n_ptc_dis: 241 | save(self.ptc_dis, 'ptc_dis') 242 | if self.params.n_clf_dis: 243 | save(self.clf_dis, 'clf_dis') 244 | 245 | def save_best_periodic(self, to_log): 246 | """ 247 | Save the best models / periodically save the models. 248 | """ 249 | if to_log['ae_loss'] < self.best_loss: 250 | self.best_loss = to_log['ae_loss'] 251 | logger.info('Best reconstruction loss: %.5f' % self.best_loss) 252 | self.save_model('best_rec') 253 | if self.params.eval_clf and np.mean(to_log['clf_accu']) > self.best_accu: 254 | self.best_accu = np.mean(to_log['clf_accu']) 255 | logger.info('Best evaluation accuracy: %.5f' % self.best_accu) 256 | self.save_model('best_accu') 257 | if to_log['n_epoch'] % 5 == 0 and to_log['n_epoch'] > 0: 258 | self.save_model('periodic-%i' % to_log['n_epoch']) 259 | 260 | 261 | def classifier_step(classifier, optimizer, data, params, costs): 262 | """ 263 | Train the classifier. 264 | """ 265 | classifier.train() 266 | bs = params.batch_size 267 | 268 | # batch / classify 269 | batch_x, batch_y = data.train_batch(bs) 270 | preds = classifier(batch_x) 271 | # loss / optimize 272 | loss = get_attr_loss(preds, batch_y, False, params) 273 | costs.append(loss.data[0]) 274 | optimizer.zero_grad() 275 | loss.backward() 276 | if params.clip_grad_norm: 277 | clip_grad_norm(classifier.parameters(), params.clip_grad_norm) 278 | optimizer.step() 279 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import re 10 | import pickle 11 | import random 12 | import inspect 13 | import argparse 14 | import subprocess 15 | import torch 16 | from torch import optim 17 | from logging import getLogger 18 | 19 | from .logger import create_logger 20 | from .loader import AVAILABLE_ATTR 21 | 22 | 23 | FALSY_STRINGS = {'off', 'false', '0'} 24 | TRUTHY_STRINGS = {'on', 'true', '1'} 25 | 26 | MODELS_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'models') 27 | 28 | 29 | logger = getLogger() 30 | 31 | 32 | def initialize_exp(params): 33 | """ 34 | Experiment initialization. 35 | """ 36 | # dump parameters 37 | params.dump_path = get_dump_path(params) 38 | pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb')) 39 | 40 | # create a logger 41 | logger = create_logger(os.path.join(params.dump_path, 'train.log')) 42 | logger.info('============ Initialized logger ============') 43 | logger.info('\n'.join('%s: %s' % (k, str(v)) for k, v 44 | in sorted(dict(vars(params)).items()))) 45 | return logger 46 | 47 | 48 | def bool_flag(s): 49 | """ 50 | Parse boolean arguments from the command line. 51 | """ 52 | if s.lower() in FALSY_STRINGS: 53 | return False 54 | elif s.lower() in TRUTHY_STRINGS: 55 | return True 56 | else: 57 | raise argparse.ArgumentTypeError("invalid value for a boolean flag. use 0 or 1") 58 | 59 | 60 | def attr_flag(s): 61 | """ 62 | Parse attributes parameters. 63 | """ 64 | if s == "*": 65 | return s 66 | attr = s.split(',') 67 | assert len(attr) == len(set(attr)) 68 | attributes = [] 69 | for x in attr: 70 | if '.' not in x: 71 | attributes.append((x, 2)) 72 | else: 73 | split = x.split('.') 74 | assert len(split) == 2 and len(split[0]) > 0 75 | assert split[1].isdigit() and int(split[1]) >= 2 76 | attributes.append((split[0], int(split[1]))) 77 | return sorted(attributes, key=lambda x: (x[1], x[0])) 78 | 79 | 80 | def check_attr(params): 81 | """ 82 | Check attributes validy. 83 | """ 84 | if params.attr == '*': 85 | params.attr = attr_flag(','.join(AVAILABLE_ATTR)) 86 | else: 87 | assert all(name in AVAILABLE_ATTR and n_cat >= 2 for name, n_cat in params.attr) 88 | params.n_attr = sum([n_cat for _, n_cat in params.attr]) 89 | 90 | 91 | def get_optimizer(model, s): 92 | """ 93 | Parse optimizer parameters. 94 | Input should be of the form: 95 | - "sgd,lr=0.01" 96 | - "adagrad,lr=0.1,lr_decay=0.05" 97 | """ 98 | if "," in s: 99 | method = s[:s.find(',')] 100 | optim_params = {} 101 | for x in s[s.find(',') + 1:].split(','): 102 | split = x.split('=') 103 | assert len(split) == 2 104 | assert re.match("^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None 105 | optim_params[split[0]] = float(split[1]) 106 | else: 107 | method = s 108 | optim_params = {} 109 | 110 | if method == 'adadelta': 111 | optim_fn = optim.Adadelta 112 | elif method == 'adagrad': 113 | optim_fn = optim.Adagrad 114 | elif method == 'adam': 115 | optim_fn = optim.Adam 116 | optim_params['betas'] = (optim_params.get('beta1', 0.5), optim_params.get('beta2', 0.999)) 117 | optim_params.pop('beta1', None) 118 | optim_params.pop('beta2', None) 119 | elif method == 'adamax': 120 | optim_fn = optim.Adamax 121 | elif method == 'asgd': 122 | optim_fn = optim.ASGD 123 | elif method == 'rmsprop': 124 | optim_fn = optim.RMSprop 125 | elif method == 'rprop': 126 | optim_fn = optim.Rprop 127 | elif method == 'sgd': 128 | optim_fn = optim.SGD 129 | assert 'lr' in optim_params 130 | else: 131 | raise Exception('Unknown optimization method: "%s"' % method) 132 | 133 | # check that we give good parameters to the optimizer 134 | expected_args = inspect.getargspec(optim_fn.__init__)[0] 135 | assert expected_args[:2] == ['self', 'params'] 136 | if not all(k in expected_args[2:] for k in optim_params.keys()): 137 | raise Exception('Unexpected parameters: expected "%s", got "%s"' % ( 138 | str(expected_args[2:]), str(optim_params.keys()))) 139 | 140 | return optim_fn(model.parameters(), **optim_params) 141 | 142 | 143 | def clip_grad_norm(parameters, max_norm, norm_type=2): 144 | """Clips gradient norm of an iterable of parameters. 145 | The norm is computed over all gradients together, as if they were 146 | concatenated into a single vector. Gradients are modified in-place. 147 | Arguments: 148 | parameters (Iterable[Variable]): an iterable of Variables that will have 149 | gradients normalized 150 | max_norm (float or int): max norm of the gradients 151 | norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. 152 | """ 153 | parameters = list(parameters) 154 | max_norm = float(max_norm) 155 | norm_type = float(norm_type) 156 | if norm_type == float('inf'): 157 | total_norm = max(p.grad.data.abs().max() for p in parameters) 158 | else: 159 | total_norm = 0 160 | for p in parameters: 161 | param_norm = p.grad.data.norm(norm_type) 162 | total_norm += param_norm ** norm_type 163 | total_norm = total_norm ** (1. / norm_type) 164 | clip_coef = max_norm / (total_norm + 1e-6) 165 | if clip_coef >= 1: 166 | return 167 | for p in parameters: 168 | p.grad.data.mul_(clip_coef) 169 | 170 | 171 | def get_dump_path(params): 172 | """ 173 | Create a directory to store the experiment. 174 | """ 175 | assert os.path.isdir(MODELS_PATH) 176 | 177 | # create the sweep path if it does not exist 178 | sweep_path = os.path.join(MODELS_PATH, params.name) 179 | if not os.path.exists(sweep_path): 180 | subprocess.Popen("mkdir %s" % sweep_path, shell=True).wait() 181 | 182 | # create a random name for the experiment 183 | chars = 'abcdefghijklmnopqrstuvwxyz0123456789' 184 | while True: 185 | exp_id = ''.join(random.choice(chars) for _ in range(10)) 186 | dump_path = os.path.join(MODELS_PATH, params.name, exp_id) 187 | if not os.path.isdir(dump_path): 188 | break 189 | 190 | # create the dump folder 191 | if not os.path.isdir(dump_path): 192 | subprocess.Popen("mkdir %s" % dump_path, shell=True).wait() 193 | return dump_path 194 | 195 | 196 | def reload_model(model, to_reload, attributes=None): 197 | """ 198 | Reload a previously trained model. 199 | """ 200 | # reload the model 201 | assert os.path.isfile(to_reload) 202 | to_reload = torch.load(to_reload) 203 | 204 | # check parameters sizes 205 | model_params = set(model.state_dict().keys()) 206 | to_reload_params = set(to_reload.state_dict().keys()) 207 | assert model_params == to_reload_params, (model_params - to_reload_params, 208 | to_reload_params - model_params) 209 | 210 | # check attributes 211 | attributes = [] if attributes is None else attributes 212 | for k in attributes: 213 | if getattr(model, k, None) is None: 214 | raise Exception('Attribute "%s" not found in the current model' % k) 215 | if getattr(to_reload, k, None) is None: 216 | raise Exception('Attribute "%s" not found in the model to reload' % k) 217 | if getattr(model, k) != getattr(to_reload, k): 218 | raise Exception('Attribute "%s" differs between the current model (%s) ' 219 | 'and the one to reload (%s)' 220 | % (k, str(getattr(model, k)), str(getattr(to_reload, k)))) 221 | 222 | # copy saved parameters 223 | for k in model.state_dict().keys(): 224 | if model.state_dict()[k].size() != to_reload.state_dict()[k].size(): 225 | raise Exception("Expected tensor {} of size {}, but got {}".format( 226 | k, model.state_dict()[k].size(), 227 | to_reload.state_dict()[k].size() 228 | )) 229 | model.state_dict()[k].copy_(to_reload.state_dict()[k]) 230 | 231 | 232 | def print_accuracies(values): 233 | """ 234 | Pretty plot of accuracies. 235 | """ 236 | assert all(len(x) == 2 for x in values) 237 | for name, value in values: 238 | logger.info('{:<20}: {:>6}'.format(name, '%.3f%%' % (100 * value))) 239 | logger.info('') 240 | 241 | 242 | def get_lambda(l, params): 243 | """ 244 | Compute discriminators' lambdas. 245 | """ 246 | s = params.lambda_schedule 247 | if s == 0: 248 | return l 249 | else: 250 | return l * float(min(params.n_total_iter, s)) / s 251 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | 8 | import os 9 | import argparse 10 | import torch 11 | 12 | from src.loader import load_images, DataSampler 13 | from src.utils import initialize_exp, bool_flag, attr_flag, check_attr 14 | from src.model import AutoEncoder, LatentDiscriminator, PatchDiscriminator, Classifier 15 | from src.training import Trainer 16 | from src.evaluation import Evaluator 17 | 18 | 19 | # parse parameters 20 | parser = argparse.ArgumentParser(description='Images autoencoder') 21 | parser.add_argument("--name", type=str, default="default", 22 | help="Experiment name") 23 | parser.add_argument("--img_sz", type=int, default=256, 24 | help="Image sizes (images have to be squared)") 25 | parser.add_argument("--img_fm", type=int, default=3, 26 | help="Number of feature maps (1 for grayscale, 3 for RGB)") 27 | parser.add_argument("--attr", type=attr_flag, default="Smiling,Male", 28 | help="Attributes to classify") 29 | parser.add_argument("--instance_norm", type=bool_flag, default=False, 30 | help="Use instance normalization instead of batch normalization") 31 | parser.add_argument("--init_fm", type=int, default=32, 32 | help="Number of initial filters in the encoder") 33 | parser.add_argument("--max_fm", type=int, default=512, 34 | help="Number maximum of filters in the autoencoder") 35 | parser.add_argument("--n_layers", type=int, default=6, 36 | help="Number of layers in the encoder / decoder") 37 | parser.add_argument("--n_skip", type=int, default=0, 38 | help="Number of skip connections") 39 | parser.add_argument("--deconv_method", type=str, default="convtranspose", 40 | help="Deconvolution method") 41 | parser.add_argument("--hid_dim", type=int, default=512, 42 | help="Last hidden layer dimension for discriminator / classifier") 43 | parser.add_argument("--dec_dropout", type=float, default=0., 44 | help="Dropout in the decoder") 45 | parser.add_argument("--lat_dis_dropout", type=float, default=0.3, 46 | help="Dropout in the latent discriminator") 47 | parser.add_argument("--n_lat_dis", type=int, default=1, 48 | help="Number of latent discriminator training steps") 49 | parser.add_argument("--n_ptc_dis", type=int, default=0, 50 | help="Number of patch discriminator training steps") 51 | parser.add_argument("--n_clf_dis", type=int, default=0, 52 | help="Number of classifier discriminator training steps") 53 | parser.add_argument("--smooth_label", type=float, default=0.2, 54 | help="Smooth label for patch discriminator") 55 | parser.add_argument("--lambda_ae", type=float, default=1, 56 | help="Autoencoder loss coefficient") 57 | parser.add_argument("--lambda_lat_dis", type=float, default=0.0001, 58 | help="Latent discriminator loss feedback coefficient") 59 | parser.add_argument("--lambda_ptc_dis", type=float, default=0, 60 | help="Patch discriminator loss feedback coefficient") 61 | parser.add_argument("--lambda_clf_dis", type=float, default=0, 62 | help="Classifier discriminator loss feedback coefficient") 63 | parser.add_argument("--lambda_schedule", type=float, default=500000, 64 | help="Progressively increase discriminators' lambdas (0 to disable)") 65 | parser.add_argument("--v_flip", type=bool_flag, default=False, 66 | help="Random vertical flip for data augmentation") 67 | parser.add_argument("--h_flip", type=bool_flag, default=True, 68 | help="Random horizontal flip for data augmentation") 69 | parser.add_argument("--batch_size", type=int, default=32, 70 | help="Batch size") 71 | parser.add_argument("--ae_optimizer", type=str, default="adam,lr=0.0002", 72 | help="Autoencoder optimizer (SGD / RMSprop / Adam, etc.)") 73 | parser.add_argument("--dis_optimizer", type=str, default="adam,lr=0.0002", 74 | help="Discriminator optimizer (SGD / RMSprop / Adam, etc.)") 75 | parser.add_argument("--clip_grad_norm", type=float, default=5, 76 | help="Clip gradient norms (0 to disable)") 77 | parser.add_argument("--n_epochs", type=int, default=1000, 78 | help="Total number of epochs") 79 | parser.add_argument("--epoch_size", type=int, default=50000, 80 | help="Number of samples per epoch") 81 | parser.add_argument("--ae_reload", type=str, default="", 82 | help="Reload a pretrained encoder") 83 | parser.add_argument("--lat_dis_reload", type=str, default="", 84 | help="Reload a pretrained latent discriminator") 85 | parser.add_argument("--ptc_dis_reload", type=str, default="", 86 | help="Reload a pretrained patch discriminator") 87 | parser.add_argument("--clf_dis_reload", type=str, default="", 88 | help="Reload a pretrained classifier discriminator") 89 | parser.add_argument("--eval_clf", type=str, default="", 90 | help="Load an external classifier for evaluation") 91 | parser.add_argument("--debug", type=bool_flag, default=False, 92 | help="Debug mode (only load a subset of the whole dataset)") 93 | params = parser.parse_args() 94 | 95 | # check parameters 96 | check_attr(params) 97 | assert len(params.name.strip()) > 0 98 | assert params.n_skip <= params.n_layers - 1 99 | assert params.deconv_method in ['convtranspose', 'upsampling', 'pixelshuffle'] 100 | assert 0 <= params.smooth_label < 0.5 101 | assert not params.ae_reload or os.path.isfile(params.ae_reload) 102 | assert not params.lat_dis_reload or os.path.isfile(params.lat_dis_reload) 103 | assert not params.ptc_dis_reload or os.path.isfile(params.ptc_dis_reload) 104 | assert not params.clf_dis_reload or os.path.isfile(params.clf_dis_reload) 105 | assert os.path.isfile(params.eval_clf) 106 | assert params.lambda_lat_dis == 0 or params.n_lat_dis > 0 107 | assert params.lambda_ptc_dis == 0 or params.n_ptc_dis > 0 108 | assert params.lambda_clf_dis == 0 or params.n_clf_dis > 0 109 | 110 | # initialize experiment / load dataset 111 | logger = initialize_exp(params) 112 | data, attributes = load_images(params) 113 | train_data = DataSampler(data[0], attributes[0], params) 114 | valid_data = DataSampler(data[1], attributes[1], params) 115 | 116 | # build the model 117 | ae = AutoEncoder(params).cuda() 118 | lat_dis = LatentDiscriminator(params).cuda() if params.n_lat_dis else None 119 | ptc_dis = PatchDiscriminator(params).cuda() if params.n_ptc_dis else None 120 | clf_dis = Classifier(params).cuda() if params.n_clf_dis else None 121 | eval_clf = torch.load(params.eval_clf).cuda().eval() 122 | 123 | # trainer / evaluator 124 | trainer = Trainer(ae, lat_dis, ptc_dis, clf_dis, train_data, params) 125 | evaluator = Evaluator(ae, lat_dis, ptc_dis, clf_dis, eval_clf, valid_data, params) 126 | 127 | 128 | for n_epoch in range(params.n_epochs): 129 | 130 | logger.info('Starting epoch %i...' % n_epoch) 131 | 132 | for n_iter in range(0, params.epoch_size, params.batch_size): 133 | 134 | # latent discriminator training 135 | for _ in range(params.n_lat_dis): 136 | trainer.lat_dis_step() 137 | 138 | # patch discriminator training 139 | for _ in range(params.n_ptc_dis): 140 | trainer.ptc_dis_step() 141 | 142 | # classifier discriminator training 143 | for _ in range(params.n_clf_dis): 144 | trainer.clf_dis_step() 145 | 146 | # autoencoder training 147 | trainer.autoencoder_step() 148 | 149 | # print training statistics 150 | trainer.step(n_iter) 151 | 152 | # run all evaluations / save best or periodic model 153 | to_log = evaluator.evaluate(n_epoch) 154 | trainer.save_best_periodic(to_log) 155 | logger.info('End of epoch %i.\n' % n_epoch) 156 | --------------------------------------------------------------------------------