├── LICENSE.txt ├── README.md ├── __init__.py ├── __pycache__ ├── datasets.cpython-36.pyc ├── nets.cpython-36.pyc └── trainer.cpython-36.pyc ├── configs └── 001.yaml ├── data ├── supple │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ ├── 00008.jpg │ ├── latent_code_00001.npy │ ├── latent_code_00002.npy │ ├── latent_code_00003.npy │ ├── latent_code_00004.npy │ ├── latent_code_00005.npy │ ├── latent_code_00006.npy │ ├── latent_code_00007.npy │ └── latent_code_00008.npy ├── teaser │ ├── 00001.jpg │ ├── 00002.jpg │ ├── latent_code_00001.npy │ └── latent_code_00002.npy ├── test │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ ├── 00008.jpg │ ├── latent_code_00000.npy │ ├── latent_code_00001.npy │ ├── latent_code_00002.npy │ ├── latent_code_00003.npy │ ├── latent_code_00004.npy │ ├── latent_code_00005.npy │ ├── latent_code_00006.npy │ ├── latent_code_00007.npy │ └── latent_code_00008.npy └── video │ ├── FP006542HD02.mp4 │ └── FP006911MD02.mp4 ├── datasets.py ├── download.sh ├── environment.yml ├── evaluation.py ├── image ├── README.md ├── license_of_images_used.json ├── user_interface.jpg └── video_result.jpg ├── nets.py ├── notebooks ├── figure_sequential_edit.ipynb ├── figure_supplementary.ipynb └── visu_manipulation.ipynb ├── pretraining ├── latent_classifier.py └── latent_classifier_eval.py ├── run_video_manip.sh ├── test.py ├── train.py ├── trainer.py ├── utils ├── __pycache__ │ ├── functions.cpython-36.pyc │ ├── upfirdn_2d.cpython-36.pyc │ └── video_utils.cpython-36.pyc ├── functions.py └── video_utils.py └── video_processing.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | LIMITED SOFTWARE EVALUATION LICENSE AGREEMENT 2 | 3 | 4 | 5 | This Limited Software Evaluation License Agreement (the “Agreement”) is entered into as of April 9th 2020, (“Effective Date”) 6 | 7 | The following limited software evaluation license agreement (“the Agreement”) constitute an agreement between you (the “licensee”) and InterDigital R&D France, a French company existing and organized under the laws of France with its registered offices located at 975 avenue des champs blancs 35510 Cesson-Sévigné, FRANCE (hereinafter “InterDigital”) 8 | This Agreement governs the download and use of the Software (as defined below). Your use of the Software is subject to the terms and conditions set forth in this Agreement. By installing, using, accessing or copying the Software, you hereby irrevocably accept the terms and conditions of this Agreement. If you do not accept all or parts of the terms and conditions of this Agreement you cannot install, use, access nor copy the Software 9 | 10 | Article 1. Definitions 11 | 12 | “Affiliate” as used herein shall mean any entity that, directly or indirectly, through one or more intermediates, is controlled by, controls, or is under common control with InterDigital or The Licensee, as the case may be. For purposes of this definition only, the term “control” means the possession of the power to direct or cause the direction of the management and policies of an entity, whether by ownership of voting stock or partnership interest, by contract, or otherwise, including direct or indirect ownership of more than fifty percent (50%) of the voting interest in the entity in question. 13 | 14 | “Authorized Purpose” means any use of the Software for research on the Software and evaluation of the Software exclusively, and academic research using the Software without any commercial use. For the avoidance of doubt, a commercial use includes, but is not limited to: 15 | - using the Software in advertisements of any kind, 16 | - licensing or selling of the Software, 17 | - use the Software to provide any service to any third Party 18 | - use the Software to develop a competitive product of the Software 19 | 20 | “Documentation” means textual materials delivered by InterDigital to the Licensee pursuant to this Agreement relating to the Software, in written or electronic format, including but not limited to: technical reference manuals, technical notes, user manuals, and application guides. 21 | 22 | “Limited Period” means the life of the copyright owned by InterDigital on the Software in each and every country where such copyright would exist. 23 | 24 | “Intellectual Property Rights” means all copyrights, trademarks, trade secrets, patents, mask works and other intellectual property rights recognized in any jurisdiction worldwide, including all applications and registrations with respect thereto. 25 | 26 | "Open Source software" shall mean any software, including where appropriate, any and all modifications, derivative works, enhancements, upgrades, improvements, fixed bugs, and/or statically linked to the source code of such software, released under a free software license, that requires as a condition of royalty-free usage, copy, modification and/or redistribution of the Open Source Software to: 27 | • Redistribute the Open Source Software royalty-free, and/or; 28 | • Redistribute the Open Source Software under the same license/distribution terms as those contained in the open source or free software license under which it has originally been released and/or; 29 | • Release to the public, disclose or otherwise make available the source code of the Open Source Software. 30 | 31 | For purposes of the Agreement, by means of example and without limitation, any software that is released or distributed under any of the following licenses shall be qualified as Open Source Software: (A) GNU General Public License (GPL), (B) GNU Lesser/Library GPL (LGPL), (C) the Artistic License, (D) the Mozilla Public License, (E) the Common Public License, (F) the Sun Community Source License (SCSL), (G) the Sun Industry Standards Source License (SISSL), (H) BSD License, (I) MIT License, (J) Apache Software License, (K) Open SSL License, (L) IBM Public License, (M) Open Software License. 32 | 33 | “Software” means any computer programming code, in object and/or source version, and related Documentation delivered by InterDigital to the Licensee pursuant to this Agreement as described in Exhibit A attached and incorporated herein by reference. 34 | 35 | Article 2. License 36 | 37 | InterDigital grants Licensee a free, worldwide, non-exclusive, license on copyright owned on the Software to download, use, modify and reproduce solely for the Authorized Purpose for the Limited Period. 38 | 39 | The Licensee shall not pay any royalty, license fee or maintenance fee, or other fee of any nature under this Agreement. 40 | 41 | The Licensee shall have the right to correct, adapt, modify, reverse engineer, disassemble, decompile and any action leading to the transformation of Software provided that such action is made to accomplish the Authorized Purpose. 42 | 43 | Licensee shall have the right to make a demonstration of the Software, provided that it is in the Purpose and provided that Licensee shall maintain control of the Software at all time. This includes the control of any computer or server on which the Software is installed: no third party shall have access to such computer or server under any circumstances. No computer nor server containing the Software will be left in the possession of any third Party. 44 | 45 | Article 3. Restrictions on use of the Software 46 | 47 | Licensee shall not remove, obscure or modify any copyright, trademark or other proprietary rights notices, marks or labels contained on or within the Software, falsify or delete any author attributions, legal notices or other labels of the origin or source of the material. 48 | 49 | Licensee shall not have the right to distribute the Software, either modified or not, to any third Party. 50 | 51 | The rights granted here above do not include any rights to automatically obtain any upgrade or update of the Software, acquired or otherwise made available by InterDigital. Such deliverance shall be discussed on a case by case basis by the Parties. 52 | 53 | Article 4. Ownership 54 | 55 | Title to and ownership of the Software, the Documentation and/or any Intellectual Property Right protecting the Software or/and the Documentation shall, at all times, remain with InterDigital. Licensee agrees that except for the rights granted on copyright on the Software set forth in Section 2 above, in no event does anything in this Agreement grant, provide or convey any other rights, immunities or interest in or to any Intellectual Property Rights (including especially patents) of InterDigital or any of its Affiliates whether by implication, estoppel or otherwise. 56 | 57 | 58 | Article 5. Publication/Communication 59 | 60 | Any publication or oral communication resulting from the use of the Software shall be elaborated in good faith and shall not be driven by a deliberate will to denigrate InterDigital or any of its products. In any publication and on any support joined to an oral communication (for instance a PowerPoint document) resulting from the use of the Software, the following statement shall be inserted: 61 | 62 | “HRFAE is an InterDigital product” 63 | 64 | And in any publication, the latest publication about the software shall be properly cited. The latest publication currently is: 65 | "Arxiv preprint (ref to come shortly)” 66 | 67 | In any oral communication resulting from the use of the Software, the Licensee shall orally indicate that the Software is InterDigital’s property. 68 | 69 | Article 6. No Warranty - Disclaimer 70 | 71 | THE SOFTWARE AND DOCUMENTATION ARE PROVIDED TO LICENSEE ON AN “AS IS” BASIS. INTERDIGITAL MAKES NO WARRANTY THAT THE LICENSED TECHNOLOGY WILL OPERATE ON ANY PARTICULAR HARDWARE, PLATFORM, OR ENVIRONMENT. THERE IS NO WARRANTY THAT THE OPERATION OF THE LICENSED TECHNOLOGY SHALL BE UNINTERRUPTED, WITHOUT BUGS OR ERROR FREE. THE SOFTWARE AND DOCUMENTATION ARE PROVIDED HEREUNDER WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED LIABILITIES AND WARRANTIES OF NONINFRINGEMENT OF INTELLECTUAL PROPERTY, FREEDOM FROM INHERENT DEFECTS, CONFORMITY TO A SAMPLE OR MODEL, MERCHANTABILITY, FITNESS AND/OR SUITABILITY FOR A SPECIFIC OR GENERAL PURPOSE AND THOSE ARISING BY STATUTE OR BY LAW, OR FROM A CAUSE OF DEALING OR USAGE OF TRADE. 72 | 73 | InterDigital shall not be obliged to perform any modifications, derivative works, enhancements, upgrades, updates or improvements of the Software or to fix any bug that could arise. 74 | 75 | Hence, the Licensee uses the Software at his own cost, risks and responsibility. InterDigital shall not be liable for any damage that could arise to Licensee by using the Software, either in accordance with this Agreement or not. 76 | 77 | InterDigital shall not be liable for any consequential or indirect losses, including any indirect loss of profits, revenues, business, and/or anticipated savings, whether or not in the contemplation of the Parties at the time of entering into the Agreement unless expressly set out in the Agreement, or arising from gross negligence, willful misconduct or fraud. 78 | 79 | Licensee agrees that it will defend, indemnify and hold harmless InterDigital and its Affiliates against any and all losses, damages, costs and expenses arising from a breach by the Licensee of any of its obligations or representations hereunder, including, without limitation, any third party, and/or any claims in connection with any such breach and/or any use of the Software, including any claim from third party arising from access, use or any other activity in relation to this Software. 80 | 81 | The Licensee shall not make any warranty, representation, or commitment on behalf of InterDigital to any other third party. 82 | 83 | Article 7. Open Source Software 84 | 85 | InterDigital hereby notifies the Licensee, and the Licensee hereby acknowledges and accepts, that the Software contains Open Source Software. The list of such Open Source Software is enclosed in exhibit B and the relevant license are contained at the root of the Software when downloaded. Hence, the Licensee shall comply with such license and agree on its terms on at its own risks. 86 | 87 | The Licensee hereby represents, warrants and covenants to InterDigital that The Licensee’s use of the Software shall not result in the Contamination of all or part of the Software, directly or indirectly, or of any Intellectual Property of InterDigital or its Affiliates. 88 | 89 | Contamination effect shall mean that the licensing terms under which one Open Source software, distinct from the Software, is released would also apply, by viral effect, to the software to which such Open Source software is linked to, combined with or otherwise connected to. 90 | 91 | Article 8. No Future Contract Obligation 92 | 93 | Neither this Agreement nor the furnishing of the Software, nor any other Confidential Information shall be construed to obligate either party to: (a) enter into any further agreement or negotiation concerning the deployment of the Software; (b) refrain from entering into any agreement or negotiation with any other third party regarding the same or any other subject matter; or (c) refrain from pursuing its business in whatever manner it elects even if this involves competing with the other party. 94 | 95 | Article 9. Term and Termination 96 | 97 | This Agreement shall terminate at the end of the Limited Period, unless earlier terminated by either party on the ground of material breach by the other party, which breach is not remedied after thirty (30) days advance written notice, specifying the breach with reasonable particularity and referencing this Agreement. 98 | 99 | Article 10. General Provisions 100 | 101 | 12.1 Severability. If any provision of this Agreement shall be held to be in contravention of applicable law, this Agreement shall be construed as if such provision were not a part thereof, and in all other respects the terms hereof shall remain in full force and effect. 102 | 103 | 12.2 Governing Law. Regardless of the place of execution, delivery, performance or any other aspect of this Agreement, this Agreement and all of the rights of the parties under this Agreement shall be governed by, construed under and enforced in accordance with the substantive law of the France without regard to conflicts of law principles. In case of a dispute that could not be settled amicably, the courts of Nanterre shall be exclusively competent. 104 | 105 | 12.3 Survival. The provisions of articles 1, 3, 4, 6, 7, 9, 10.2 and 10.6 shall survive termination of this Agreement. 106 | 12.4 Assignment. InterDigital may assign this license to any third Party. Such assignment will be announced on the website as defined in article 5. Licensee may not assign this agreement to any third party without the previous written agreement from InterDigital. 107 | 108 | 12.5 Entire Agreement. This Agreement constitutes the entire agreement between the parties hereto with respect to the subject matter hereof and supersedes any prior agreements or understanding. 109 | 110 | 12.6 Notices. To have legal effect, notices must be provided by registered or certified mail, return receipt requested, to the representatives of InterDigital at the following address: 111 | 112 | InterDigital 113 | Legal Dept 114 | 975 avenue des champs blancs 115 | 35510 Cesson-Sévigné 116 | FRANCE 117 | 118 | ======================================================================= 119 | 120 | Exhibit A 121 | Software 122 | 123 | 124 | The Software is comprised of the following software and Documentation: 125 | 126 | - README.md file that explains the content of the software and the procedure to use it. 127 | - Source python files, as well as pretrained models 128 | 129 | ======================================================================= 130 | 131 | Exhibit B 132 | Open Source licenses 133 | 134 | 135 | PIL http://www.pythonware.com/products/pil/license.htm 136 | 137 | numpy https://numpy.org/license.html 138 | 139 | tensorboardX https://github.com/lanpa/tensorboardX/blob/master/LICENSE 140 | 141 | pytorch https://github.com/pytorch/pytorch/blob/master/LICENSE 142 | 143 | torchvision https://github.com/pytorch/vision/blob/master/LICENSE 144 | 145 | tensorboard_logger https://github.com/TeamHG-Memex/tensorboard_logger/blob/master/LICENSE 146 | 147 | argparse https://github.com/ThomasWaldmann/argparse/blob/master/LICENSE.txt 148 | 149 | yaml https://github.com/yaml/pyyaml/blob/master/LICENSE 150 | 151 | 152 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Latent Transformer for Disentangled Face Editing in Images and Videos 2 | 3 | Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos. 4 | 5 | [[Video Editing Results]](https://drive.google.com/drive/folders/1aIfmbgJL1CUFgZQzqDVaUtLHrqxS6QjP?usp=sharing) 6 | 7 | ## Requirements 8 | 9 | ### Dependencies 10 | 11 | - Python 3.6 12 | - PyTorch 1.8 13 | - Opencv 14 | - Tensorboard_logger 15 | 16 | You can install a new environment for this repo by running 17 | ``` 18 | conda env create -f environment.yml 19 | conda activate lattrans 20 | ``` 21 | 22 | ### Prepare StyleGAN2 encoder and generator 23 | 24 | * We use the pretrained StyleGAN2 encoder and generator released from paper [Encoding in Style: a StyleGAN Encoder for Image-to-Image Translation](https://arxiv.org/pdf/2008.00951.pdf). Download and save the [official implementation](https://github.com/eladrich/pixel2style2pixel.git) to `pixel2style2pixel/` directory. Download and save the [pretrained model](https://drive.google.com/file/d/1bMTNWkh5LArlaWSc_wa8VKyq2V42T2z0/view) to `pixel2style2pixel/pretrained_models/`. 25 | 26 | * In order to save the latent codes to the designed path, we slightly modify `pixel2style2pixel/scripts/inference.py`. 27 | 28 | ``` 29 | # modify run_on_batch() 30 | if opts.latent_mask is None: 31 | result_batch = net(inputs, randomize_noise=False, resize=opts.resize_outputs, return_latents=True) 32 | 33 | # modify run() 34 | tic = time.time() 35 | result_batch, latent_batch = run_on_batch(input_cuda, net, opts) 36 | latent_save_path = os.path.join(test_opts.exp_dir, 'latent_code_%05d.npy'%global_i) 37 | np.save(latent_save_path, latent_batch.cpu().numpy()) 38 | toc = time.time() 39 | ``` 40 | 41 | 42 | ## Training 43 | 44 | * Prepare the training data 45 | 46 | To train the latent transformers, you can download [our prepared dataset](https://drive.google.com/drive/folders/1aXVc-q2ER7A9aACSwml5Wyw5ZgrgPq52?usp=sharing) to the directory `data/` and the [pretrained latent classifier](https://drive.google.com/file/d/1K_ShWBfTOCbxBcJfzti7vlYGmRbjXTfn/view?usp=sharing) to the directory `models/`. 47 | ``` 48 | sh download.sh 49 | ``` 50 | 51 | You can also prepare your own training data. To achieve that, you need to map your dataset to latent codes using the StyleGAN2 encoder. The corresponding label file is also required. You can continue to use our pretrained latent classifier. If you want to train your own latent classifier on new labels, you can use `pretraining/latent_classifier.py`. 52 | 53 | * Training 54 | 55 | You can modify the training options of the config file in the directory `configs/`. 56 | ``` 57 | python train.py --config 001 58 | ``` 59 | 60 | ## Testing 61 | 62 | ### Single Attribute Manipulation 63 | 64 | Make sure that the latent classifier is downloaded to the directory `models/` and the StyleGAN2 encoder is prepared as required. After training your latent transformers, you can use `test.py` to run the latent transformer for the images in the test directory `data/test/`. We also provide several pretrained models [here](https://drive.google.com/file/d/14uipafI5mena7LFFtvPh6r5HdzjBqFEt/view?usp=sharing) (run ```download.sh``` to download them). The output images will be saved in the folder `outputs/`. You can change the desired attribute with `--attr`. 65 | 66 | ``` 67 | python test.py --config 001 --attr Eyeglasses --out_path ./outputs/ 68 | ``` 69 | If you want to test the model on your custom images, you need to first encoder the images to the latent space of StyleGAN using the pretrained encoder. 70 | ``` 71 | cd pixel2style2pixel/ 72 | python scripts/inference.py \ 73 | --checkpoint_path=pretrained_models/psp_ffhq_encode.pt \ 74 | --data_path=../data/test/ \ 75 | --exp_dir=../data/test/ \ 76 | --test_batch_size=1 77 | ``` 78 | 79 | ### Sequential Attribute Manipulation 80 | 81 | You can reproduce the sequential editing results in the paper using `notebooks/figure_sequential_edit.ipynb` and the results in the supplementary material using `notebooks/figure_supplementary.ipynb`. 82 | 83 | ![User Interface](./image/user_interface.jpg) 84 | 85 | We also provide an interactive visualization `notebooks/visu_manipulation.ipynb`, where the user can choose the desired attributes for manipulation and define the magnitude of edit for each attribute. 86 | 87 | 88 | ## Video Manipulation 89 | 90 | ![Video Result](./image/video_result.jpg) 91 | 92 | We provide a script to achieve attribute manipulation for the videos in the test directory `data/video/`. Please ensure that the StyleGAN2 encoder is prepared as required. You can upload your own video and modify the options in `run_video_manip.sh`. You can view our [video editing results](https://drive.google.com/drive/folders/1aIfmbgJL1CUFgZQzqDVaUtLHrqxS6QjP?usp=sharing) presented in the paper. 93 | 94 | ``` 95 | sh run_video_manip.sh 96 | ``` 97 | 98 | ## Citation 99 | ``` 100 | @article{yao2021latent, 101 | title={A Latent Transformer for Disentangled Face Editing in Images and Videos}, 102 | author={Yao, Xu and Newson, Alasdair and Gousseau, Yann and Hellier, Pierre}, 103 | journal={2021 International Conference on Computer Vision}, 104 | year={2021} 105 | } 106 | ``` 107 | ## License 108 | 109 | Copyright © 2021, InterDigital R&D France. All rights reserved. 110 | 111 | This source code is made available under the license found in the LICENSE.txt in the root directory of this source tree. 112 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/__init__.py -------------------------------------------------------------------------------- /__pycache__/datasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/__pycache__/datasets.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/nets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/__pycache__/nets.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /configs/001.yaml: -------------------------------------------------------------------------------- 1 | # Input data 2 | resolution: 1024 3 | age_min: 20 4 | age_max: 70 5 | # Training hyperparameters 6 | batch_size: 32 7 | epochs: 10 8 | truncation_psi: 0.5 9 | # Optimizer parameters 10 | lr: 0.001 11 | beta_1: 0.9 12 | beta_2: 0.999 13 | weight_decay: 0.0005 14 | # Learning rate scheduler 15 | step_size: 10 16 | gamma: 0.1 17 | # Tensorboard log options 18 | image_save_iter: 3000 19 | image_log_iter: 1000 20 | log_iter: 10 21 | # Networks 22 | net_type: 17 23 | mapping_layers: 18 24 | mapping_fmaps: 512 25 | mapping_lrmul: 1 26 | mapping_nonlinearity: 'linear' 27 | # Classifier 28 | cls_type: 1 29 | latent_cls_type: 3013 30 | # Attributes 31 | attr: '5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,Black_Hair,Blond_Hair,Blurry,Brown_Hair,Bushy_Eyebrows,Chubby,Double_Chin,Eyeglasses,Goatee,Gray_Hair,Heavy_Makeup,High_Cheekbones,Male,Mouth_Slightly_Open,Mustache,Narrow_Eyes,No_Beard,Oval_Face,Pale_Skin,Pointy_Nose,Receding_Hairline,Rosy_Cheeks,Sideburns,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young' 32 | # Loss weight 33 | corr_threshold: 1 34 | w: 35 | pb: 1 36 | recon: 10 37 | reg: 1 -------------------------------------------------------------------------------- /data/supple/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00001.jpg -------------------------------------------------------------------------------- /data/supple/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00002.jpg -------------------------------------------------------------------------------- /data/supple/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00003.jpg -------------------------------------------------------------------------------- /data/supple/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00004.jpg -------------------------------------------------------------------------------- /data/supple/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00005.jpg -------------------------------------------------------------------------------- /data/supple/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00006.jpg -------------------------------------------------------------------------------- /data/supple/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00007.jpg -------------------------------------------------------------------------------- /data/supple/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/00008.jpg -------------------------------------------------------------------------------- /data/supple/latent_code_00001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00001.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00002.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00002.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00003.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00003.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00004.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00004.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00005.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00005.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00006.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00006.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00007.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00007.npy -------------------------------------------------------------------------------- /data/supple/latent_code_00008.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/supple/latent_code_00008.npy -------------------------------------------------------------------------------- /data/teaser/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/teaser/00001.jpg -------------------------------------------------------------------------------- /data/teaser/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/teaser/00002.jpg -------------------------------------------------------------------------------- /data/teaser/latent_code_00001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/teaser/latent_code_00001.npy -------------------------------------------------------------------------------- /data/teaser/latent_code_00002.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/teaser/latent_code_00002.npy -------------------------------------------------------------------------------- /data/test/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00000.jpg -------------------------------------------------------------------------------- /data/test/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00001.jpg -------------------------------------------------------------------------------- /data/test/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00002.jpg -------------------------------------------------------------------------------- /data/test/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00003.jpg -------------------------------------------------------------------------------- /data/test/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00004.jpg -------------------------------------------------------------------------------- /data/test/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00005.jpg -------------------------------------------------------------------------------- /data/test/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00006.jpg -------------------------------------------------------------------------------- /data/test/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00007.jpg -------------------------------------------------------------------------------- /data/test/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/00008.jpg -------------------------------------------------------------------------------- /data/test/latent_code_00000.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00000.npy -------------------------------------------------------------------------------- /data/test/latent_code_00001.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00001.npy -------------------------------------------------------------------------------- /data/test/latent_code_00002.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00002.npy -------------------------------------------------------------------------------- /data/test/latent_code_00003.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00003.npy -------------------------------------------------------------------------------- /data/test/latent_code_00004.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00004.npy -------------------------------------------------------------------------------- /data/test/latent_code_00005.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00005.npy -------------------------------------------------------------------------------- /data/test/latent_code_00006.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00006.npy -------------------------------------------------------------------------------- /data/test/latent_code_00007.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00007.npy -------------------------------------------------------------------------------- /data/test/latent_code_00008.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/test/latent_code_00008.npy -------------------------------------------------------------------------------- /data/video/FP006542HD02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/video/FP006542HD02.mp4 -------------------------------------------------------------------------------- /data/video/FP006911MD02.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/data/video/FP006911MD02.mp4 -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as data 12 | 13 | from PIL import Image 14 | from torchvision import transforms, utils 15 | 16 | class LatentDataset(data.Dataset): 17 | def __init__(self, latent_dir, label_dir, training_set=True): 18 | dlatents = np.load(latent_dir) 19 | labels = np.load(label_dir) 20 | 21 | train_len = int(0.9*len(labels)) 22 | if training_set: 23 | self.dlatents = dlatents[:train_len] 24 | self.labels = labels[:train_len] 25 | #self.process_score() 26 | else: 27 | self.dlatents = dlatents[train_len:] 28 | self.labels = labels[train_len:] 29 | 30 | self.length = len(self.labels) 31 | 32 | def __len__(self): 33 | return self.length 34 | 35 | def __getitem__(self, idx): 36 | dlatent = torch.tensor(self.dlatents[idx]) 37 | lbl = torch.tensor(self.labels[idx]) 38 | 39 | return dlatent, lbl 40 | 41 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # pip install gdown 2 | 3 | # download our prepared dataset 4 | gdown https://drive.google.com/uc?id=1qx-FRJV0yZmiF0GzQLsZFSJznUgNipf7 -O data/ 5 | gdown https://drive.google.com/uc?id=1Pz3wEpp6E7aFSUg5zh1Nm6EpWmRG_HnY -O data/ 6 | 7 | # download the pretrained latent classifier 8 | gdown https://drive.google.com/uc?id=1K_ShWBfTOCbxBcJfzti7vlYGmRbjXTfn -O models/ 9 | 10 | # download the pretrained models 11 | gdown https://drive.google.com/uc?id=14uipafI5mena7LFFtvPh6r5HdzjBqFEt -O logs/ 12 | unzip logs/pretrained_models.zip -d logs/ -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: lattrans 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - cudatoolkit=10.2.89 7 | - numpy=1.19.2 8 | - python=3.6.13 9 | - pip=21.1.1 10 | - pytorch=1.8.1 11 | - torchaudio=0.8.1 12 | - torchvision=0.9.1 13 | - pip: 14 | - face-alignment==1.3.4 15 | - opencv-python==4.5.2.52 16 | - scikit-image==0.17.2 17 | - scipy==1.2.0 18 | - pyyaml==5.4.1 19 | - tensorboard-logger==0.1.0 20 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import glob 8 | import os 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.data as data 14 | import yaml 15 | 16 | from PIL import Image 17 | from torchvision import transforms, utils, models 18 | from tensorboard_logger import Logger 19 | 20 | from datasets import * 21 | from trainer import * 22 | from utils.functions import * 23 | 24 | torch.backends.cudnn.enabled = True 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = True 27 | torch.autograd.set_detect_anomaly(True) 28 | Image.MAX_IMAGE_PIXELS = None 29 | device = torch.device('cuda') 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 33 | parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.') 34 | parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path') 35 | parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path') 36 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path') 37 | parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier') 38 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 39 | opts = parser.parse_args() 40 | 41 | # Celeba attribute list 42 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 43 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 44 | 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 45 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 46 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 47 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 48 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 49 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 50 | 51 | ######################################################################################################################## 52 | # Generate manipulated samples for evaluation 53 | # For the evaluation data, we project the first 1K images of FFHQ into the the latent space W+ of StyleGAN. 54 | # For each input image, we edit each attribute with 10 different scaling factors and generate the corresponding images. 55 | ######################################################################################################################## 56 | 57 | # Load input latent codes 58 | testdata_dir = '/srv/tacm/users/yaox/ffhq_latents_psp/' 59 | n_steps = 11 60 | scale = 2.0 61 | 62 | with torch.no_grad(): 63 | 64 | save_dir = './outputs/evaluation/' 65 | os.makedirs(save_dir, exist_ok=True) 66 | 67 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 68 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r')) 69 | 70 | # Initialize trainer 71 | trainer = Trainer(config, None, None, opts.label_file) 72 | trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) 73 | trainer.to(device) 74 | 75 | for attr in list(attr_dict.keys()): 76 | 77 | attr_num = attr_dict[attr] 78 | trainer.attr_num = attr_dict[attr] 79 | trainer.load_model(log_dir) 80 | 81 | for k in range(1000): 82 | 83 | w_0 = np.load(testdata_dir + 'latent_code_%05d.npy' % k) 84 | w_0 = torch.tensor(w_0).to(device) 85 | 86 | predict_lbl_0 = trainer.Latent_Classifier(w_0.view(w_0.size(0), -1)) 87 | lbl_0 = F.sigmoid(predict_lbl_0) 88 | attr_pb_0 = lbl_0[:, attr_num] 89 | coeff = -1 if attr_pb_0 > 0.5 else 1 90 | 91 | range_alpha = torch.linspace(0, scale*coeff, n_steps) 92 | for i,alpha in enumerate(range_alpha): 93 | 94 | w_1 = trainer.T_net(w_0.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device)) 95 | w_1 = w_1.view(w_0.size()) 96 | w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1) 97 | x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False) 98 | utils.save_image(clip_img(x_1), save_dir + attr + '_%d'%k + '_alpha_'+ str(i) + '.jpg') 99 | -------------------------------------------------------------------------------- /image/README.md: -------------------------------------------------------------------------------- 1 | You can view our [video editing results](https://drive.google.com/drive/folders/1aIfmbgJL1CUFgZQzqDVaUtLHrqxS6QjP?usp=sharing) presented in the paper. 2 | 3 | In `license_of_images_used.json` we list the licenses for all the images presented in the paper. 4 | -------------------------------------------------------------------------------- /image/license_of_images_used.json: -------------------------------------------------------------------------------- 1 | {"Paper Figure 1, 1": {"photo_url": "https://www.flickr.com/photos/flywithinsun/5946850626/in/photostream/", "photo_title": "Emma Watson_35", "author": "Kingsley Huang", "country": "", "license": "Attribution-NonCommercial License", "license_url": "https://creativecommons.org/licenses/by-nc/2.0/"}, "Paper Figure 1, 2": {"photo_url": "https://www.flickr.com/photos/stevegarfield/3197571945/in/photostream/", "photo_title": "Official portrait of President-elect Barack Obama", "author": "Steve Garfield", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/"}, "Paper Figure 2, 1": {"photo_url": "https://www.flickr.com/photos/matthewalmonroth/14807460745/", "photo_title": "Ethan one", "author": "Matthew Roth", "country": "", "license": "Attribution-NonCommercial License", "license_url": "https://creativecommons.org/licenses/by-nc/2.0/", "date_uploaded": "2014-08-02", "date_crawled": "2018-10-10"}, "Paper Figure 2, 2": {"photo_url": "https://www.flickr.com/photos/iv4quad/15508606518/", "photo_title": "A European Honeymoon: Day 4 (Rhine River and Cologne)", "author": "John Carkeet", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2014-11-02", "date_crawled": "2018-10-10"}, "Paper Figure 2, 3": {"photo_url": "https://www.flickr.com/photos/fotospresidencia_sv/39217060254/", "photo_title": "Festival para el Buen Vivir y Gobernando con la Gente-Cuscatancingo, San Salvador.", "author": "Presidencia El Salvador", "country": "", "license": "Public Domain Dedication (CC0)", "license_url": "https://creativecommons.org/publicdomain/zero/1.0/", "date_uploaded": "2018-01-27", "date_crawled": "2018-10-10"}, "Paper Figure 2, 4": {"photo_url": "https://www.flickr.com/photos/cd_1940/3088242838/", "photo_title": "20081129-115543-00170", "author": "cd_1940", "country": "", "license": "Attribution-NonCommercial License", "license_url": "https://creativecommons.org/licenses/by-nc/2.0/", "date_uploaded": "2008-12-07", "date_crawled": "2018-10-10"}, "Paper Figure 4, 1": {"photo_url": "https://www.flickr.com/photos/j_benson/35582782503/", "photo_title": "State Street", "author": "John Benson", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2017-08-05", "date_crawled": "2018-10-10"}, "Paper Figure 4, 2": {"photo_url": "https://www.flickr.com/photos/apocalyse/8095577481/", "photo_title": "P1140375", "author": "akuhlrock", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2012-10-17", "date_crawled": "2018-10-10"}, "Paper Figure 4, 3": {"photo_url": "https://www.flickr.com/photos/quakecon/3923570806/", "photo_title": "Tournament_0075", "author": "QuakeCon", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2009-09-15", "date_crawled": "2018-10-10"}, "Paper Figure 4, 4": {"photo_url": "https://www.flickr.com/photos/alphalab/8880216552/", "photo_title": "InnovationWorks_DEMO_DAY2013_017", "author": "AlphaLab Startup Accelerator", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2013-05-29", "date_crawled": "2018-10-10"}, "Paper Figure 5, 1": {"photo_url": "https://www.flickr.com/photos/147837124@N06/32545819081/", "photo_title": "Desembargadora Eleitoral Cristiane Chaves Frota", "author": "TRE - RJ", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2017-02-02", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 1": {"photo_url": "https://www.flickr.com/photos/colinbrown/2325314759/", "photo_title": "IMG_7208.jpg", "author": "Colin Brown", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2008-03-11", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 2": {"photo_url": "https://www.flickr.com/photos/12567328@N00/32701170786/", "photo_title": "IMG_6581", "author": "VcStyle", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2017-02-06", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 3": {"photo_url": "https://www.flickr.com/photos/davidberkowitz/5278053394/", "photo_title": "IMG_0847", "author": "David Berkowitz", "country": "Chile", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2010-12-20", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 4": {"photo_url": "https://www.flickr.com/photos/kathmandu/26378594036/", "photo_title": "Japanese Stone Lantern Lighting Ceremony", "author": "S Pakhrin", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2016-04-13", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 5": {"photo_url": "https://www.flickr.com/photos/9575673@N08/3666061081/", "photo_title": "Garcia Family Portraits", "author": "Jim Legans, Jr", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2009-06-28", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 6": {"photo_url": "https://www.flickr.com/photos/pfaseal/42344546234/", "photo_title": "DSC_7165", "author": "PFA SEAL", "country": "", "license": "Public Domain Mark", "license_url": "https://creativecommons.org/publicdomain/mark/1.0/", "date_uploaded": "2018-06-28", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 7": {"photo_url": "https://www.flickr.com/photos/catt1788/35888018712/", "photo_title": "Hamburg", "author": "Cat Burston", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2017-07-21", "date_crawled": "2018-10-10"}, "Supplementary Figure 1/2, 8": {"photo_url": "https://www.flickr.com/photos/mdgovpics/27076452033/", "photo_title": "PTECH Press Conference", "author": "Maryland GovPics", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2016-06-15", "date_crawled": "2018-10-10"}, "Supplementary Figure 4, 1": {"photo_url": "https://www.flickr.com/photos/21710901@N08/28545927034/", "photo_title": "Julia Senior Pics Album", "author": "seanccochran", "country": "", "license": "Attribution License", "license_url": "https://creativecommons.org/licenses/by/2.0/", "date_uploaded": "2016-08-23", "date_crawled": "2018-10-10"}, "Supplementary Figure 4, 2": {"photo_url": "https://www.flickr.com/photos/campact/27770801245/", "photo_title": "HandinHand LEIPZIG 19.6.16", "author": "campact", "country": "", "license": "Attribution-NonCommercial License", "license_url": "https://creativecommons.org/licenses/by-nc/2.0/", "date_uploaded": "2016-06-19", "date_crawled": "2018-10-10"}} -------------------------------------------------------------------------------- /image/user_interface.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/image/user_interface.jpg -------------------------------------------------------------------------------- /image/video_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/image/video_result.jpg -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | #---------------------------------------------------------------------------- 12 | # Latent classification model 13 | 14 | class LCNet(nn.Module): 15 | def __init__(self, fmaps=[6048, 2048, 512, 40], activ='relu'): 16 | super().__init__() 17 | # Linear layers 18 | self.fcs = nn.ModuleList() 19 | for i in range(len(fmaps)-1): 20 | in_channel = fmaps[i] 21 | out_channel = fmaps[i+1] 22 | self.fcs.append(nn.Linear(in_channel, out_channel, bias=True)) 23 | # Activation 24 | if activ == 'relu': 25 | self.relu = nn.ReLU() 26 | elif activ == 'leakyrelu': 27 | self.relu = nn.LeakyReLU(0.2) 28 | 29 | def forward(self, x): 30 | for layer in self.fcs[:-1]: 31 | x = self.relu(layer(x)) 32 | x = self.fcs[-1](x) 33 | return x 34 | 35 | #---------------------------------------------------------------------------- 36 | # Get weight tensor for a convolution or fully-connected layer. 37 | 38 | def get_weight(weight, gain=1, use_wscale=True, lrmul=1): 39 | fan_in = np.prod(weight.size()[1:]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out] 40 | he_std = gain / np.sqrt(fan_in) # He init 41 | # Equalized learning rate and custom learning rate multiplier. 42 | if use_wscale: 43 | runtime_coef = he_std * lrmul 44 | else: 45 | runtime_coef = lrmul 46 | return weight * runtime_coef 47 | 48 | 49 | #---------------------------------------------------------------------------- 50 | # Apply activation func. 51 | 52 | def apply_bias_act(x, act='linear', alpha=None, gain=None): 53 | if act == 'linear': 54 | return x 55 | elif act == 'lrelu': 56 | if alpha is None: 57 | alpha = 0.2 58 | if gain is None: 59 | gain = np.sqrt(2) 60 | x = F.leaky_relu(x, negative_slope=alpha) 61 | x = x*gain 62 | return x 63 | 64 | 65 | #---------------------------------------------------------------------------- 66 | # Fully-connected layer. 67 | 68 | class Dense_layer(nn.Module): 69 | def __init__(self, input_size, output_size, gain=1, use_wscale=True, lrmul=1): 70 | super(Dense_layer, self).__init__() 71 | self.weight = nn.Parameter(torch.Tensor(output_size, input_size)) 72 | self.bias = nn.Parameter(torch.Tensor(output_size)) 73 | self.gain = gain 74 | self.use_wscale = use_wscale 75 | self.lrmul = lrmul 76 | nn.init.xavier_uniform_(self.weight) 77 | nn.init.zeros_(self.bias) 78 | 79 | def forward(self, x): 80 | w = get_weight(self.weight, gain=self.gain, use_wscale=self.use_wscale, lrmul=self.lrmul) 81 | b = self.bias 82 | x = F.linear(x, w, bias=b) 83 | return x 84 | 85 | #---------------------------------------------------------------------------- 86 | # Mapping network to modify the disentangled latent w+. 87 | 88 | class F_mapping(nn.Module): 89 | def __init__( 90 | self, 91 | dlatent_size = 512, # Transformed latent (W) dimensionality. 92 | mapping_layers = 18, # Number of mapping layers. 93 | mapping_fmaps = 512, # Number of activations in the mapping layers. 94 | mapping_lrmul = 1, # Learning rate multiplier for the mapping layers. 95 | mapping_nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu', etc. 96 | dtype = torch.float32, # Data type to use for activations and outputs. 97 | **_kwargs): # Ignore unrecognized keyword args. 98 | super().__init__() 99 | 100 | self.mapping_layers = mapping_layers 101 | self.act = mapping_nonlinearity 102 | self.dtype = dtype 103 | 104 | self.dense = nn.ModuleList() 105 | # Mapping layers. 106 | for layer_idx in range(mapping_layers): 107 | self.dense.append(Dense_layer(mapping_fmaps, mapping_fmaps, lrmul=mapping_lrmul)) 108 | 109 | def forward(self, latents_in, coeff): 110 | # Inputs. 111 | latents_in = latents_in.type(self.dtype) 112 | 113 | x = latents_in.split(split_size=512, dim=1) 114 | out = [] 115 | # Mapping layers. 116 | for layer_idx in range(self.mapping_layers): 117 | out.append(apply_bias_act(self.dense[layer_idx](x[layer_idx]), act='linear')) 118 | x = torch.cat(out, dim=1) 119 | 120 | coeff = coeff.view(x.size(0), -1) 121 | x = coeff * x + latents_in 122 | 123 | # Output. 124 | assert x.dtype == self.dtype 125 | return x -------------------------------------------------------------------------------- /notebooks/figure_sequential_edit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Copyright (c) 2021, InterDigital R&D France. All rights reserved.\n", 10 | "#\n", 11 | "# This source code is made available under the license found in the\n", 12 | "# LICENSE.txt in the root directory of this source tree.\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "%matplotlib inline\n", 15 | "\n", 16 | "import argparse\n", 17 | "import glob\n", 18 | "import os\n", 19 | "import numpy as np\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F\n", 23 | "import torch.utils.data as data\n", 24 | "import yaml\n", 25 | "\n", 26 | "from PIL import Image\n", 27 | "from torchvision import transforms, utils, models\n", 28 | "\n", 29 | "import sys\n", 30 | "\n", 31 | "if os.getcwd().split('/')[-1] == 'notebooks':\n", 32 | " sys.path.append('..')\n", 33 | " os.chdir('..')\n", 34 | "\n", 35 | "from datasets import *\n", 36 | "from trainer import *\n", 37 | "from utils.functions import *\n", 38 | "\n", 39 | "torch.backends.cudnn.enabled = True\n", 40 | "torch.backends.cudnn.deterministic = True\n", 41 | "torch.backends.cudnn.benchmark = True\n", 42 | "torch.autograd.set_detect_anomaly(True)\n", 43 | "Image.MAX_IMAGE_PIXELS = None\n", 44 | "device = torch.device('cuda')\n", 45 | "\n", 46 | "parser = argparse.ArgumentParser()\n", 47 | "parser.add_argument('--config', type=str, default='001', help='Path to the config file.')\n", 48 | "parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.')\n", 49 | "parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path')\n", 50 | "parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path')\n", 51 | "parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path')\n", 52 | "parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier')\n", 53 | "parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')\n", 54 | "opts = parser.parse_args([])\n", 55 | "\n", 56 | "# Celeba attribute list\n", 57 | "attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \\\n", 58 | " 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \\\n", 59 | " 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \\\n", 60 | " 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \\\n", 61 | " 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \\\n", 62 | " 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \\\n", 63 | " 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \\\n", 64 | " 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Initialize trainer model.\n", 74 | "log_dir = os.path.join(opts.log_path, opts.config) + '/'\n", 75 | "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n", 76 | "\n", 77 | "trainer = Trainer(config, None, None, opts.label_file)\n", 78 | "trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) \n", 79 | "trainer.to(device)\n", 80 | "print('Load model.')" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "### Figure 1. Teaser" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "%matplotlib inline\n", 97 | "testdata_dir = './data/teaser/'\n", 98 | "\n", 99 | "teaser_attrs = [{'Smiling':1, 'Bangs':1, 'Arched_Eyebrows':1, 'Young':-1}, \\\n", 100 | " {'Young':1, 'Smiling':1, 'No_Beard':-1, 'Eyeglasses':1} ]\n", 101 | "\n", 102 | "with torch.no_grad():\n", 103 | " \n", 104 | " for k, idx in enumerate([1,2]):\n", 105 | "\n", 106 | " w_0 = np.load(testdata_dir + 'latent_code_%05d.npy'%idx)\n", 107 | " w_0 = torch.tensor(w_0).to(device)\n", 108 | " \n", 109 | " x_0 = img_to_tensor(Image.open(testdata_dir + '%05d.jpg'%idx))\n", 110 | " x_0 = x_0.unsqueeze(0).to(device)\n", 111 | " img_l = [x_0] # original image\n", 112 | " \n", 113 | " x_1, _ = trainer.StyleGAN([w_0], input_is_latent=True, randomize_noise=False)\n", 114 | " x_0 = torch.ones((x_1.size(0), x_1.size(1), x_1.size(2),x_1.size(3)+40)).type_as(x_1)\n", 115 | " x_0[:,:,:,20:1044] = x_1[:,:,:,:]\n", 116 | " img_l.append(x_0) # projected image\n", 117 | " \n", 118 | " w_1 = w_0\n", 119 | " attrs = teaser_attrs[k]\n", 120 | " for attr in list(attrs.keys()):\n", 121 | " \n", 122 | " trainer.attr_num = attr_dict[attr]\n", 123 | " trainer.load_model(log_dir)\n", 124 | " \n", 125 | " alpha = torch.tensor(1.0) * attrs[attr]\n", 126 | " w_1 = trainer.T_net(w_1.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device))\n", 127 | " w_1 = w_1.view(w_0.size())\n", 128 | " w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)\n", 129 | " x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)\n", 130 | " img_l.append(x_1.data)\n", 131 | "\n", 132 | " img = img_l[0] if len(img_l)==1 else torch.cat(img_l, 3)\n", 133 | " img = np.clip(clip_img(img)[0].cpu().numpy()*255.,0,255).astype(np.uint8)\n", 134 | " img = Image.fromarray(img.transpose(1,2,0))\n", 135 | " plt.figure(figsize=(30,5))\n", 136 | " plt.imshow(img)\n", 137 | " plt.axis('off')\n", 138 | "plt.show()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "### Figure 4. Sequential facial attribute editing" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "testdata_dir = './data/test/'\n", 155 | "\n", 156 | "attrs_list = [{'Chubby':-0.5, 'Blond_Hair':1.5, 'Smiling':1, 'Wearing_Lipstick':1, 'Eyeglasses':1.5}, \\\n", 157 | " {'Eyeglasses':-1.5, 'Bangs':1, 'Bags_Under_Eyes':1, 'Smiling':-1, 'Young':-1}, \\\n", 158 | " {'Smiling':1, 'No_Beard':-1, 'Receding_Hairline':1, 'Eyeglasses':1, 'Arched_Eyebrows':1}, \\\n", 159 | " {'Smiling':-1, 'Chubby':-0.5, 'Goatee':1, 'Eyeglasses':1, 'Pale_Skin':1} ]\n", 160 | "\n", 161 | "with torch.no_grad():\n", 162 | " \n", 163 | " for k, idx in enumerate([4,5,6,7]):\n", 164 | "\n", 165 | " w_0 = np.load(testdata_dir + 'latent_code_%05d.npy'%idx)\n", 166 | " w_0 = torch.tensor(w_0).to(device)\n", 167 | " \n", 168 | " x_0 = img_to_tensor(Image.open(testdata_dir + '%05d.jpg'%idx))\n", 169 | " x_0 = x_0.unsqueeze(0).to(device)\n", 170 | " img_l = [x_0] # original image\n", 171 | " \n", 172 | " x_1, _ = trainer.StyleGAN([w_0], input_is_latent=True, randomize_noise=False)\n", 173 | " x_0 = torch.ones((x_1.size(0), x_1.size(1), x_1.size(2),x_1.size(3)+40)).type_as(x_1)\n", 174 | " x_0[:,:,:,20:1044] = x_1[:,:,:,:]\n", 175 | " img_l.append(x_0) # projected image\n", 176 | " \n", 177 | " w_1 = w_0\n", 178 | " attrs = attrs_list[k]\n", 179 | " for attr in list(attrs.keys()):\n", 180 | " \n", 181 | " trainer.attr_num = attr_dict[attr]\n", 182 | " trainer.load_model(log_dir)\n", 183 | " \n", 184 | " alpha = torch.tensor(1.0) * attrs[attr]\n", 185 | " w_1 = trainer.T_net(w_1.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device))\n", 186 | " w_1 = w_1.view(w_0.size())\n", 187 | " w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)\n", 188 | " x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)\n", 189 | " img_l.append(x_1.data)\n", 190 | "\n", 191 | " img = img_l[0] if len(img_l)==1 else torch.cat(img_l, 3)\n", 192 | " img = np.clip(clip_img(img)[0].cpu().numpy()*255.,0,255).astype(np.uint8)\n", 193 | " img = Image.fromarray(img.transpose(1,2,0))\n", 194 | " plt.figure(figsize=(30,5))\n", 195 | " plt.imshow(img)\n", 196 | " plt.axis('off')\n", 197 | " plt.show()" 198 | ] 199 | } 200 | ], 201 | "metadata": { 202 | "kernelspec": { 203 | "display_name": "geo", 204 | "language": "python", 205 | "name": "geo" 206 | }, 207 | "language_info": { 208 | "codemirror_mode": { 209 | "name": "ipython", 210 | "version": 3 211 | }, 212 | "file_extension": ".py", 213 | "mimetype": "text/x-python", 214 | "name": "python", 215 | "nbconvert_exporter": "python", 216 | "pygments_lexer": "ipython3", 217 | "version": "3.6.13" 218 | } 219 | }, 220 | "nbformat": 4, 221 | "nbformat_minor": 4 222 | } 223 | -------------------------------------------------------------------------------- /notebooks/figure_supplementary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Copyright (c) 2021, InterDigital R&D France. All rights reserved.\n", 10 | "#\n", 11 | "# This source code is made available under the license found in the\n", 12 | "# LICENSE.txt in the root directory of this source tree.\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "%matplotlib inline\n", 15 | "\n", 16 | "import argparse\n", 17 | "import glob\n", 18 | "import os\n", 19 | "import numpy as np\n", 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "import torch.nn.functional as F\n", 23 | "import torch.utils.data as data\n", 24 | "import yaml\n", 25 | "\n", 26 | "from PIL import Image\n", 27 | "from torchvision import transforms, utils, models\n", 28 | "\n", 29 | "import sys\n", 30 | "\n", 31 | "if os.getcwd().split('/')[-1] == 'notebooks':\n", 32 | " sys.path.append('..')\n", 33 | " os.chdir('..')\n", 34 | " \n", 35 | "from datasets import *\n", 36 | "from trainer import *\n", 37 | "from utils.functions import *\n", 38 | "\n", 39 | "torch.backends.cudnn.enabled = True\n", 40 | "torch.backends.cudnn.deterministic = True\n", 41 | "torch.backends.cudnn.benchmark = True\n", 42 | "torch.autograd.set_detect_anomaly(True)\n", 43 | "Image.MAX_IMAGE_PIXELS = None\n", 44 | "device = torch.device('cuda')\n", 45 | "\n", 46 | "parser = argparse.ArgumentParser()\n", 47 | "parser.add_argument('--config', type=str, default='003', help='Path to the config file.')\n", 48 | "parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.')\n", 49 | "parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path')\n", 50 | "parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path')\n", 51 | "parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path')\n", 52 | "parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier')\n", 53 | "parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')\n", 54 | "opts = parser.parse_args([])\n", 55 | "\n", 56 | "# Celeba attribute list\n", 57 | "attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \\\n", 58 | " 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \\\n", 59 | " 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \\\n", 60 | " 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \\\n", 61 | " 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \\\n", 62 | " 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \\\n", 63 | " 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \\\n", 64 | " 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# Initialize trainer model.\n", 74 | "log_dir = os.path.join(opts.log_path, opts.config) + '/'\n", 75 | "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n", 76 | "\n", 77 | "trainer = Trainer(config, None, None, opts.label_file)\n", 78 | "trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) \n", 79 | "trainer.to(device)\n", 80 | "\n", 81 | "testdata_dir = './data/supple/'\n", 82 | "latent_list = glob.glob1(testdata_dir,'*.npy')\n", 83 | "latent_list.sort()\n", 84 | "\n", 85 | "attrs_list = [{'Arched_Eyebrows':1, 'Chubby':-0.5, 'Pointy_Nose':-1, 'Smiling':-1, 'Pale_Skin':1}, \\\n", 86 | " {'Black_Hair':1, 'Smiling':1, 'Bags_Under_Eyes':1, 'Young':1, 'Eyeglasses':1.2}, \\\n", 87 | " {'Chubby':-0.5, 'Bangs':1, 'Smiling':1, 'Young':-1, 'Heavy_Makeup':1}, \\\n", 88 | " {'Chubby':-0.5, 'Smiling':-1, 'Narrow_Eyes':-0.5, 'Heavy_Makeup':1, 'Eyeglasses':1}, \\\n", 89 | " {'Eyeglasses':-1.2, 'Smiling':1, 'Goatee':1, 'Arched_Eyebrows':1, 'Young':1}, \\\n", 90 | " {'No_Beard':-1, 'Bushy_Eyebrows':1, 'Mouth_Slightly_Open':1, 'Receding_Hairline':1, 'Eyeglasses':1.2}, \\\n", 91 | " {'Smiling':-1, 'Chubby':-0.5, 'No_Beard':-1, 'Eyeglasses':1.2, 'Receding_Hairline':1}, \\\n", 92 | " {'Smiling':1, 'Goatee':1, 'Eyeglasses':-1, 'Arched_Eyebrows':-1, 'Young':1} ]" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "### Single attribute manipulation" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "%matplotlib inline\n", 109 | "with torch.no_grad():\n", 110 | " \n", 111 | " for k, latent in enumerate(latent_list):\n", 112 | "\n", 113 | " w_0 = np.load(testdata_dir + latent)\n", 114 | " w_0 = torch.tensor(w_0).to(device)\n", 115 | " \n", 116 | " idx = [int(s) for s in latent.replace('_','.').split('.') if s.isdigit()][0]\n", 117 | " x_0 = img_to_tensor(Image.open(testdata_dir + '%05d.jpg'%idx))\n", 118 | " x_0 = x_0.unsqueeze(0).to(device)\n", 119 | " img_l = [x_0] # original image\n", 120 | " \n", 121 | " x_1, _ = trainer.StyleGAN([w_0], input_is_latent=True, randomize_noise=False)\n", 122 | " x_0 = torch.ones((x_1.size(0), x_1.size(1), x_1.size(2),x_1.size(3)+40)).type_as(x_1)\n", 123 | " x_0[:,:,:,20:1044] = x_1[:,:,:,:]\n", 124 | " img_l.append(x_0) # projected image\n", 125 | " \n", 126 | " attrs = attrs_list[k]\n", 127 | " for attr in list(attrs.keys()):\n", 128 | " \n", 129 | " trainer.attr_num = attr_dict[attr]\n", 130 | " trainer.load_model(log_dir)\n", 131 | " \n", 132 | " alpha = torch.tensor(1.0) * attrs[attr]\n", 133 | " w_1 = trainer.T_net(w_0.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device))\n", 134 | " w_1 = w_1.view(w_0.size())\n", 135 | " w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)\n", 136 | " x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)\n", 137 | " img_l.append(x_1.data)\n", 138 | "\n", 139 | " img = img_l[0] if len(img_l)==1 else torch.cat(img_l, 3)\n", 140 | " img = np.clip(clip_img(img)[0].cpu().numpy()*255.,0,255).astype(np.uint8)\n", 141 | " img = Image.fromarray(img.transpose(1,2,0))\n", 142 | " plt.figure(figsize=(30,5))\n", 143 | " plt.imshow(img)\n", 144 | " plt.axis('off')\n", 145 | " plt.show()" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "### Sequential attribute manipulation" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "with torch.no_grad():\n", 162 | " \n", 163 | " for k, latent in enumerate(latent_list):\n", 164 | "\n", 165 | " w_0 = np.load(testdata_dir + latent)\n", 166 | " w_0 = torch.tensor(w_0).to(device)\n", 167 | " \n", 168 | " idx = [int(s) for s in latent.replace('_','.').split('.') if s.isdigit()][0]\n", 169 | " x_0 = img_to_tensor(Image.open(testdata_dir + '%05d.jpg'%idx))\n", 170 | " x_0 = x_0.unsqueeze(0).to(device)\n", 171 | " img_l = [x_0] # original image\n", 172 | " \n", 173 | " x_1, _ = trainer.StyleGAN([w_0], input_is_latent=True, randomize_noise=False)\n", 174 | " x_0 = torch.ones((x_1.size(0), x_1.size(1), x_1.size(2),x_1.size(3)+40)).type_as(x_1)\n", 175 | " x_0[:,:,:,20:1044] = x_1[:,:,:,:]\n", 176 | " img_l.append(x_0) # projected image\n", 177 | " \n", 178 | " attrs = attrs_list[k]\n", 179 | " w_1 = w_0\n", 180 | " for attr in list(attrs.keys()):\n", 181 | " \n", 182 | " trainer.attr_num = attr_dict[attr]\n", 183 | " trainer.load_model(log_dir)\n", 184 | " \n", 185 | " alpha = torch.tensor(1.0) * attrs[attr]\n", 186 | " w_1 = trainer.T_net(w_1.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device))\n", 187 | " w_1 = w_1.view(w_0.size())\n", 188 | " w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)\n", 189 | " x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)\n", 190 | " img_l.append(x_1.data)\n", 191 | "\n", 192 | " img = img_l[0] if len(img_l)==1 else torch.cat(img_l, 3)\n", 193 | " img = np.clip(clip_img(img)[0].cpu().numpy()*255.,0,255).astype(np.uint8)\n", 194 | " img = Image.fromarray(img.transpose(1,2,0))\n", 195 | " plt.figure(figsize=(30,5))\n", 196 | " plt.imshow(img)\n", 197 | " plt.axis('off')\n", 198 | " plt.show()" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "geo", 205 | "language": "python", 206 | "name": "geo" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.6.13" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 4 223 | } 224 | -------------------------------------------------------------------------------- /notebooks/visu_manipulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Copyright (c) 2021, InterDigital R&D France. All rights reserved.\n", 10 | "\n", 11 | "# This source code is made available under the license found in the\n", 12 | "# LICENSE.txt in the root directory of this source tree.\n", 13 | "\n", 14 | "from __future__ import print_function\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "%matplotlib inline\n", 18 | "\n", 19 | "from ipywidgets import interact, interactive, fixed, interact_manual\n", 20 | "import ipywidgets as widgets \n", 21 | "\n", 22 | "import argparse\n", 23 | "import copy\n", 24 | "import glob\n", 25 | "import os\n", 26 | "import numpy as np\n", 27 | "import torch\n", 28 | "import torch.nn as nn\n", 29 | "import torch.nn.functional as F\n", 30 | "import torch.utils.data as data\n", 31 | "import yaml\n", 32 | "\n", 33 | "from PIL import Image\n", 34 | "from torchvision import transforms, utils, models\n", 35 | "\n", 36 | "import sys\n", 37 | "\n", 38 | "if os.getcwd().split('/')[-1] == 'notebooks':\n", 39 | " sys.path.append('..')\n", 40 | " os.chdir('..')\n", 41 | " \n", 42 | "from datasets import *\n", 43 | "from trainer import *\n", 44 | "from utils.functions import *\n", 45 | "\n", 46 | "torch.backends.cudnn.enabled = True\n", 47 | "torch.backends.cudnn.deterministic = True\n", 48 | "torch.backends.cudnn.benchmark = True\n", 49 | "torch.autograd.set_detect_anomaly(True)\n", 50 | "Image.MAX_IMAGE_PIXELS = None\n", 51 | "device = torch.device('cuda')\n", 52 | "\n", 53 | "parser = argparse.ArgumentParser()\n", 54 | "parser.add_argument('--config', type=str, default='001', help='Path to the config file.')\n", 55 | "parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.')\n", 56 | "parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path')\n", 57 | "parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path')\n", 58 | "parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path')\n", 59 | "parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier')\n", 60 | "parser.add_argument('--log_path', type=str, default='./logs/', help='log file path')\n", 61 | "opts = parser.parse_args([])\n", 62 | "\n", 63 | "# Celeba attribute list\n", 64 | "attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \\\n", 65 | " 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \\\n", 66 | " 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \\\n", 67 | " 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \\\n", 68 | " 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \\\n", 69 | " 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \\\n", 70 | " 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \\\n", 71 | " 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39}" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Initialize trainer model.\n", 81 | "log_dir = os.path.join(opts.log_path, opts.config) + '/'\n", 82 | "config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'))\n", 83 | "\n", 84 | "trainer = Trainer(config, None, None, opts.label_file)\n", 85 | "trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) \n", 86 | "trainer.to(device)\n", 87 | "print('Load model.')\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "### Visulization of attribute manipulation" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "# Set desired attributes for manipulation in attr_list\n", 104 | "attr_list = ['Male','Eyeglasses','Young','Smiling']\n", 105 | "testdata_dir = './data/test/'\n", 106 | "\n", 107 | "# Load latent transformer models\n", 108 | "T_net_dict = {}\n", 109 | "for attr in attr_list:\n", 110 | " trainer.attr_num = attr_dict[attr]\n", 111 | " trainer.load_model(log_dir)\n", 112 | " T_net_dict[attr] = copy.deepcopy(trainer.T_net)\n", 113 | " \n", 114 | "# Visualization function\n", 115 | "def visu_manipulation(seed, **attr_scale):\n", 116 | " with torch.no_grad():\n", 117 | " w_0 = np.load(testdata_dir + 'latent_code_%05d.npy'%int(seed))\n", 118 | " w_0 = torch.tensor(w_0).to(device)\n", 119 | " w_1 = w_0\n", 120 | " for key in attr_scale.keys():\n", 121 | " if attr_scale[key] != 0:\n", 122 | " w_1 = T_net_dict[key](w_1.view(w_0.size(0),-1), torch.tensor(attr_scale[key]).unsqueeze(0).to(device))\n", 123 | " w_1 = w_1.view(w_0.size())\n", 124 | " w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1)\n", 125 | " x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False)\n", 126 | " img = np.clip(clip_img(x_1)[0].cpu().numpy()*255.,0,255).astype(np.uint8)\n", 127 | " img = Image.fromarray(img.transpose(1,2,0))\n", 128 | " plt.figure(figsize=(10,10))\n", 129 | " plt.imshow(img)\n", 130 | " plt.axis('off')\n", 131 | " plt.show()" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "metadata": { 138 | "scrolled": false 139 | }, 140 | "outputs": [], 141 | "source": [ 142 | "# User interface\n", 143 | "%matplotlib inline\n", 144 | "attr_scale = {key: (-1.5,1.5,0.3) for key in attr_list}\n", 145 | "interact(visu_manipulation, seed=[0,1,2,3,4,5,6,7,8], **attr_scale)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": null, 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [] 154 | } 155 | ], 156 | "metadata": { 157 | "kernelspec": { 158 | "display_name": "geo", 159 | "language": "python", 160 | "name": "geo" 161 | }, 162 | "language_info": { 163 | "codemirror_mode": { 164 | "name": "ipython", 165 | "version": 3 166 | }, 167 | "file_extension": ".py", 168 | "mimetype": "text/x-python", 169 | "name": "python", 170 | "nbconvert_exporter": "python", 171 | "pygments_lexer": "ipython3", 172 | "version": "3.6.13" 173 | } 174 | }, 175 | "nbformat": 4, 176 | "nbformat_minor": 4 177 | } 178 | -------------------------------------------------------------------------------- /pretraining/latent_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.data as data 13 | import yaml 14 | 15 | from PIL import Image 16 | from torchvision import transforms, utils 17 | from tensorboard_logger import Logger 18 | 19 | import sys 20 | sys.path.append(".") 21 | sys.path.append("..") 22 | 23 | from datasets import * 24 | from nets import * 25 | 26 | torch.backends.cudnn.enabled = True 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = True 29 | torch.autograd.set_detect_anomaly(True) 30 | Image.MAX_IMAGE_PIXELS = None 31 | device = torch.device('cuda') 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--config', type=str, default='1001', help='Path to the config file.') 35 | parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path') 36 | parser.add_argument('--label_path', type=str, default='./data/celebahq_anno.npy', help='label file path') 37 | parser.add_argument('--mapping_layers', type=int, default=3, help='mapping layers num') 38 | parser.add_argument('--fmaps', type=int, default=512, help='fmaps num') 39 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 40 | parser.add_argument('--resume', type=bool, default=False, help='resume from checkpoint') 41 | parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') 42 | parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') 43 | opts = parser.parse_args() 44 | 45 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 46 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 47 | 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 48 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 49 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 50 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 51 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 52 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 53 | 54 | batch_size = 256 55 | epochs = 20 56 | 57 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 58 | os.makedirs(log_dir, exist_ok=True) 59 | logger = Logger(log_dir) 60 | 61 | Latent_Classifier = LCNet(fmaps=[9216, 2048, 512, 40], activ='leakyrelu') 62 | Latent_Classifier.to(device) 63 | 64 | dataset_A = LatentDataset(opts.latent_path, opts.label_path, training_set=True) 65 | loader_A = data.DataLoader(dataset_A, batch_size=batch_size, shuffle=True) 66 | 67 | params = list(Latent_Classifier.parameters()) 68 | optimizer = torch.optim.Adam(params, lr=1e-4, weight_decay=0.0005) 69 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 70 | 71 | BCEloss = nn.BCEWithLogitsLoss(reduction='none') 72 | n_iter = 0 73 | for n_epoch in range(epochs): 74 | 75 | scheduler.step() 76 | 77 | for i, list_A in enumerate(loader_A): 78 | 79 | dlatent_A, lbl_A = list_A 80 | dlatent_A, lbl_A = dlatent_A.to(device), lbl_A.to(device) 81 | 82 | predict_lbl_A = Latent_Classifier(dlatent_A.view(dlatent_A.size(0), -1)) 83 | predict_lbl = F.sigmoid(predict_lbl_A) 84 | 85 | loss = BCEloss(predict_lbl_A, lbl_A.float()) 86 | loss = loss.mean() 87 | 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | if (n_iter + 1) % 10 == 0: 93 | logger.log_value('loss', loss.item(), n_iter + 1) 94 | n_iter += 1 95 | 96 | torch.save(Latent_Classifier.state_dict(),'{:s}/latent_classifier_epoch_{:d}.pth'.format(log_dir, n_epoch + 1)) 97 | 98 | -------------------------------------------------------------------------------- /pretraining/latent_classifier_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.data as data 13 | import yaml 14 | 15 | from PIL import Image 16 | from torchvision import transforms, utils 17 | from tensorboard_logger import Logger 18 | 19 | import sys 20 | sys.path.append(".") 21 | sys.path.append("..") 22 | 23 | from datasets import * 24 | from nets import * 25 | 26 | torch.backends.cudnn.enabled = True 27 | torch.backends.cudnn.deterministic = True 28 | torch.backends.cudnn.benchmark = True 29 | torch.autograd.set_detect_anomaly(True) 30 | Image.MAX_IMAGE_PIXELS = None 31 | device = torch.device('cuda') 32 | 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--config', type=str, default='1001', help='Path to the config file.') 35 | parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path') 36 | parser.add_argument('--label_path', type=str, default='./data/celebahq_anno.npy', help='label file path') 37 | parser.add_argument('--mapping_layers', type=int, default=3, help='mapping layers num') 38 | parser.add_argument('--fmaps', type=int, default=512, help='fmaps num') 39 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 40 | parser.add_argument('--resume', type=bool, default=False, help='resume from checkpoint') 41 | parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') 42 | parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') 43 | opts = parser.parse_args() 44 | 45 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 46 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 47 | 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 48 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 49 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 50 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 51 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 52 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 53 | 54 | batch_size = 256 55 | epochs = 20 56 | 57 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 58 | 59 | def last_model(log_dir, model_name=None): 60 | if model_name == None: 61 | files_pth = [i for i in os.listdir(log_dir) if i.endswith('.pth')] 62 | files_pth.sort() 63 | return torch.load(log_dir + files_pth[-1]) 64 | else: 65 | return torch.load(log_dir + model_name) 66 | 67 | dataset_A = LatentDataset(opts.latent_path, opts.label_path, training_set=False) 68 | loader_A = data.DataLoader(dataset_A, batch_size=batch_size, shuffle=True) 69 | 70 | Latent_Classifier = LCNet(fmaps=[9216, 2048, 512, 40], activ='leakyrelu') 71 | Latent_Classifier.to(device) 72 | Latent_Classifier.load_state_dict(last_model(log_dir, None)) 73 | Latent_Classifier.eval() 74 | 75 | total_num = dataset_A.__len__() 76 | valid_num = [] 77 | real_lbl = [] 78 | pred_lbl = [] 79 | 80 | with torch.no_grad(): 81 | for i, list_A in enumerate(loader_A): 82 | 83 | dlatent_A, lbl_A = list_A 84 | dlatent_A, lbl_A = dlatent_A.to(device), lbl_A.to(device) 85 | 86 | predict_lbl_A = Latent_Classifier(dlatent_A.view(dlatent_A.size(0), -1)) 87 | predict_lbl = F.sigmoid(predict_lbl_A).round().long() 88 | 89 | real_lbl.append(lbl_A) 90 | pred_lbl.append(predict_lbl.data) 91 | valid_num.append((lbl_A == predict_lbl).long()) 92 | 93 | real_lbl = torch.cat(real_lbl, dim=0) 94 | pred_lbl = torch.cat(pred_lbl, dim=0) 95 | valid_num = torch.cat(valid_num, dim=0) 96 | 97 | T_num = torch.sum(real_lbl, dim=0) 98 | F_num = total_num - T_num 99 | 100 | pred_T_num = torch.sum(pred_lbl, dim=0) 101 | pred_F_num = total_num - pred_T_num 102 | 103 | True_Positive = torch.sum(valid_num * real_lbl, dim=0) 104 | True_Negative = torch.sum(valid_num * (1 - real_lbl), dim=0) 105 | 106 | # Recall 107 | recall_T = True_Positive.float()/(T_num.float() + 1e-8) 108 | recall_F = True_Negative.float()/(F_num.float() + 1e-8) 109 | 110 | # Precision 111 | precision_T = True_Positive.float()/(pred_T_num.float() + 1e-8) 112 | precesion_F = True_Negative.float()/(pred_F_num.float() + 1e-8) 113 | 114 | # Accuracy 115 | valid_num = torch.sum(valid_num, dim=0) 116 | accuracy = valid_num.float()/total_num 117 | 118 | for i in range(40): 119 | print('%0.3f'%recall_T[i].item(), '%0.3f'%recall_F[i].item(), '%0.3f'%precision_T[i].item(), '%0.3f'%precesion_F[i].item(), '%0.3f'%accuracy[i].item()) -------------------------------------------------------------------------------- /run_video_manip.sh: -------------------------------------------------------------------------------- 1 | VideoName='FP006542HD02' 2 | Attribute='Smiling' 3 | Scale='1' 4 | Sigma='3' # Choose appropriate gaussian filter size 5 | VideoDir='./data/video' 6 | Path=${PWD} 7 | 8 | 9 | # Cut video to frames 10 | python video_processing.py --function 'video_to_frames' --video_path ${VideoDir}/${VideoName}.mp4 #--resize 11 | 12 | # Crop and align the faces in each frame 13 | python video_processing.py --function 'align_frames' --video_path ${VideoDir}/${VideoName}.mp4 --filter_size=${Sigma} --optical_flow 14 | 15 | # Project each frame to StyleGAN2 latent space 16 | cd pixel2style2pixel/ 17 | python scripts/inference.py --checkpoint_path=pretrained_models/psp_ffhq_encode.pt \ 18 | --data_path=${Path}/outputs/video/${VideoName}/${VideoName}_crop_align \ 19 | --exp_dir=${Path}/outputs/video/${VideoName}/${VideoName}_crop_align_latent \ 20 | --test_batch_size=1 21 | 22 | # Achieve latent manipulation 23 | cd ${Path} 24 | python video_processing.py --function 'latent_manipulation' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --alpha=${Scale} 25 | 26 | # Reproject the manipulated frames to the original video 27 | python video_processing.py --function 'reproject_origin' --video_path ${VideoDir}/${VideoName}.mp4 --seamless 28 | python video_processing.py --function 'reproject_manipulate' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --seamless 29 | python video_processing.py --function 'compare_frames' --video_path ${VideoDir}/${VideoName}.mp4 --attr ${Attribute} --strs 'Original,Projected,Manipulated' 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import glob 8 | import os 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.utils.data as data 14 | import yaml 15 | 16 | from PIL import Image 17 | from torchvision import transforms, utils, models 18 | from tensorboard_logger import Logger 19 | 20 | from datasets import * 21 | from trainer import * 22 | from utils.functions import * 23 | 24 | torch.backends.cudnn.enabled = True 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = True 27 | torch.autograd.set_detect_anomaly(True) 28 | Image.MAX_IMAGE_PIXELS = None 29 | device = torch.device('cuda') 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 33 | parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.') 34 | parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path') 35 | parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path') 36 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path') 37 | parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier') 38 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 39 | parser.add_argument('--out_path', type=str, default='./outputs/', help='output path') 40 | opts = parser.parse_args() 41 | 42 | # Celeba attribute list 43 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 44 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 45 | 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 46 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 47 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 48 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 49 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 50 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 51 | 52 | 53 | n_steps = 7 54 | scale = 1.5 55 | 56 | with torch.no_grad(): 57 | 58 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 59 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r')) 60 | 61 | save_dir = opts.out_path + 'test/' 62 | os.makedirs(save_dir, exist_ok=True) 63 | 64 | attr = opts.attr 65 | attr_num = attr_dict[attr] 66 | 67 | # Initialize trainer 68 | trainer = Trainer(config, attr_num, attr, opts.label_file) 69 | trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) 70 | trainer.load_model(log_dir) 71 | trainer.to(device) 72 | 73 | testdata_dir = './data/test/' 74 | img_list = [glob.glob1(testdata_dir, ext) for ext in ['*jpg','*png']] 75 | img_list = [item for sublist in img_list for item in sublist] 76 | img_list.sort() 77 | 78 | for idx in range(len(img_list)): 79 | 80 | x_0 = img_to_tensor(Image.open(testdata_dir + img_list[idx])) 81 | x_0 = x_0.unsqueeze(0).to(device) 82 | img_l = [x_0] # original image 83 | 84 | w_0 = np.load(testdata_dir + 'latent_code_%05d.npy'%idx) 85 | w_0 = torch.tensor(w_0).to(device) 86 | predict_lbl_0 = trainer.Latent_Classifier(w_0.view(w_0.size(0), -1)) 87 | lbl_0 = F.sigmoid(predict_lbl_0) 88 | attr_pb_0 = lbl_0[:, attr_num] 89 | coeff = -1 if attr_pb_0 > 0.5 else 1 90 | 91 | range_alpha = torch.linspace(-scale, scale, n_steps) 92 | for i, alpha in enumerate(range_alpha): 93 | 94 | w_1 = trainer.T_net(w_0.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device)) 95 | w_1 = w_1.view(w_0.size()) 96 | w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1) 97 | x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False) 98 | img_l.append(x_1.data) 99 | 100 | out = torch.cat(img_l, dim=3) 101 | utils.save_image(clip_img(out), save_dir + attr + '_' + '%05d.jpg'%idx) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.data as data 13 | import yaml 14 | 15 | from PIL import Image 16 | from torchvision import transforms, utils 17 | from tensorboard_logger import Logger 18 | 19 | from datasets import * 20 | from trainer import * 21 | 22 | torch.backends.cudnn.enabled = True 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = True 25 | torch.autograd.set_detect_anomaly(True) 26 | Image.MAX_IMAGE_PIXELS = None 27 | device = torch.device('cuda') 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 31 | parser.add_argument('--latent_path', type=str, default='./data/celebahq_dlatents_psp.npy', help='dataset path') 32 | parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path') 33 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path') 34 | parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier') 35 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 36 | parser.add_argument('--resume', type=bool, default=False, help='resume from checkpoint') 37 | parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') 38 | parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') 39 | opts = parser.parse_args() 40 | 41 | # Celeba attribute list 42 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 43 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 44 | 'Blurry': 10, 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 45 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 46 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 47 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 48 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 49 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 50 | 51 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 52 | if not os.path.exists(log_dir): 53 | os.makedirs(log_dir) 54 | logger = Logger(log_dir) 55 | 56 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r')) 57 | attr_l = config['attr'].split(',') 58 | batch_size = config['batch_size'] 59 | epochs = config['epochs'] 60 | 61 | dlatents = np.load(opts.latent_path) 62 | w = torch.tensor(dlatents).to(device) 63 | 64 | dataset_A = LatentDataset(opts.latent_path, opts.label_file, training_set=True) 65 | loader_A = data.DataLoader(dataset_A, batch_size=batch_size, shuffle=True) 66 | 67 | print('Start training!') 68 | for attr in attr_l: 69 | 70 | total_iter = 0 71 | attr_num = attr_dict[attr] 72 | 73 | # Initialize trainer 74 | trainer = Trainer(config, attr_num, attr, opts.label_file) 75 | trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) 76 | trainer.to(device) 77 | 78 | for n_epoch in range(epochs): 79 | 80 | for n_iter, list_A in enumerate(loader_A): 81 | 82 | w_A, lbl_A = list_A 83 | w_A, lbl_A = w_A.to(device), lbl_A.to(device) 84 | trainer.update(w_A, None, n_iter) 85 | 86 | if (total_iter+1) % config['log_iter'] == 0: 87 | trainer.log_loss(logger, total_iter) 88 | if (total_iter+1) % config['image_log_iter'] == 0: 89 | trainer.log_image(logger, w[total_iter%dataset_A.length].unsqueeze(0), total_iter) 90 | total_iter += 1 91 | 92 | trainer.save_model(log_dir) 93 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as data 12 | 13 | from PIL import Image 14 | from torch.autograd import grad 15 | from torchvision import transforms, utils 16 | 17 | from nets import * 18 | from utils.functions import * 19 | 20 | import sys 21 | sys.path.append('pixel2style2pixel/') 22 | 23 | from pixel2style2pixel.models.stylegan2.model import Generator 24 | from pixel2style2pixel.models.psp import get_keys 25 | 26 | class Trainer(nn.Module): 27 | def __init__(self, config, attr_num, attr, label_file): 28 | super(Trainer, self).__init__() 29 | # Load Hyperparameters 30 | self.accumulation_steps = 16 31 | self.config = config 32 | self.attr_num = attr_num 33 | self.attr = attr 34 | mapping_lrmul = self.config['mapping_lrmul'] 35 | mapping_layers = self.config['mapping_layers'] 36 | mapping_fmaps = self.config['mapping_fmaps'] 37 | mapping_nonlinearity = self.config['mapping_nonlinearity'] 38 | # Networks 39 | # Latent Transformer 40 | self.T_net = F_mapping(mapping_lrmul= mapping_lrmul, mapping_layers=mapping_layers, mapping_fmaps=mapping_fmaps, mapping_nonlinearity = mapping_nonlinearity) 41 | # Latent Classifier 42 | self.Latent_Classifier = LCNet([9216, 2048, 512, 40], activ='leakyrelu') 43 | # StyleGAN Model 44 | self.StyleGAN = Generator(1024, 512, 8) 45 | 46 | self.label_file = label_file 47 | self.corr_ma = None 48 | 49 | # Optimizers 50 | self.params = list(self.T_net.parameters()) 51 | self.optimizer = torch.optim.Adam(self.params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) 52 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=config['step_size'], gamma=config['gamma']) 53 | 54 | def initialize(self, stylegan_model_path, classifier_model_path): 55 | state_dict = torch.load(stylegan_model_path, map_location='cpu') 56 | self.StyleGAN.load_state_dict(get_keys(state_dict, 'decoder'), strict=True) 57 | self.Latent_Classifier.load_state_dict(torch.load(classifier_model_path)) 58 | self.Latent_Classifier.eval() 59 | 60 | def L1loss(self, input, target): 61 | return nn.L1Loss()(input,target) 62 | 63 | def MSEloss(self, input, target): 64 | if isinstance(input, list): 65 | return sum([nn.MSELoss()(input[i],target[i]) for i in range(len(input))])/len(input) 66 | else: 67 | return nn.MSELoss()(input,target) 68 | 69 | def SmoothL1loss(self, input, target): 70 | return nn.SmoothL1Loss()(input, target) 71 | 72 | def CEloss(self, x, target, reduction='mean'): 73 | return nn.CrossEntropyLoss(reduction=reduction)(x, target) 74 | 75 | def BCEloss(self, x, target, reduction='mean'): 76 | return nn.BCEWithLogitsLoss(reduction=reduction)(x, target) 77 | 78 | def GAN_loss(self, x, real=True): 79 | if real: 80 | target = torch.ones(x.size()).type_as(x) 81 | else: 82 | target = torch.zeros(x.size()).type_as(x) 83 | return nn.MSELoss()(x, target) 84 | 85 | def get_correlation(self, attr_num, threshold=1): 86 | if self.corr_ma is None: 87 | lbls = np.load(self.label_file) 88 | self.corr_ma = np.corrcoef(lbls.transpose()) 89 | self.corr_ma[np.isnan(self.corr_ma)] = 0 90 | corr_vec = np.abs(self.corr_ma[attr_num:attr_num+1]) 91 | corr_vec[corr_vec>=threshold] = 1 92 | return 1 - corr_vec 93 | 94 | def get_coeff(self, x): 95 | sign_0 = F.relu(x-0.5).sign() 96 | sign_1 = F.relu(0.5-x).sign() 97 | return sign_0*(-x) + sign_1*(1-x) 98 | 99 | def compute_loss(self, w, mask_input, n_iter): 100 | self.w_0 = w 101 | predict_lbl_0 = self.Latent_Classifier(self.w_0.view(w.size(0), -1)) 102 | lbl_0 = F.sigmoid(predict_lbl_0) 103 | attr_pb_0 = lbl_0[:, self.attr_num] 104 | # Get scaling factor 105 | coeff = self.get_coeff(attr_pb_0) 106 | target_pb = torch.clamp(attr_pb_0 + coeff, 0, 1).round() 107 | if 'alpha' in self.config and not self.config['alpha']: 108 | coeff = 2 * target_pb.type_as(attr_pb_0) - 1 109 | # Apply latent transformation 110 | self.w_1 = self.T_net(self.w_0.view(w.size(0), -1), coeff) 111 | self.w_1 = self.w_1.view(w.size()) 112 | predict_lbl_1 = self.Latent_Classifier(self.w_1.view(w.size(0), -1)) 113 | 114 | # Pb loss 115 | T_coeff = target_pb.size(0)/(target_pb.sum(0) + 1e-8) 116 | F_coeff = target_pb.size(0)/(target_pb.size(0) - target_pb.sum(0) + 1e-8) 117 | mask_pb = T_coeff.float() * target_pb + F_coeff.float() * (1-target_pb) 118 | self.loss_pb = self.BCEloss(predict_lbl_1[:, self.attr_num], target_pb, reduction='none')*mask_pb 119 | self.loss_pb = self.loss_pb.mean() 120 | 121 | # Latent code recon 122 | self.loss_recon = self.MSEloss(self.w_1, self.w_0) 123 | 124 | # Reg loss 125 | threshold_val = 1 if 'corr_threshold' not in self.config else self.config['corr_threshold'] 126 | mask = torch.tensor(self.get_correlation(self.attr_num, threshold=threshold_val)).type_as(predict_lbl_0) 127 | mask = mask.repeat(predict_lbl_0.size(0), 1) 128 | self.loss_reg = self.MSEloss(predict_lbl_1*mask, predict_lbl_0*mask) 129 | 130 | # Total loss 131 | w_recon, w_pb, w_reg = self.config['w']['recon'], self.config['w']['pb'], self.config['w']['reg'] 132 | self.loss = w_pb * self.loss_pb + w_recon*self.loss_recon + w_reg * self.loss_reg 133 | 134 | return self.loss 135 | 136 | def get_image(self, w): 137 | # Original image 138 | predict_lbl_0 = self.Latent_Classifier(w.view(w.size(0), -1)) 139 | lbl_0 = F.sigmoid(predict_lbl_0) 140 | attr_pb_0 = lbl_0[:, self.attr_num] 141 | coeff = self.get_coeff(attr_pb_0) 142 | target_pb = torch.clamp(attr_pb_0 + coeff, 0, 1).round() 143 | if 'alpha' in self.config and not self.config['alpha']: 144 | coeff = 2 * target_pb.type_as(attr_pb_0) - 1 145 | 146 | w_1 = self.T_net(w.view(w.size(0), -1), coeff) 147 | w_1 = w_1.view(w.size()) 148 | self.x_0, _ = self.StyleGAN([w], input_is_latent=True, randomize_noise=False) 149 | self.x_1, _ = self.StyleGAN([w_1], input_is_latent=True, randomize_noise=False) 150 | 151 | def log_image(self, logger, w, n_iter): 152 | with torch.no_grad(): 153 | self.get_image(w) 154 | logger.log_images('image_'+self.attr+'/iter'+str(n_iter+1)+'_input', clip_img(downscale(self.x_0, 2)), n_iter + 1) 155 | logger.log_images('image_'+self.attr+'/iter'+str(n_iter+1)+'_modif', clip_img(downscale(self.x_1, 2)), n_iter + 1) 156 | 157 | def log_loss(self, logger, n_iter): 158 | logger.log_value('loss_'+self.attr+'/class', self.loss_pb.item(), n_iter + 1) 159 | logger.log_value('loss_'+self.attr+'/latent_recon', self.loss_recon.item(), n_iter + 1) 160 | logger.log_value('loss_'+self.attr+'/attr_reg', self.loss_reg.item(), n_iter + 1) 161 | logger.log_value('loss_'+self.attr+'/total', self.loss.item(), n_iter + 1) 162 | 163 | def save_image(self, log_dir, n_iter): 164 | utils.save_image(clip_img(self.x_0), log_dir + 'iter' +str(n_iter+1)+ '_img.jpg') 165 | utils.save_image(clip_img(self.x_1), log_dir + 'iter' +str(n_iter+1)+ '_img_modif.jpg') 166 | 167 | def save_model(self, log_dir): 168 | torch.save(self.T_net.state_dict(),log_dir + '/tnet_' + str(self.attr_num) +'.pth.tar') 169 | 170 | def save_checkpoint(self, n_epoch, log_dir): 171 | checkpoint_state = { 172 | 'n_epoch': n_epoch, 173 | 'T_net_state_dict': self.T_net.state_dict(), 174 | 'opt_state_dict': self.optimizer.state_dict(), 175 | 'scheduler_state_dict': self.scheduler.state_dict() 176 | } 177 | if (n_epoch+1) % 10 == 0 : 178 | torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)+'_'+str(n_epoch+1)) 179 | else: 180 | torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)) 181 | 182 | def load_model(self, log_dir): 183 | self.T_net.load_state_dict(torch.load(log_dir + 'tnet_' + str(self.attr_num) +'.pth.tar')) 184 | 185 | def load_checkpoint(self, checkpoint_path): 186 | state_dict = torch.load(checkpoint_path) 187 | self.T_net.load_state_dict(state_dict['T_net_state_dict']) 188 | self.optimizer.load_state_dict(state_dict['opt_state_dict']) 189 | self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) 190 | return state_dict['n_epoch'] + 1 191 | 192 | def update(self, w, mask, n_iter): 193 | self.n_iter = n_iter 194 | self.optimizer.zero_grad() 195 | self.compute_loss(w, mask, n_iter).backward() 196 | self.optimizer.step() 197 | 198 | -------------------------------------------------------------------------------- /utils/__pycache__/functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/utils/__pycache__/functions.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/upfirdn_2d.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/utils/__pycache__/upfirdn_2d.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/video_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/latent-transformer/20089489ff2d69171f977d4ab316325f146ddd20/utils/__pycache__/video_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.utils.data as data 13 | 14 | from PIL import Image 15 | from torch.autograd import grad 16 | from torchvision import transforms, utils 17 | 18 | 19 | def clip_img(x): 20 | """Clip image to range(0,1)""" 21 | img_tmp = x.clone()[0] 22 | img_tmp = (img_tmp + 1) / 2 23 | img_tmp = torch.clamp(img_tmp, 0, 1) 24 | return [img_tmp.detach().cpu()] 25 | 26 | def stylegan_to_classifier(x): 27 | """Clip image to range(0,1)""" 28 | img_tmp = x.clone() 29 | img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) 30 | img_tmp = F.interpolate(img_tmp, size=(224, 224), mode='bilinear') 31 | img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229 32 | img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224 33 | img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225 34 | return img_tmp 35 | 36 | def img_to_tensor(x): 37 | out = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 40 | ])(x) 41 | return out 42 | 43 | 44 | img_to_tensor = transforms.Compose([ 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 47 | ]) 48 | 49 | def downscale(x, scale_times=1): 50 | for i in range(scale_times): 51 | x = F.interpolate(x, scale_factor=0.5, mode='bilinear') 52 | return x 53 | 54 | def upscale(x, scale_times=1): 55 | for i in range(scale_times): 56 | x = F.interpolate(x, scale_factor=2, mode='bilinear') 57 | return x 58 | 59 | def hist_transform(source_tensor, target_tensor): 60 | """Histogram transformation""" 61 | c, h, w = source_tensor.size() 62 | s_t = source_tensor.view(c, -1) 63 | t_t = target_tensor.view(c, -1) 64 | s_t_sorted, s_t_indices = torch.sort(s_t) 65 | t_t_sorted, t_t_indices = torch.sort(t_t) 66 | for i in range(c): 67 | s_t[i, s_t_indices[i]] = t_t_sorted[i] 68 | return s_t.view(c, h, w) 69 | 70 | def init_weights(m): 71 | """Initialize layers with Xavier uniform distribution""" 72 | if type(m) == nn.Conv2d: 73 | nn.init.xavier_uniform_(m.weight) 74 | elif type(m) == nn.Linear: 75 | nn.init.uniform_(m.weight, 0.0, 1.0) 76 | if m.bias is not None: 77 | nn.init.constant_(m.bias, 0.01) 78 | 79 | def reg_loss(img): 80 | """Total variation""" 81 | reg_loss = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]))\ 82 | + torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])) 83 | return reg_loss 84 | 85 | def vgg_transform(x): 86 | """Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean""" 87 | r, g, b = torch.split(x, 1, 1) 88 | out = torch.cat((b, g, r), dim = 1) 89 | out = F.interpolate(out, size=(224, 224), mode='bilinear') 90 | out = out*255. 91 | return out 92 | 93 | def stylegan_to_vgg(x): 94 | """Clip image to range(0,1)""" 95 | img_tmp = x.clone() 96 | img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) 97 | img_tmp = F.interpolate(x, size=(224, 224), mode='bilinear') 98 | img_tmp[:,0] = (img_tmp[:,0] - 0.485) 99 | img_tmp[:,1] = (img_tmp[:,1] - 0.456) 100 | img_tmp[:,2] = (img_tmp[:,2] - 0.406) 101 | r, g, b = torch.split(img_tmp, 1, 1) 102 | img_tmp = torch.cat((b, g, r), dim = 1) 103 | img_tmp = img_tmp*255. 104 | return img_tmp 105 | 106 | -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import os 10 | import face_alignment 11 | 12 | from PIL import Image, ImageFilter 13 | from scipy import ndimage 14 | from scipy.ndimage import gaussian_filter1d 15 | from skimage import io 16 | 17 | 18 | def pil_to_cv2(pil_image): 19 | open_cv_image = np.array(pil_image) 20 | return open_cv_image[:, :, ::-1].copy() 21 | 22 | 23 | def cv2_to_pil(open_cv_image): 24 | return Image.fromarray(open_cv_image[:, :, ::-1].copy()) 25 | 26 | 27 | def put_text(img, text): 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | bottomLeftCornerOfText = (10,50) 30 | fontScale = 1.5 31 | fontColor = (255,255,0) 32 | lineType = 2 33 | return cv2.putText(img, text, 34 | bottomLeftCornerOfText, 35 | font, 36 | fontScale, 37 | fontColor, 38 | lineType) 39 | 40 | 41 | # Compare frames in two directory 42 | def compare_frames(save_dir, origin_dir, target_dir, strs='Original,Projected,Manipulated', dim=1): 43 | 44 | os.makedirs(save_dir, exist_ok=True) 45 | try: 46 | if not isinstance(target_dir, list): 47 | target_dir = [target_dir] 48 | image_list = glob.glob1(origin_dir,'frame*') 49 | image_list.sort() 50 | for name in image_list: 51 | img_l = [] 52 | for idx, dir_path in enumerate([origin_dir] + list(target_dir)): 53 | img_1 = cv2.imread(dir_path + name) 54 | img_1 = put_text(img_1, strs.split(',')[idx]) 55 | img_l.append(img_1) 56 | if len(img_l)!=4: 57 | img = np.concatenate(img_l, dim) 58 | else: 59 | tmp_1 = np.concatenate(img_l[:2], dim) 60 | tmp_2 = np.concatenate(img_l[2:], dim) 61 | img = np.concatenate([tmp_1,tmp_2], 0) 62 | cv2.imwrite(save_dir + name, img) 63 | except FileNotFoundError: 64 | pass 65 | 66 | 67 | # Save frames into video 68 | def create_video(image_folder, fps=24, video_format='.mp4', resize_ratio=1): 69 | 70 | video_name = os.path.dirname(image_folder) + video_format 71 | img_list = glob.glob1(image_folder,'frame*') 72 | img_list.sort() 73 | frame = cv2.imread(os.path.join(image_folder, img_list[0])) 74 | frame = cv2.resize(frame, (0,0), fx=resize_ratio, fy=resize_ratio) 75 | height, width, layers = frame.shape 76 | if video_format == '.mp4': 77 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 78 | elif video_format == '.avi': 79 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 80 | video = cv2.VideoWriter(video_name, fourcc, fps, (width,height)) 81 | for image_name in img_list: 82 | frame = cv2.imread(os.path.join(image_folder, image_name)) 83 | frame = cv2.resize(frame, (0,0), fx=resize_ratio, fy=resize_ratio) 84 | video.write(frame) 85 | 86 | 87 | # Split video into frames 88 | def video_to_frames(video_path, frame_path, img_format='.jpg', resize=False): 89 | 90 | os.makedirs(frame_path, exist_ok=True) 91 | vidcap = cv2.VideoCapture(video_path) 92 | success,image = vidcap.read() 93 | count = 0 94 | while success: 95 | if resize: 96 | image = cv2.resize(image, (0,0), fx=0.5, fy=0.5) 97 | cv2.imwrite(frame_path + '/frame%04d' % count + img_format, image) 98 | success,image = vidcap.read() 99 | count += 1 100 | 101 | # Align faces 102 | def align_frames(img_dir, save_dir, output_size=1024, transform_size=1024, optical_flow=True, gaussian=True, filter_size=3): 103 | 104 | os.makedirs(save_dir, exist_ok=True) 105 | 106 | # load face landmark detector 107 | fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device='cuda') 108 | 109 | # list images in the directory 110 | img_list = glob.glob1(img_dir, 'frame*') 111 | img_list.sort() 112 | 113 | # save align statistics 114 | stat_dict = {'quad':[], 'qsize':[], 'coord':[], 'crop':[]} 115 | lms = [] 116 | for idx, img_name in enumerate(img_list): 117 | 118 | img_path = os.path.join(img_dir, img_name) 119 | img = io.imread(img_path) 120 | lm = [] 121 | 122 | preds = fa.get_landmarks(img) 123 | for kk in range(68): 124 | lm.append((preds[0][kk][0], preds[0][kk][1])) 125 | 126 | # Eye distance 127 | lm_eye_left = lm[36 : 42] # left-clockwise 128 | lm_eye_right = lm[42 : 48] # left-clockwise 129 | eye_left = np.mean(lm_eye_left, axis=0) 130 | eye_right = np.mean(lm_eye_right, axis=0) 131 | eye_to_eye = eye_right - eye_left 132 | 133 | if optical_flow: 134 | if idx > 0: 135 | s = int(np.hypot(*eye_to_eye)/4) 136 | lk_params = dict(winSize=(s, s), maxLevel=5, criteria = (cv2.TERM_CRITERIA_COUNT | cv2.TERM_CRITERIA_EPS, 10, 0.03)) 137 | points_arr = np.array(lm, np.float32) 138 | points_prevarr = np.array(prev_lm, np.float32) 139 | points_arr,status, err = cv2.calcOpticalFlowPyrLK(prev_img, img, points_prevarr, points_arr, **lk_params) 140 | sigma =100 141 | points_arr_float = np.array(points_arr,np.float32) 142 | points = points_arr_float.tolist() 143 | for k in range(0, len(lm)): 144 | d = cv2.norm(np.array(prev_lm[k]) - np.array(lm[k])) 145 | alpha = np.exp(-d*d/sigma) 146 | lm[k] = (1 - alpha) * np.array(lm[k]) + alpha * np.array(points[k]) 147 | prev_img = img 148 | prev_lm = lm 149 | 150 | lms.append(lm) 151 | 152 | # Apply gaussian filter on landmarks 153 | if gaussian: 154 | lm_filtered = np.array(lms) 155 | for kk in range(68): 156 | lm_filtered[:, kk, 0] = gaussian_filter1d(lm_filtered[:, kk, 0], filter_size) 157 | lm_filtered[:, kk, 1] = gaussian_filter1d(lm_filtered[:, kk, 1], filter_size) 158 | lms = lm_filtered.tolist() 159 | 160 | # save landmarks 161 | landmark_out_dir = os.path.dirname(img_dir) + '_landmark/' 162 | os.makedirs(landmark_out_dir, exist_ok=True) 163 | 164 | for idx, img_name in enumerate(img_list): 165 | 166 | img_path = os.path.join(img_dir, img_name) 167 | img = io.imread(img_path) 168 | 169 | lm = lms[idx] 170 | img_lm = img.copy() 171 | for kk in range(68): 172 | img_lm = cv2.circle(img_lm, (int(lm[kk][0]),int(lm[kk][1])), radius=3, color=(255, 0, 255), thickness=-1) 173 | # Save landmark images 174 | cv2.imwrite(landmark_out_dir + img_name, img_lm[:,:,::-1]) 175 | 176 | # Save mask images 177 | seg_mask = np.zeros(img.shape, img.dtype) 178 | poly = np.array(lm[0:17] + lm[17:27][::-1], np.int32) 179 | cv2.fillPoly(seg_mask, [poly], (255, 255, 255)) 180 | cv2.imwrite(img_dir + "mask%04d.jpg"%idx, seg_mask); 181 | 182 | # Parse landmarks. 183 | lm_eye_left = lm[36 : 42] # left-clockwise 184 | lm_eye_right = lm[42 : 48] # left-clockwise 185 | lm_mouth_outer = lm[48 : 60] # left-clockwise 186 | 187 | # Calculate auxiliary vectors. 188 | eye_left = np.mean([lm_eye_left[0], lm_eye_left[3]], axis=0) 189 | eye_right = np.mean([lm_eye_right[0], lm_eye_right[3]], axis=0) 190 | eye_avg = (eye_left + eye_right) * 0.5 191 | eye_to_eye = eye_right - eye_left 192 | mouth_left = np.array(lm_mouth_outer[0]) 193 | mouth_right = np.array(lm_mouth_outer[6]) 194 | mouth_avg = (mouth_left + mouth_right) * 0.5 195 | eye_to_mouth = mouth_avg - eye_avg 196 | 197 | # Choose oriented crop rectangle. 198 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 199 | x /= np.hypot(*x) 200 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 201 | y = np.flipud(x) * [-1, 1] 202 | c = eye_avg + eye_to_mouth * 0.1 203 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 204 | qsize = np.hypot(*x) * 2 205 | 206 | stat_dict['coord'].append(quad) 207 | stat_dict['qsize'].append(qsize) 208 | 209 | # Apply gaussian filter on crops 210 | if gaussian: 211 | quads = np.array(stat_dict['coord']) 212 | quads = gaussian_filter1d(quads, 2*filter_size, axis=0) 213 | stat_dict['coord'] = quads.tolist() 214 | qsize = np.array(stat_dict['qsize']) 215 | qsize = gaussian_filter1d(qsize, 2*filter_size, axis=0) 216 | stat_dict['qsize'] = qsize.tolist() 217 | 218 | for idx, img_name in enumerate(img_list): 219 | 220 | img_path = os.path.join(img_dir, img_name) 221 | img = Image.open(img_path) 222 | 223 | qsize = stat_dict['qsize'][idx] 224 | quad = np.array(stat_dict['coord'][idx]) 225 | 226 | # Crop. 227 | border = max(int(np.rint(qsize * 0.1)), 3) 228 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 229 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 230 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 231 | img = img.crop(crop) 232 | quad -= crop[0:2] 233 | 234 | stat_dict['crop'].append(crop) 235 | stat_dict['quad'].append((quad + 0.5).flatten()) 236 | 237 | # Pad. 238 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 239 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 240 | if max(pad) > border - 4: 241 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 242 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 243 | h, w, _ = img.shape 244 | y, x, _ = np.ogrid[:h, :w, :1] 245 | img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 246 | quad += pad[:2] 247 | # Transform. 248 | img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR) 249 | 250 | # resizing 251 | img_pil = img.resize((output_size, output_size), Image.LANCZOS) 252 | img_pil.save(save_dir+img_name) 253 | 254 | create_video(landmark_out_dir) 255 | np.save(save_dir+'stat_dict.npy', stat_dict) 256 | 257 | 258 | def find_coeffs(pa, pb): 259 | 260 | matrix = [] 261 | for p1, p2 in zip(pa, pb): 262 | matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]]) 263 | matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]]) 264 | A = np.matrix(matrix, dtype=np.float) 265 | B = np.array(pb).reshape(8) 266 | res = np.dot(np.linalg.inv(A.T * A) * A.T, B) 267 | return np.array(res).reshape(8) 268 | 269 | # reproject aligned frames to the original video 270 | def video_reproject(orig_dir_path, recon_dir_path, save_dir_path, state_dir_path, seamless=False): 271 | 272 | if not os.path.exists(save_dir_path): 273 | os.makedirs(save_dir_path) 274 | 275 | img_list_0 = glob.glob1(orig_dir_path,'frame*') 276 | img_list_2 = glob.glob1(recon_dir_path,'frame*') 277 | img_list_0.sort() 278 | img_list_2.sort() 279 | stat_dict = np.load(state_dir_path + 'stat_dict.npy', allow_pickle=True).item() 280 | counter = len(img_list_2) 281 | 282 | for idx in range(counter): 283 | 284 | img_0 = Image.open(orig_dir_path + img_list_0[idx]) 285 | img_2 = Image.open(recon_dir_path + img_list_2[idx]) 286 | 287 | quad_f = stat_dict['quad'][idx] 288 | quad_0 = stat_dict['crop'][idx] 289 | 290 | coeffs = find_coeffs( 291 | [(quad_f[0], quad_f[1]), (quad_f[2] , quad_f[3]), (quad_f[4], quad_f[5]), (quad_f[6], quad_f[7])], 292 | [(0, 0), (0, 1024), (1024, 1024), (1024, 0)]) 293 | crop_size = (quad_0[2] - quad_0[0], quad_0[3] - quad_0[1]) 294 | img_2 = img_2.transform(crop_size, Image.PERSPECTIVE, coeffs, Image.BICUBIC) 295 | output = img_0.copy() 296 | output.paste(img_2, (int(quad_0[0]), int(quad_0[1]))) 297 | 298 | mask = cv2.imread(orig_dir_path + 'mask%04d.jpg'%idx) 299 | kernel = np.ones((10,10), np.uint8) 300 | mask = cv2.dilate(mask, kernel, iterations=5) 301 | # Apply mask 302 | if not seamless: 303 | mask = cv2_to_pil(mask).filter(ImageFilter.GaussianBlur(radius=10)).convert('L') 304 | mask = np.array(mask)[:, :, np.newaxis]/255. 305 | output = np.array(img_0)*(1-mask) + np.array(output)*mask 306 | output = Image.fromarray(output.astype(np.uint8)) 307 | output.save(save_dir_path + img_list_2[idx]) 308 | else: 309 | src = pil_to_cv2(output) 310 | dst = pil_to_cv2(img_0) 311 | # clone 312 | br = cv2.boundingRect(cv2.split(mask)[0]) # bounding rect (x,y,width,height) 313 | center = (br[0] + br[2] // 2, br[1] + br[3] // 2) 314 | output = cv2.seamlessClone(src, dst, mask, center, cv2.NORMAL_CLONE) 315 | cv2.imwrite(save_dir_path + img_list_2[idx], output) -------------------------------------------------------------------------------- /video_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, InterDigital R&D France. All rights reserved. 2 | # 3 | # This source code is made available under the license found in the 4 | # LICENSE.txt in the root directory of this source tree. 5 | 6 | import argparse 7 | import copy 8 | import glob 9 | import numpy as np 10 | import os 11 | import torch 12 | import yaml 13 | import time 14 | 15 | from PIL import Image 16 | from torchvision import transforms, utils, models 17 | 18 | from utils.video_utils import * 19 | from trainer import * 20 | 21 | torch.backends.cudnn.enabled = True 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | torch.autograd.set_detect_anomaly(True) 25 | Image.MAX_IMAGE_PIXELS = None 26 | device = torch.device('cuda') 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--config', type=str, default='001', help='Path to the config file.') 30 | parser.add_argument('--attr', type=str, default='Eyeglasses', help='attribute for manipulation.') 31 | parser.add_argument('--alpha', type=str, default='1.', help='scale for manipulation.') 32 | parser.add_argument('--label_file', type=str, default='./data/celebahq_anno.npy', help='label file path') 33 | parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='stylegan model path') 34 | parser.add_argument('--classifier_model_path', type=str, default='./models/latent_classifier_epoch_20.pth', help='pretrained attribute classifier') 35 | parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') 36 | parser.add_argument('--function', type=str, default='', help='Calling function by name.') 37 | parser.add_argument('--video_path', type=str, default='./data/video/FP006911MD02.mp4', help='video file path') 38 | parser.add_argument('--output_path', type=str, default='./outputs/video/', help='output video file path') 39 | parser.add_argument('--optical_flow', action='store_true', help='use optical flow') 40 | parser.add_argument('--resize', action='store_true', help='downscale image size') 41 | parser.add_argument('--seamless', action='store_true', help='seamless cloning') 42 | parser.add_argument('--filter_size', type=float, default=3, help='filter size') 43 | parser.add_argument('--strs', type=str, default='Original,Projected,Manipulated', help='strs to be added on video') 44 | opts = parser.parse_args() 45 | 46 | # Celeba attribute list 47 | attr_dict = {'5_o_Clock_Shadow': 0, 'Arched_Eyebrows': 1, 'Attractive': 2, 'Bags_Under_Eyes': 3, \ 48 | 'Bald': 4, 'Bangs': 5, 'Big_Lips': 6, 'Big_Nose': 7, 'Black_Hair': 8, 'Blond_Hair': 9, \ 49 | 'Brown_Hair': 11, 'Bushy_Eyebrows': 12, 'Chubby': 13, 'Double_Chin': 14, \ 50 | 'Eyeglasses': 15, 'Goatee': 16, 'Gray_Hair': 17, 'Heavy_Makeup': 18, 'High_Cheekbones': 19, \ 51 | 'Male': 20, 'Mouth_Slightly_Open': 21, 'Mustache': 22, 'Narrow_Eyes': 23, 'No_Beard': 24, \ 52 | 'Oval_Face': 25, 'Pale_Skin': 26, 'Pointy_Nose': 27, 'Receding_Hairline': 28, 'Rosy_Cheeks': 29, \ 53 | 'Sideburns': 30, 'Smiling': 31, 'Straight_Hair': 32, 'Wavy_Hair': 33, 'Wearing_Earrings': 34, \ 54 | 'Wearing_Hat': 35, 'Wearing_Lipstick': 36, 'Wearing_Necklace': 37, 'Wearing_Necktie': 38, 'Young': 39} 55 | 56 | # Latent code manipulation 57 | def latent_manipulation(opts, latent_dir_path, process_dir_path): 58 | 59 | attrs = opts.attr.split(',') 60 | alphas = opts.alpha.split(',') 61 | os.makedirs(process_dir_path, exist_ok=True) 62 | 63 | with torch.no_grad(): 64 | 65 | log_dir = os.path.join(opts.log_path, opts.config) + '/' 66 | config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r')) 67 | 68 | # Initialize trainer 69 | trainer = Trainer(config, None, None, opts.label_file) 70 | trainer.initialize(opts.stylegan_model_path, opts.classifier_model_path) 71 | trainer.to(device) 72 | 73 | latent_num = len(glob.glob1(latent_dir_path,'*.npy')) 74 | 75 | T_nets = [] 76 | for attr_idx, attr in enumerate(attrs): 77 | trainer.attr_num = attr_dict[attr] 78 | trainer.load_model(log_dir) 79 | T_nets.append(copy.deepcopy(trainer.T_net)) 80 | 81 | for k in range(latent_num): 82 | w_0 = np.load(latent_dir_path + 'latent_code_%05d.npy'%k) 83 | w_0 = torch.tensor(w_0).to(device) 84 | w_1 = w_0.clone() 85 | for attr_idx, attr in enumerate(attrs): 86 | alpha = torch.tensor(float(alphas[attr_idx])) 87 | w_1 = T_nets[attr_idx](w_1.view(w_0.size(0), -1), alpha.unsqueeze(0).to(device)) 88 | w_1 = w_1.view(w_0.size()) 89 | w_1 = torch.cat((w_1[:,:11,:], w_0[:,11:,:]), 1) 90 | x_1, _ = trainer.StyleGAN([w_1], input_is_latent=True, randomize_noise=False) 91 | utils.save_image(clip_img(x_1), process_dir_path + 'frame%04d'%k+'.jpg') 92 | 93 | 94 | video_path = opts.video_path 95 | video_name = video_path.split('/')[-1] 96 | orig_dir_path = opts.output_path + video_name.split('.')[0] + '/' + video_name.split('.')[0] + '/' 97 | align_dir_path = os.path.dirname(orig_dir_path) + '_crop_align/' 98 | latent_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_latent/' 99 | process_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '/' 100 | reproject_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '_reproject/' 101 | 102 | 103 | print(opts.function) 104 | start_time = time.perf_counter() 105 | 106 | if opts.function == 'video_to_frames': 107 | video_to_frames(video_path, orig_dir_path, resize=opts.resize) 108 | create_video(orig_dir_path) 109 | elif opts.function == 'align_frames': 110 | align_frames(orig_dir_path, align_dir_path, output_size=1024, optical_flow=opts.optical_flow, filter_size=opts.filter_size) 111 | elif opts.function == 'latent_manipulation': 112 | latent_manipulation(opts, latent_dir_path, process_dir_path) 113 | elif opts.function == 'reproject_origin': 114 | process_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_latent/inference_results/' 115 | reproject_dir_path = os.path.dirname(orig_dir_path) + '_crop_align_origin_reproject/' 116 | video_reproject(orig_dir_path, process_dir_path, reproject_dir_path, align_dir_path, seamless=opts.seamless) 117 | create_video(reproject_dir_path) 118 | elif opts.function == 'reproject_manipulate': 119 | video_reproject(orig_dir_path, process_dir_path, reproject_dir_path, align_dir_path, seamless=opts.seamless) 120 | create_video(reproject_dir_path) 121 | elif opts.function == 'compare_frames': 122 | process_dir_paths = [] 123 | process_dir_paths.append(os.path.dirname(orig_dir_path) + '_crop_align_origin_reproject/') 124 | process_dir_paths.append(os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.split(',')[0] + '_reproject/') 125 | if len(opts.attr.split(','))>1: 126 | process_dir_paths.append(reproject_dir_path) 127 | save_dir = os.path.dirname(orig_dir_path) + '_crop_align_' + opts.attr.replace(',','_') + '_compare/' 128 | compare_frames(save_dir, orig_dir_path, process_dir_paths, strs=opts.strs, dim=1) 129 | create_video(save_dir, video_format='.avi', resize_ratio=1) 130 | 131 | count_time = time.perf_counter() - start_time 132 | print("Elapsed time: %0.4f seconds"%count_time) --------------------------------------------------------------------------------