├── .gitignore ├── LICENSE ├── LICENSE-FID ├── LICENSE-LPIPS ├── LICENSE-NVIDIA ├── README.md ├── apply_factor.py ├── calc_inception.py ├── checkpoint └── .gitignore ├── closed_form_factorization.py ├── convert_weight.py ├── dataset.py ├── distributed.py ├── doc ├── sample-metfaces.png ├── sample.png ├── stylegan2-church-config-f.png └── stylegan2-ffhq-config-f.png ├── factor_index-13_degree-5.0.png ├── fid.py ├── generate.py ├── inception.py ├── inception_ffhq.pkl ├── lpips ├── __init__.py ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── model.py ├── non_leaking.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── ppl.py ├── prepare_data.py ├── projector.py ├── sample └── .gitignore └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | wandb/ 132 | *.lmdb/ 133 | *.pkl 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-FID: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN 2 in PyTorch 2 | 3 | Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch 4 | 5 | ## Notice 6 | 7 | I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care. 8 | 9 | ## Requirements 10 | 11 | I have tested on: 12 | 13 | * PyTorch 1.3.1 14 | * CUDA 10.1/10.2 15 | 16 | ## Usage 17 | 18 | First create lmdb datasets: 19 | 20 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH 21 | 22 | This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. 23 | 24 | Then you can train model in distributed settings 25 | 26 | > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH 27 | 28 | train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script. 29 | 30 | ### Convert weight from official checkpoints 31 | 32 | You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints. 33 | 34 | For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this: 35 | 36 | > python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl 37 | 38 | This will create converted stylegan2-ffhq-config-f.pt file. 39 | 40 | ### Generate samples 41 | 42 | > python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT 43 | 44 | You should change your size (--size 256 for example) if you train with another dimension. 45 | 46 | ### Project images to latent spaces 47 | 48 | > python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ... 49 | 50 | ### Closed-Form Factorization (https://arxiv.org/abs/2007.06600) 51 | 52 | You can use `closed_form_factorization.py` and `apply_factor.py` to discover meaningful latent semantic factor or directions in unsupervised manner. 53 | 54 | First, you need to extract eigenvectors of weight matrices using `closed_form_factorization.py` 55 | 56 | > python closed_form_factorization.py [CHECKPOINT] 57 | 58 | This will create factor file that contains eigenvectors. (Default: factor.pt) And you can use `apply_factor.py` to test the meaning of extracted directions 59 | 60 | > python apply_factor.py -i [INDEX_OF_EIGENVECTOR] -d [DEGREE_OF_MOVE] -n [NUMBER_OF_SAMPLES] --ckpt [CHECKPOINT] [FACTOR_FILE] 61 | 62 | For example, 63 | 64 | > python apply_factor.py -i 19 -d 5 -n 10 --ckpt [CHECKPOINT] factor.pt 65 | 66 | Will generate 10 random samples, and samples generated from latents that moved along 19th eigenvector with size/degree +-5. 67 | 68 | ![Sample of closed form factorization](factor_index-13_degree-5.0.png) 69 | 70 | ## Pretrained Checkpoints 71 | 72 | [Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO) 73 | 74 | I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences. 75 | 76 | ## Samples 77 | 78 | ![Sample with truncation](doc/sample.png) 79 | 80 | Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images) 81 | 82 | ![MetFaces sample with non-leaking augmentations](doc/sample-metfaces.png) 83 | 84 | Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images) 85 | 86 | 87 | ### Samples from converted weights 88 | 89 | ![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png) 90 | 91 | Sample from FFHQ (1024px) 92 | 93 | ![Sample from LSUN Church](doc/stylegan2-church-config-f.png) 94 | 95 | Sample from LSUN Church (256px) 96 | 97 | ## License 98 | 99 | Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2 100 | 101 | Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity 102 | 103 | To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid 104 | -------------------------------------------------------------------------------- /apply_factor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | 6 | from model import Generator 7 | 8 | 9 | if __name__ == "__main__": 10 | torch.set_grad_enabled(False) 11 | 12 | parser = argparse.ArgumentParser(description="Apply closed form factorization") 13 | 14 | parser.add_argument( 15 | "-i", "--index", type=int, default=0, help="index of eigenvector" 16 | ) 17 | parser.add_argument( 18 | "-d", 19 | "--degree", 20 | type=float, 21 | default=5, 22 | help="scalar factors for moving latent vectors along eigenvector", 23 | ) 24 | parser.add_argument( 25 | "--channel_multiplier", 26 | type=int, 27 | default=2, 28 | help='channel multiplier factor. config-f = 2, else = 1', 29 | ) 30 | parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints") 31 | parser.add_argument( 32 | "--size", type=int, default=256, help="output image size of the generator" 33 | ) 34 | parser.add_argument( 35 | "-n", "--n_sample", type=int, default=7, help="number of samples created" 36 | ) 37 | parser.add_argument( 38 | "--truncation", type=float, default=0.7, help="truncation factor" 39 | ) 40 | parser.add_argument( 41 | "--device", type=str, default="cuda", help="device to run the model" 42 | ) 43 | parser.add_argument( 44 | "--out_prefix", 45 | type=str, 46 | default="factor", 47 | help="filename prefix to result samples", 48 | ) 49 | parser.add_argument( 50 | "factor", 51 | type=str, 52 | help="name of the closed form factorization result factor file", 53 | ) 54 | 55 | args = parser.parse_args() 56 | 57 | eigvec = torch.load(args.factor)["eigvec"].to(args.device) 58 | ckpt = torch.load(args.ckpt) 59 | g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device) 60 | g.load_state_dict(ckpt["g_ema"], strict=False) 61 | 62 | trunc = g.mean_latent(4096) 63 | 64 | latent = torch.randn(args.n_sample, 512, device=args.device) 65 | latent = g.get_latent(latent) 66 | 67 | direction = args.degree * eigvec[:, args.index].unsqueeze(0) 68 | 69 | img, _ = g( 70 | [latent], 71 | truncation=args.truncation, 72 | truncation_latent=trunc, 73 | input_is_latent=True, 74 | ) 75 | img1, _ = g( 76 | [latent + direction], 77 | truncation=args.truncation, 78 | truncation_latent=trunc, 79 | input_is_latent=True, 80 | ) 81 | img2, _ = g( 82 | [latent - direction], 83 | truncation=args.truncation, 84 | truncation_latent=trunc, 85 | input_is_latent=True, 86 | ) 87 | 88 | grid = utils.save_image( 89 | torch.cat([img1, img, img2], 0), 90 | f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", 91 | normalize=True, 92 | range=(-1, 1), 93 | nrow=args.n_sample, 94 | ) 95 | -------------------------------------------------------------------------------- /calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.models import inception_v3, Inception3 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from inception import InceptionV3 15 | from dataset import MultiResolutionDataset 16 | 17 | 18 | class Inception3Feature(Inception3): 19 | def forward(self, x): 20 | if x.shape[2] != 299 or x.shape[3] != 299: 21 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True) 22 | 23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 27 | 28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 31 | 32 | x = self.Mixed_5b(x) # 35 x 35 x 192 33 | x = self.Mixed_5c(x) # 35 x 35 x 256 34 | x = self.Mixed_5d(x) # 35 x 35 x 288 35 | 36 | x = self.Mixed_6a(x) # 35 x 35 x 288 37 | x = self.Mixed_6b(x) # 17 x 17 x 768 38 | x = self.Mixed_6c(x) # 17 x 17 x 768 39 | x = self.Mixed_6d(x) # 17 x 17 x 768 40 | x = self.Mixed_6e(x) # 17 x 17 x 768 41 | 42 | x = self.Mixed_7a(x) # 17 x 17 x 768 43 | x = self.Mixed_7b(x) # 8 x 8 x 1280 44 | x = self.Mixed_7c(x) # 8 x 8 x 2048 45 | 46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 47 | 48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 49 | 50 | 51 | def load_patched_inception_v3(): 52 | # inception = inception_v3(pretrained=True) 53 | # inception_feat = Inception3Feature() 54 | # inception_feat.load_state_dict(inception.state_dict()) 55 | inception_feat = InceptionV3([3], normalize_input=False) 56 | 57 | return inception_feat 58 | 59 | 60 | @torch.no_grad() 61 | def extract_features(loader, inception, device): 62 | pbar = tqdm(loader) 63 | 64 | feature_list = [] 65 | 66 | for img in pbar: 67 | img = img.to(device) 68 | feature = inception(img)[0].view(img.shape[0], -1) 69 | feature_list.append(feature.to("cpu")) 70 | 71 | features = torch.cat(feature_list, 0) 72 | 73 | return features 74 | 75 | 76 | if __name__ == "__main__": 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | 79 | parser = argparse.ArgumentParser( 80 | description="Calculate Inception v3 features for datasets" 81 | ) 82 | parser.add_argument( 83 | "--size", 84 | type=int, 85 | default=256, 86 | help="image sizes used for embedding calculation", 87 | ) 88 | parser.add_argument( 89 | "--batch", default=64, type=int, help="batch size for inception networks" 90 | ) 91 | parser.add_argument( 92 | "--n_sample", 93 | type=int, 94 | default=50000, 95 | help="number of samples used for embedding calculation", 96 | ) 97 | parser.add_argument( 98 | "--flip", action="store_true", help="apply random flipping to real images" 99 | ) 100 | parser.add_argument("path", metavar="PATH", help="path to datset lmdb file") 101 | 102 | args = parser.parse_args() 103 | 104 | inception = load_patched_inception_v3() 105 | inception = nn.DataParallel(inception).eval().to(device) 106 | 107 | transform = transforms.Compose( 108 | [ 109 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 112 | ] 113 | ) 114 | 115 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 116 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 117 | 118 | features = extract_features(loader, inception, device).numpy() 119 | 120 | features = features[: args.n_sample] 121 | 122 | print(f"extracted {features.shape[0]} features") 123 | 124 | mean = np.mean(features, 0) 125 | cov = np.cov(features, rowvar=False) 126 | 127 | name = os.path.splitext(os.path.basename(args.path))[0] 128 | 129 | with open(f"inception_{name}.pkl", "wb") as f: 130 | pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f) 131 | -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /closed_form_factorization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser( 8 | description="Extract factor/eigenvectors of latent spaces using closed form factorization" 9 | ) 10 | 11 | parser.add_argument( 12 | "--out", type=str, default="factor.pt", help="name of the result factor file" 13 | ) 14 | parser.add_argument("ckpt", type=str, help="name of the model checkpoint") 15 | 16 | args = parser.parse_args() 17 | 18 | ckpt = torch.load(args.ckpt) 19 | modulate = { 20 | k: v 21 | for k, v in ckpt["g_ema"].items() 22 | if "modulation" in k and "to_rgbs" not in k and "weight" in k 23 | } 24 | 25 | weight_mat = [] 26 | for k, v in modulate.items(): 27 | weight_mat.append(v) 28 | 29 | W = torch.cat(weight_mat, 0) 30 | eigvec = torch.svd(W).V.to("cpu") 31 | 32 | torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out) 33 | 34 | -------------------------------------------------------------------------------- /convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pickle 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import utils 10 | 11 | from model import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + "/weight"].value().eval() 16 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 17 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 18 | noise = vars[source_name + "/noise_strength"].value().eval() 19 | bias = vars[source_name + "/bias"].value().eval() 20 | 21 | dic = { 22 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 24 | "conv.modulation.bias": mod_bias + 1, 25 | "noise.weight": np.array([noise]), 26 | "activate.bias": bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + ".conv.weight"] = torch.flip( 36 | dic_torch[target_name + ".conv.weight"], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + "/weight"].value().eval() 44 | 45 | dic = {"weight": weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic["bias"] = vars[source_name + "/bias"].value().eval() 49 | 50 | dic_torch = {} 51 | 52 | dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) 53 | 54 | if bias: 55 | dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) 56 | 57 | return dic_torch 58 | 59 | 60 | def convert_torgb(vars, source_name, target_name): 61 | weight = vars[source_name + "/weight"].value().eval() 62 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 63 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 64 | bias = vars[source_name + "/bias"].value().eval() 65 | 66 | dic = { 67 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 68 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 69 | "conv.modulation.bias": mod_bias + 1, 70 | "bias": bias.reshape((1, 3, 1, 1)), 71 | } 72 | 73 | dic_torch = {} 74 | 75 | for k, v in dic.items(): 76 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 77 | 78 | return dic_torch 79 | 80 | 81 | def convert_dense(vars, source_name, target_name): 82 | weight = vars[source_name + "/weight"].value().eval() 83 | bias = vars[source_name + "/bias"].value().eval() 84 | 85 | dic = {"weight": weight.transpose((1, 0)), "bias": bias} 86 | 87 | dic_torch = {} 88 | 89 | for k, v in dic.items(): 90 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 91 | 92 | return dic_torch 93 | 94 | 95 | def update(state_dict, new): 96 | for k, v in new.items(): 97 | if k not in state_dict: 98 | raise KeyError(k + " is not found") 99 | 100 | if v.shape != state_dict[k].shape: 101 | raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") 102 | 103 | state_dict[k] = v 104 | 105 | 106 | def discriminator_fill_statedict(statedict, vars, size): 107 | log_size = int(math.log(size, 2)) 108 | 109 | update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) 110 | 111 | conv_i = 1 112 | 113 | for i in range(log_size - 2, 0, -1): 114 | reso = 4 * 2 ** i 115 | update( 116 | statedict, 117 | convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), 118 | ) 119 | update( 120 | statedict, 121 | convert_conv( 122 | vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 123 | ), 124 | ) 125 | update( 126 | statedict, 127 | convert_conv( 128 | vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False 129 | ), 130 | ) 131 | conv_i += 1 132 | 133 | update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) 134 | update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) 135 | update(statedict, convert_dense(vars, f"Output", "final_linear.1")) 136 | 137 | return statedict 138 | 139 | 140 | def fill_statedict(state_dict, vars, size,n_mlp=8): 141 | log_size = int(math.log(size, 2)) 142 | 143 | for i in range(n_mlp): 144 | update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) 145 | 146 | update( 147 | state_dict, 148 | { 149 | "input.input": torch.from_numpy( 150 | vars["G_synthesis/4x4/Const/const"].value().eval() 151 | ) 152 | }, 153 | ) 154 | 155 | update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) 156 | 157 | for i in range(log_size - 2): 158 | reso = 4 * 2 ** (i + 1) 159 | update( 160 | state_dict, 161 | convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), 162 | ) 163 | 164 | update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) 165 | 166 | conv_i = 0 167 | 168 | for i in range(log_size - 2): 169 | reso = 4 * 2 ** (i + 1) 170 | update( 171 | state_dict, 172 | convert_modconv( 173 | vars, 174 | f"G_synthesis/{reso}x{reso}/Conv0_up", 175 | f"convs.{conv_i}", 176 | flip=True, 177 | ), 178 | ) 179 | update( 180 | state_dict, 181 | convert_modconv( 182 | vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}" 183 | ), 184 | ) 185 | conv_i += 2 186 | 187 | for i in range(0, (log_size - 2) * 2 + 1): 188 | update( 189 | state_dict, 190 | { 191 | f"noises.noise_{i}": torch.from_numpy( 192 | vars[f"G_synthesis/noise{i}"].value().eval() 193 | ) 194 | }, 195 | ) 196 | 197 | return state_dict 198 | 199 | 200 | def convertStyleGan2(_G,_D,Gs,channel_multiplier = 4,style_dim=1024,n_mlp=4,max_channel_size=1024): 201 | generator, discriminator, g_ema = _G, _D, Gs 202 | 203 | size = g_ema.output_shape[2] 204 | 205 | g = Generator(size, style_dim, n_mlp, channel_multiplier=channel_multiplier,max_channel_size=max_channel_size) 206 | state_dict = g.state_dict() 207 | state_dict = fill_statedict(state_dict, g_ema.vars, size,n_mlp) 208 | 209 | g.load_state_dict(state_dict) 210 | 211 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 212 | 213 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 214 | 215 | 216 | #convert _G 217 | g_train = Generator(size, style_dim, n_mlp, channel_multiplier=channel_multiplier,max_channel_size=max_channel_size) 218 | g_train_state = g_train.state_dict() 219 | g_train_state = fill_statedict(g_train_state, generator.vars, size,n_mlp) 220 | ckpt["g"] = g_train_state 221 | 222 | 223 | #convert discriminator 224 | channel_multiplier=2 225 | disc = Discriminator(size, 226 | channel_multiplier=channel_multiplier, 227 | stddev_group = 32, 228 | stddev_feat = 4) 229 | 230 | d_state = disc.state_dict() 231 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 232 | 233 | disc.load_state_dict(d_state) 234 | 235 | ckpt["d"] = d_state 236 | 237 | 238 | 239 | 240 | return ckpt, g, disc, g_train 241 | 242 | 243 | 244 | 245 | 246 | if __name__ == "__main__": 247 | device = "cuda" 248 | 249 | parser = argparse.ArgumentParser( 250 | description="Tensorflow to pytorch model checkpoint converter" 251 | ) 252 | parser.add_argument( 253 | "--repo", 254 | type=str, 255 | required=True, 256 | help="path to the offical StyleGAN2 repository with dnnlib/ folder", 257 | ) 258 | parser.add_argument( 259 | "--gen", action="store_true", help="convert the generator weights" 260 | ) 261 | parser.add_argument( 262 | "--disc", action="store_true", help="convert the discriminator weights" 263 | ) 264 | parser.add_argument( 265 | "--channel_multiplier", 266 | type=int, 267 | default=2, 268 | help="channel multiplier factor. config-f = 2, else = 1", 269 | ) 270 | parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") 271 | 272 | args = parser.parse_args() 273 | 274 | sys.path.append(args.repo) 275 | 276 | import dnnlib 277 | from dnnlib import tflib 278 | 279 | tflib.init_tf() 280 | 281 | with open(args.path, "rb") as f: 282 | generator, discriminator, g_ema = pickle.load(f) 283 | 284 | size = g_ema.output_shape[2] 285 | 286 | g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 287 | state_dict = g.state_dict() 288 | state_dict = fill_statedict(state_dict, g_ema.vars, size) 289 | 290 | g.load_state_dict(state_dict) 291 | 292 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 293 | 294 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 295 | 296 | if args.gen: 297 | g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 298 | g_train_state = g_train.state_dict() 299 | g_train_state = fill_statedict(g_train_state, generator.vars, size) 300 | ckpt["g"] = g_train_state 301 | 302 | if args.disc: 303 | disc = Discriminator(size, channel_multiplier=args.channel_multiplier) 304 | d_state = disc.state_dict() 305 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 306 | ckpt["d"] = d_state 307 | 308 | name = os.path.splitext(os.path.basename(args.path))[0] 309 | torch.save(ckpt, name + ".pt") 310 | 311 | batch_size = {256: 16, 512: 9, 1024: 4} 312 | n_sample = batch_size.get(size, 25) 313 | 314 | g = g.to(device) 315 | 316 | z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") 317 | 318 | with torch.no_grad(): 319 | img_pt, _ = g( 320 | [torch.from_numpy(z).to(device)], 321 | truncation=0.5, 322 | truncation_latent=latent_avg.to(device), 323 | randomize_noise=False, 324 | ) 325 | 326 | Gs_kwargs = dnnlib.EasyDict() 327 | Gs_kwargs.randomize_noise = False 328 | img_tf = g_ema.run(z, None, **Gs_kwargs) 329 | img_tf = torch.from_numpy(img_tf).to(device) 330 | 331 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( 332 | 0.0, 1.0 333 | ) 334 | 335 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 336 | 337 | print(img_diff.abs().max()) 338 | 339 | utils.save_image( 340 | img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) 341 | ) 342 | 343 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MultiResolutionDataset(Dataset): 9 | def __init__(self, path, transform, resolution=256): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | 40 | return img 41 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /doc/sample-metfaces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/doc/sample-metfaces.png -------------------------------------------------------------------------------- /doc/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/doc/sample.png -------------------------------------------------------------------------------- /doc/stylegan2-church-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/doc/stylegan2-church-config-f.png -------------------------------------------------------------------------------- /doc/stylegan2-ffhq-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/doc/stylegan2-ffhq-config-f.png -------------------------------------------------------------------------------- /factor_index-13_degree-5.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/factor_index-13_degree-5.0.png -------------------------------------------------------------------------------- /fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import linalg 8 | from tqdm import tqdm 9 | 10 | from model import Generator 11 | from calc_inception import load_patched_inception_v3 12 | 13 | 14 | @torch.no_grad() 15 | def extract_feature_from_samples( 16 | generator, inception, truncation, truncation_latent, batch_size, n_sample, device 17 | ): 18 | n_batch = n_sample // batch_size 19 | resid = n_sample - (n_batch * batch_size) 20 | batch_sizes = [batch_size] * n_batch + [resid] 21 | features = [] 22 | 23 | for batch in tqdm(batch_sizes): 24 | latent = torch.randn(batch, 512, device=device) 25 | img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) 26 | feat = inception(img)[0].view(img.shape[0], -1) 27 | features.append(feat.to("cpu")) 28 | 29 | features = torch.cat(features, 0) 30 | 31 | return features 32 | 33 | 34 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 35 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 36 | 37 | if not np.isfinite(cov_sqrt).all(): 38 | print("product of cov matrices is singular") 39 | offset = np.eye(sample_cov.shape[0]) * eps 40 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 41 | 42 | if np.iscomplexobj(cov_sqrt): 43 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 44 | m = np.max(np.abs(cov_sqrt.imag)) 45 | 46 | raise ValueError(f"Imaginary component {m}") 47 | 48 | cov_sqrt = cov_sqrt.real 49 | 50 | mean_diff = sample_mean - real_mean 51 | mean_norm = mean_diff @ mean_diff 52 | 53 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 54 | 55 | fid = mean_norm + trace 56 | 57 | return fid 58 | 59 | 60 | if __name__ == "__main__": 61 | device = "cuda" 62 | 63 | parser = argparse.ArgumentParser(description="Calculate FID scores") 64 | 65 | parser.add_argument("--truncation", type=float, default=1, help="truncation factor") 66 | parser.add_argument( 67 | "--truncation_mean", 68 | type=int, 69 | default=4096, 70 | help="number of samples to calculate mean for truncation", 71 | ) 72 | parser.add_argument( 73 | "--batch", type=int, default=64, help="batch size for the generator" 74 | ) 75 | parser.add_argument( 76 | "--n_sample", 77 | type=int, 78 | default=50000, 79 | help="number of the samples for calculating FID", 80 | ) 81 | parser.add_argument( 82 | "--size", type=int, default=256, help="image sizes for generator" 83 | ) 84 | parser.add_argument( 85 | "--inception", 86 | type=str, 87 | default=None, 88 | required=True, 89 | help="path to precomputed inception embedding", 90 | ) 91 | parser.add_argument( 92 | "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint" 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | ckpt = torch.load(args.ckpt) 98 | 99 | g = Generator(args.size, 512, 8).to(device) 100 | g.load_state_dict(ckpt["g_ema"]) 101 | g = nn.DataParallel(g) 102 | g.eval() 103 | 104 | if args.truncation < 1: 105 | with torch.no_grad(): 106 | mean_latent = g.mean_latent(args.truncation_mean) 107 | 108 | else: 109 | mean_latent = None 110 | 111 | inception = nn.DataParallel(load_patched_inception_v3()).to(device) 112 | inception.eval() 113 | 114 | features = extract_feature_from_samples( 115 | g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device 116 | ).numpy() 117 | print(f"extracted {features.shape[0]} features") 118 | 119 | sample_mean = np.mean(features, 0) 120 | sample_cov = np.cov(features, rowvar=False) 121 | 122 | with open(args.inception, "rb") as f: 123 | embeds = pickle.load(f) 124 | real_mean = embeds["mean"] 125 | real_cov = embeds["cov"] 126 | 127 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 128 | 129 | print("fid:", fid) 130 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model import Generator 6 | from tqdm import tqdm 7 | 8 | 9 | def generate(args, g_ema, device, mean_latent): 10 | 11 | with torch.no_grad(): 12 | g_ema.eval() 13 | for i in tqdm(range(args.pics)): 14 | sample_z = torch.randn(args.sample, args.latent, device=device) 15 | 16 | sample, _ = g_ema( 17 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 18 | ) 19 | 20 | utils.save_image( 21 | sample, 22 | f"sample/{str(i).zfill(6)}.png", 23 | nrow=1, 24 | normalize=True, 25 | range=(-1, 1), 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | device = "cuda" 31 | 32 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 33 | 34 | parser.add_argument( 35 | "--size", type=int, default=1024, help="output image size of the generator" 36 | ) 37 | parser.add_argument( 38 | "--sample", 39 | type=int, 40 | default=1, 41 | help="number of samples to be generated for each image", 42 | ) 43 | parser.add_argument( 44 | "--pics", type=int, default=20, help="number of images to be generated" 45 | ) 46 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 47 | parser.add_argument( 48 | "--truncation_mean", 49 | type=int, 50 | default=4096, 51 | help="number of vectors to calculate mean for the truncation", 52 | ) 53 | parser.add_argument( 54 | "--ckpt", 55 | type=str, 56 | default="stylegan2-ffhq-config-f.pt", 57 | help="path to the model checkpoint", 58 | ) 59 | parser.add_argument( 60 | "--channel_multiplier", 61 | type=int, 62 | default=2, 63 | help="channel multiplier of the generator. config-f = 2, else = 1", 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | args.latent = 512 69 | args.n_mlp = 8 70 | 71 | g_ema = Generator( 72 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 73 | ).to(device) 74 | checkpoint = torch.load(args.ckpt) 75 | 76 | g_ema.load_state_dict(checkpoint["g_ema"]) 77 | 78 | if args.truncation < 1: 79 | with torch.no_grad(): 80 | mean_latent = g_ema.mean_latent(args.truncation_mean) 81 | else: 82 | mean_latent = None 83 | 84 | generate(args, g_ema, device, mean_latent) 85 | -------------------------------------------------------------------------------- /inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 189 | inception.load_state_dict(state_dict) 190 | return inception 191 | 192 | 193 | class FIDInceptionA(models.inception.InceptionA): 194 | """InceptionA block patched for FID computation""" 195 | def __init__(self, in_channels, pool_features): 196 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 197 | 198 | def forward(self, x): 199 | branch1x1 = self.branch1x1(x) 200 | 201 | branch5x5 = self.branch5x5_1(x) 202 | branch5x5 = self.branch5x5_2(branch5x5) 203 | 204 | branch3x3dbl = self.branch3x3dbl_1(x) 205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 207 | 208 | # Patch: Tensorflow's average pool does not use the padded zero's in 209 | # its average calculation 210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 211 | count_include_pad=False) 212 | branch_pool = self.branch_pool(branch_pool) 213 | 214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 215 | return torch.cat(outputs, 1) 216 | 217 | 218 | class FIDInceptionC(models.inception.InceptionC): 219 | """InceptionC block patched for FID computation""" 220 | def __init__(self, in_channels, channels_7x7): 221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 222 | 223 | def forward(self, x): 224 | branch1x1 = self.branch1x1(x) 225 | 226 | branch7x7 = self.branch7x7_1(x) 227 | branch7x7 = self.branch7x7_2(branch7x7) 228 | branch7x7 = self.branch7x7_3(branch7x7) 229 | 230 | branch7x7dbl = self.branch7x7dbl_1(x) 231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 235 | 236 | # Patch: Tensorflow's average pool does not use the padded zero's in 237 | # its average calculation 238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 239 | count_include_pad=False) 240 | branch_pool = self.branch_pool(branch_pool) 241 | 242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class FIDInceptionE_1(models.inception.InceptionE): 247 | """First InceptionE block patched for FID computation""" 248 | def __init__(self, in_channels): 249 | super(FIDInceptionE_1, self).__init__(in_channels) 250 | 251 | def forward(self, x): 252 | branch1x1 = self.branch1x1(x) 253 | 254 | branch3x3 = self.branch3x3_1(x) 255 | branch3x3 = [ 256 | self.branch3x3_2a(branch3x3), 257 | self.branch3x3_2b(branch3x3), 258 | ] 259 | branch3x3 = torch.cat(branch3x3, 1) 260 | 261 | branch3x3dbl = self.branch3x3dbl_1(x) 262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 263 | branch3x3dbl = [ 264 | self.branch3x3dbl_3a(branch3x3dbl), 265 | self.branch3x3dbl_3b(branch3x3dbl), 266 | ] 267 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 268 | 269 | # Patch: Tensorflow's average pool does not use the padded zero's in 270 | # its average calculation 271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 272 | count_include_pad=False) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | 278 | 279 | class FIDInceptionE_2(models.inception.InceptionE): 280 | """Second InceptionE block patched for FID computation""" 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /inception_ffhq.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nagolinc/stylegan2-pytorch/dcbe4a33e22b244be6d2a735dbf3a347dafc73f5/inception_ffhq.pkl -------------------------------------------------------------------------------- /lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from lpips import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer("kernel", kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 166 | ) 167 | 168 | 169 | class ModulatedConv2d(nn.Module): 170 | def __init__( 171 | self, 172 | in_channel, 173 | out_channel, 174 | kernel_size, 175 | style_dim, 176 | demodulate=True, 177 | upsample=False, 178 | downsample=False, 179 | blur_kernel=[1, 3, 3, 1], 180 | ): 181 | super().__init__() 182 | 183 | self.eps = 1e-8 184 | self.kernel_size = kernel_size 185 | self.in_channel = in_channel 186 | self.out_channel = out_channel 187 | self.upsample = upsample 188 | self.downsample = downsample 189 | 190 | if upsample: 191 | factor = 2 192 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 193 | pad0 = (p + 1) // 2 + factor - 1 194 | pad1 = p // 2 + 1 195 | 196 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 197 | 198 | if downsample: 199 | factor = 2 200 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 201 | pad0 = (p + 1) // 2 202 | pad1 = p // 2 203 | 204 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 205 | 206 | fan_in = in_channel * kernel_size ** 2 207 | self.scale = 1 / math.sqrt(fan_in) 208 | self.padding = kernel_size // 2 209 | 210 | self.weight = nn.Parameter( 211 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 212 | ) 213 | 214 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 215 | 216 | self.demodulate = demodulate 217 | 218 | def __repr__(self): 219 | return ( 220 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 221 | f"upsample={self.upsample}, downsample={self.downsample})" 222 | ) 223 | 224 | def forward(self, input, style): 225 | batch, in_channel, height, width = input.shape 226 | 227 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 228 | weight = self.scale * self.weight * style 229 | 230 | if self.demodulate: 231 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 232 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 233 | 234 | weight = weight.view( 235 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 236 | ) 237 | 238 | if self.upsample: 239 | input = input.view(1, batch * in_channel, height, width) 240 | weight = weight.view( 241 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 242 | ) 243 | weight = weight.transpose(1, 2).reshape( 244 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 245 | ) 246 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 247 | _, _, height, width = out.shape 248 | out = out.view(batch, self.out_channel, height, width) 249 | out = self.blur(out) 250 | 251 | elif self.downsample: 252 | input = self.blur(input) 253 | _, _, height, width = input.shape 254 | input = input.view(1, batch * in_channel, height, width) 255 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 256 | _, _, height, width = out.shape 257 | out = out.view(batch, self.out_channel, height, width) 258 | 259 | else: 260 | input = input.view(1, batch * in_channel, height, width) 261 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 262 | _, _, height, width = out.shape 263 | out = out.view(batch, self.out_channel, height, width) 264 | 265 | return out 266 | 267 | 268 | class NoiseInjection(nn.Module): 269 | def __init__(self): 270 | super().__init__() 271 | 272 | self.weight = nn.Parameter(torch.zeros(1)) 273 | 274 | def forward(self, image, noise=None): 275 | if noise is None: 276 | batch, _, height, width = image.shape 277 | noise = image.new_empty(batch, 1, height, width).normal_() 278 | 279 | return image + self.weight * noise 280 | 281 | 282 | class ConstantInput(nn.Module): 283 | def __init__(self, channel, size=4): 284 | super().__init__() 285 | 286 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 287 | 288 | def forward(self, input): 289 | batch = input.shape[0] 290 | out = self.input.repeat(batch, 1, 1, 1) 291 | 292 | return out 293 | 294 | 295 | class StyledConv(nn.Module): 296 | def __init__( 297 | self, 298 | in_channel, 299 | out_channel, 300 | kernel_size, 301 | style_dim, 302 | upsample=False, 303 | blur_kernel=[1, 3, 3, 1], 304 | demodulate=True, 305 | ): 306 | super().__init__() 307 | 308 | self.conv = ModulatedConv2d( 309 | in_channel, 310 | out_channel, 311 | kernel_size, 312 | style_dim, 313 | upsample=upsample, 314 | blur_kernel=blur_kernel, 315 | demodulate=demodulate, 316 | ) 317 | 318 | self.noise = NoiseInjection() 319 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 320 | # self.activate = ScaledLeakyReLU(0.2) 321 | self.activate = FusedLeakyReLU(out_channel) 322 | 323 | def forward(self, input, style, noise=None): 324 | out = self.conv(input, style) 325 | out = self.noise(out, noise=noise) 326 | # out = out + self.bias 327 | out = self.activate(out) 328 | 329 | return out 330 | 331 | 332 | class ToRGB(nn.Module): 333 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 334 | super().__init__() 335 | 336 | if upsample: 337 | self.upsample = Upsample(blur_kernel) 338 | 339 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 340 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 341 | 342 | def forward(self, input, style, skip=None): 343 | out = self.conv(input, style) 344 | out = out + self.bias 345 | 346 | if skip is not None: 347 | skip = self.upsample(skip) 348 | 349 | out = out + skip 350 | 351 | return out 352 | 353 | 354 | class Generator(nn.Module): 355 | def __init__( 356 | self, 357 | size, 358 | style_dim, 359 | n_mlp, 360 | channel_multiplier=2, 361 | blur_kernel=[1, 3, 3, 1], 362 | lr_mlp=0.01, 363 | max_channel_size=512 364 | ): 365 | super().__init__() 366 | 367 | self.size = size 368 | 369 | self.style_dim = style_dim 370 | 371 | layers = [PixelNorm()] 372 | 373 | for i in range(n_mlp): 374 | layers.append( 375 | EqualLinear( 376 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 377 | ) 378 | ) 379 | 380 | self.style = nn.Sequential(*layers) 381 | 382 | self.channels = { 383 | 4: min(max_channel_size, 4096 * channel_multiplier), 384 | 8: min(max_channel_size, 2048 * channel_multiplier), 385 | 16: min(max_channel_size, 1024 * channel_multiplier), 386 | 32: min(max_channel_size, 512 * channel_multiplier), 387 | 64: 256 * channel_multiplier, 388 | 128: 128 * channel_multiplier, 389 | 256: 64 * channel_multiplier, 390 | 512: 32 * channel_multiplier, 391 | 1024: 16 * channel_multiplier, 392 | } 393 | 394 | self.input = ConstantInput(self.channels[4]) 395 | self.conv1 = StyledConv( 396 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 397 | ) 398 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 399 | 400 | self.log_size = int(math.log(size, 2)) 401 | self.num_layers = (self.log_size - 2) * 2 + 1 402 | 403 | self.convs = nn.ModuleList() 404 | self.upsamples = nn.ModuleList() 405 | self.to_rgbs = nn.ModuleList() 406 | self.noises = nn.Module() 407 | 408 | in_channel = self.channels[4] 409 | 410 | for layer_idx in range(self.num_layers): 411 | res = (layer_idx + 5) // 2 412 | shape = [1, 1, 2 ** res, 2 ** res] 413 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 414 | 415 | for i in range(3, self.log_size + 1): 416 | out_channel = self.channels[2 ** i] 417 | 418 | self.convs.append( 419 | StyledConv( 420 | in_channel, 421 | out_channel, 422 | 3, 423 | style_dim, 424 | upsample=True, 425 | blur_kernel=blur_kernel, 426 | ) 427 | ) 428 | 429 | self.convs.append( 430 | StyledConv( 431 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 432 | ) 433 | ) 434 | 435 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 436 | 437 | in_channel = out_channel 438 | 439 | self.n_latent = self.log_size * 2 - 2 440 | 441 | def make_noise(self): 442 | device = self.input.input.device 443 | 444 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 445 | 446 | for i in range(3, self.log_size + 1): 447 | for _ in range(2): 448 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 449 | 450 | return noises 451 | 452 | def mean_latent(self, n_latent): 453 | latent_in = torch.randn( 454 | n_latent, self.style_dim, device=self.input.input.device 455 | ) 456 | latent = self.style(latent_in).mean(0, keepdim=True) 457 | 458 | return latent 459 | 460 | def get_latent(self, input): 461 | return self.style(input) 462 | 463 | def forward( 464 | self, 465 | styles, 466 | return_latents=False, 467 | inject_index=None, 468 | truncation=1, 469 | truncation_latent=None, 470 | input_is_latent=False, 471 | noise=None, 472 | randomize_noise=True, 473 | ): 474 | if not input_is_latent: 475 | styles = [self.style(s) for s in styles] 476 | 477 | if noise is None: 478 | if randomize_noise: 479 | noise = [None] * self.num_layers 480 | else: 481 | noise = [ 482 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 483 | ] 484 | 485 | if truncation < 1: 486 | style_t = [] 487 | 488 | for style in styles: 489 | style_t.append( 490 | truncation_latent + truncation * (style - truncation_latent) 491 | ) 492 | 493 | styles = style_t 494 | 495 | if len(styles) < 2: 496 | inject_index = self.n_latent 497 | 498 | if styles[0].ndim < 3: 499 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 500 | 501 | else: 502 | latent = styles[0] 503 | 504 | else: 505 | if inject_index is None: 506 | inject_index = random.randint(1, self.n_latent - 1) 507 | 508 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 509 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 510 | 511 | latent = torch.cat([latent, latent2], 1) 512 | 513 | out = self.input(latent) 514 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 515 | 516 | skip = self.to_rgb1(out, latent[:, 1]) 517 | 518 | i = 1 519 | for conv1, conv2, noise1, noise2, to_rgb in zip( 520 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 521 | ): 522 | out = conv1(out, latent[:, i], noise=noise1) 523 | out = conv2(out, latent[:, i + 1], noise=noise2) 524 | skip = to_rgb(out, latent[:, i + 2], skip) 525 | 526 | i += 2 527 | 528 | image = skip 529 | 530 | if return_latents: 531 | return image, latent 532 | 533 | else: 534 | return image, None 535 | 536 | 537 | class ConvLayer(nn.Sequential): 538 | def __init__( 539 | self, 540 | in_channel, 541 | out_channel, 542 | kernel_size, 543 | downsample=False, 544 | blur_kernel=[1, 3, 3, 1], 545 | bias=True, 546 | activate=True, 547 | ): 548 | layers = [] 549 | 550 | if downsample: 551 | factor = 2 552 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 553 | pad0 = (p + 1) // 2 554 | pad1 = p // 2 555 | 556 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 557 | 558 | stride = 2 559 | self.padding = 0 560 | 561 | else: 562 | stride = 1 563 | self.padding = kernel_size // 2 564 | 565 | layers.append( 566 | EqualConv2d( 567 | in_channel, 568 | out_channel, 569 | kernel_size, 570 | padding=self.padding, 571 | stride=stride, 572 | bias=bias and not activate, 573 | ) 574 | ) 575 | 576 | if activate: 577 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 578 | 579 | super().__init__(*layers) 580 | 581 | 582 | class ResBlock(nn.Module): 583 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 584 | super().__init__() 585 | 586 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 587 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 588 | 589 | self.skip = ConvLayer( 590 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 591 | ) 592 | 593 | def forward(self, input): 594 | out = self.conv1(input) 595 | out = self.conv2(out) 596 | 597 | skip = self.skip(input) 598 | out = (out + skip) / math.sqrt(2) 599 | 600 | return out 601 | 602 | 603 | class Discriminator(nn.Module): 604 | def __init__(self, 605 | size, 606 | channel_multiplier=2, 607 | blur_kernel=[1, 3, 3, 1], 608 | stddev_group = 4, 609 | stddev_feat = 1 610 | ): 611 | 612 | super().__init__() 613 | 614 | self.stddev_group = stddev_group 615 | self.stddev_feat = stddev_feat 616 | 617 | channels = { 618 | 4: 512, 619 | 8: 512, 620 | 16: 512, 621 | 32: 512, 622 | 64: 256 * channel_multiplier, 623 | 128: 128 * channel_multiplier, 624 | 256: 64 * channel_multiplier, 625 | 512: 32 * channel_multiplier, 626 | 1024: 16 * channel_multiplier, 627 | } 628 | 629 | convs = [ConvLayer(3, channels[size], 1)] 630 | 631 | log_size = int(math.log(size, 2)) 632 | 633 | in_channel = channels[size] 634 | 635 | for i in range(log_size, 2, -1): 636 | out_channel = channels[2 ** (i - 1)] 637 | 638 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 639 | 640 | in_channel = out_channel 641 | 642 | self.convs = nn.Sequential(*convs) 643 | 644 | 645 | 646 | self.final_conv = ConvLayer(in_channel + stddev_feat, channels[4], 3) 647 | self.final_linear = nn.Sequential( 648 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 649 | EqualLinear(channels[4], 1), 650 | ) 651 | 652 | def forward(self, input): 653 | out = self.convs(input) 654 | 655 | batch, channel, height, width = out.shape 656 | group = min(batch, self.stddev_group) 657 | stddev = out.view( 658 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 659 | ) 660 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 661 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 662 | stddev = stddev.repeat(group, 1, height, width) 663 | out = torch.cat([out, stddev], 1) 664 | 665 | out = self.final_conv(out) 666 | 667 | out = out.view(batch, -1) 668 | out = self.final_linear(out) 669 | 670 | return out 671 | 672 | -------------------------------------------------------------------------------- /non_leaking.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | from distributed import reduce_sum 7 | from op import upfirdn2d 8 | 9 | 10 | class AdaptiveAugment: 11 | def __init__(self, ada_aug_target, ada_aug_len, update_every, device): 12 | self.ada_aug_target = ada_aug_target 13 | self.ada_aug_len = ada_aug_len 14 | self.update_every = update_every 15 | 16 | self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device) 17 | self.r_t_stat = 0 18 | self.ada_aug_p = 0 19 | 20 | @torch.no_grad() 21 | def tune(self, real_pred): 22 | ada_aug_data = torch.tensor( 23 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]), 24 | device=real_pred.device, 25 | ) 26 | self.ada_aug_buf += reduce_sum(ada_aug_data) 27 | 28 | if self.ada_aug_buf[1] > self.update_every - 1: 29 | pred_signs, n_pred = self.ada_aug_buf.tolist() 30 | 31 | self.r_t_stat = pred_signs / n_pred 32 | 33 | if self.r_t_stat > self.ada_aug_target: 34 | sign = 1 35 | 36 | else: 37 | sign = -1 38 | 39 | self.ada_aug_p += sign * n_pred / self.ada_aug_len 40 | self.ada_aug_p = min(1, max(0, self.ada_aug_p)) 41 | self.ada_aug_buf.mul_(0) 42 | 43 | return self.ada_aug_p 44 | 45 | 46 | SYM6 = ( 47 | 0.015404109327027373, 48 | 0.0034907120842174702, 49 | -0.11799011114819057, 50 | -0.048311742585633, 51 | 0.4910559419267466, 52 | 0.787641141030194, 53 | 0.3379294217276218, 54 | -0.07263752278646252, 55 | -0.021060292512300564, 56 | 0.04472490177066578, 57 | 0.0017677118642428036, 58 | -0.007800708325034148, 59 | ) 60 | 61 | 62 | def translate_mat(t_x, t_y): 63 | batch = t_x.shape[0] 64 | 65 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 66 | translate = torch.stack((t_x, t_y), 1) 67 | mat[:, :2, 2] = translate 68 | 69 | return mat 70 | 71 | 72 | def rotate_mat(theta): 73 | batch = theta.shape[0] 74 | 75 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 76 | sin_t = torch.sin(theta) 77 | cos_t = torch.cos(theta) 78 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) 79 | mat[:, :2, :2] = rot 80 | 81 | return mat 82 | 83 | 84 | def scale_mat(s_x, s_y): 85 | batch = s_x.shape[0] 86 | 87 | mat = torch.eye(3).unsqueeze(0).repeat(batch, 1, 1) 88 | mat[:, 0, 0] = s_x 89 | mat[:, 1, 1] = s_y 90 | 91 | return mat 92 | 93 | 94 | def translate3d_mat(t_x, t_y, t_z): 95 | batch = t_x.shape[0] 96 | 97 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 98 | translate = torch.stack((t_x, t_y, t_z), 1) 99 | mat[:, :3, 3] = translate 100 | 101 | return mat 102 | 103 | 104 | def rotate3d_mat(axis, theta): 105 | batch = theta.shape[0] 106 | 107 | u_x, u_y, u_z = axis 108 | 109 | eye = torch.eye(3).unsqueeze(0) 110 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) 111 | outer = torch.tensor(axis) 112 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0) 113 | 114 | sin_t = torch.sin(theta).view(-1, 1, 1) 115 | cos_t = torch.cos(theta).view(-1, 1, 1) 116 | 117 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer 118 | 119 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 120 | eye_4[:, :3, :3] = rot 121 | 122 | return eye_4 123 | 124 | 125 | def scale3d_mat(s_x, s_y, s_z): 126 | batch = s_x.shape[0] 127 | 128 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 129 | mat[:, 0, 0] = s_x 130 | mat[:, 1, 1] = s_y 131 | mat[:, 2, 2] = s_z 132 | 133 | return mat 134 | 135 | 136 | def luma_flip_mat(axis, i): 137 | batch = i.shape[0] 138 | 139 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 140 | axis = torch.tensor(axis + (0,)) 141 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) 142 | 143 | return eye - flip 144 | 145 | 146 | def saturation_mat(axis, i): 147 | batch = i.shape[0] 148 | 149 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 150 | axis = torch.tensor(axis + (0,)) 151 | axis = torch.ger(axis, axis) 152 | saturate = axis + (eye - axis) * i.view(-1, 1, 1) 153 | 154 | return saturate 155 | 156 | 157 | def lognormal_sample(size, mean=0, std=1): 158 | return torch.empty(size).log_normal_(mean=mean, std=std) 159 | 160 | 161 | def category_sample(size, categories): 162 | category = torch.tensor(categories) 163 | sample = torch.randint(high=len(categories), size=(size,)) 164 | 165 | return category[sample] 166 | 167 | 168 | def uniform_sample(size, low, high): 169 | return torch.empty(size).uniform_(low, high) 170 | 171 | 172 | def normal_sample(size, mean=0, std=1): 173 | return torch.empty(size).normal_(mean, std) 174 | 175 | 176 | def bernoulli_sample(size, p): 177 | return torch.empty(size).bernoulli_(p) 178 | 179 | 180 | def random_mat_apply(p, transform, prev, eye): 181 | size = transform.shape[0] 182 | select = bernoulli_sample(size, p).view(size, 1, 1) 183 | select_transform = select * transform + (1 - select) * eye 184 | 185 | return select_transform @ prev 186 | 187 | 188 | def sample_affine(p, size, height, width): 189 | G = torch.eye(3).unsqueeze(0).repeat(size, 1, 1) 190 | eye = G 191 | 192 | # flip 193 | param = category_sample(size, (0, 1)) 194 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size)) 195 | G = random_mat_apply(p, Gc, G, eye) 196 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') 197 | 198 | # 90 rotate 199 | param = category_sample(size, (0, 3)) 200 | Gc = rotate_mat(-math.pi / 2 * param) 201 | G = random_mat_apply(p, Gc, G, eye) 202 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') 203 | 204 | # integer translate 205 | param = uniform_sample(size, -0.125, 0.125) 206 | param_height = torch.round(param * height) / height 207 | param_width = torch.round(param * width) / width 208 | Gc = translate_mat(param_width, param_height) 209 | G = random_mat_apply(p, Gc, G, eye) 210 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') 211 | 212 | # isotropic scale 213 | param = lognormal_sample(size, std=0.2 * math.log(2)) 214 | Gc = scale_mat(param, param) 215 | G = random_mat_apply(p, Gc, G, eye) 216 | # print('isotropic scale', G, scale_mat(param, param), sep='\n') 217 | 218 | p_rot = 1 - math.sqrt(1 - p) 219 | 220 | # pre-rotate 221 | param = uniform_sample(size, -math.pi, math.pi) 222 | Gc = rotate_mat(-param) 223 | G = random_mat_apply(p_rot, Gc, G, eye) 224 | # print('pre-rotate', G, rotate_mat(-param), sep='\n') 225 | 226 | # anisotropic scale 227 | param = lognormal_sample(size, std=0.2 * math.log(2)) 228 | Gc = scale_mat(param, 1 / param) 229 | G = random_mat_apply(p, Gc, G, eye) 230 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') 231 | 232 | # post-rotate 233 | param = uniform_sample(size, -math.pi, math.pi) 234 | Gc = rotate_mat(-param) 235 | G = random_mat_apply(p_rot, Gc, G, eye) 236 | # print('post-rotate', G, rotate_mat(-param), sep='\n') 237 | 238 | # fractional translate 239 | param = normal_sample(size, std=0.125) 240 | Gc = translate_mat(param, param) 241 | G = random_mat_apply(p, Gc, G, eye) 242 | # print('fractional translate', G, translate_mat(param, param), sep='\n') 243 | 244 | return G 245 | 246 | 247 | def sample_color(p, size): 248 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) 249 | eye = C 250 | axis_val = 1 / math.sqrt(3) 251 | axis = (axis_val, axis_val, axis_val) 252 | 253 | # brightness 254 | param = normal_sample(size, std=0.2) 255 | Cc = translate3d_mat(param, param, param) 256 | C = random_mat_apply(p, Cc, C, eye) 257 | 258 | # contrast 259 | param = lognormal_sample(size, std=0.5 * math.log(2)) 260 | Cc = scale3d_mat(param, param, param) 261 | C = random_mat_apply(p, Cc, C, eye) 262 | 263 | # luma flip 264 | param = category_sample(size, (0, 1)) 265 | Cc = luma_flip_mat(axis, param) 266 | C = random_mat_apply(p, Cc, C, eye) 267 | 268 | # hue rotation 269 | param = uniform_sample(size, -math.pi, math.pi) 270 | Cc = rotate3d_mat(axis, param) 271 | C = random_mat_apply(p, Cc, C, eye) 272 | 273 | # saturation 274 | param = lognormal_sample(size, std=1 * math.log(2)) 275 | Cc = saturation_mat(axis, param) 276 | C = random_mat_apply(p, Cc, C, eye) 277 | 278 | return C 279 | 280 | 281 | def make_grid(shape, x0, x1, y0, y1, device): 282 | n, c, h, w = shape 283 | grid = torch.empty(n, h, w, 3, device=device) 284 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) 285 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) 286 | grid[:, :, :, 2] = 1 287 | 288 | return grid 289 | 290 | 291 | def affine_grid(grid, mat): 292 | n, h, w, _ = grid.shape 293 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) 294 | 295 | 296 | def get_padding(G, height, width): 297 | extreme = ( 298 | G[:, :2, :] 299 | @ torch.tensor([(-1.0, -1, 1), (-1, 1, 1), (1, -1, 1), (1, 1, 1)]).t() 300 | ) 301 | 302 | size = torch.tensor((width, height)) 303 | 304 | pad_low = ( 305 | ((extreme.min(-1).values + 1) * size) 306 | .clamp(max=0) 307 | .abs() 308 | .ceil() 309 | .max(0) 310 | .values.to(torch.int64) 311 | .tolist() 312 | ) 313 | pad_high = ( 314 | (extreme.max(-1).values * size - size) 315 | .clamp(min=0) 316 | .ceil() 317 | .max(0) 318 | .values.to(torch.int64) 319 | .tolist() 320 | ) 321 | 322 | return pad_low[0], pad_high[0], pad_low[1], pad_high[1] 323 | 324 | 325 | def try_sample_affine_and_pad(img, p, pad_k, G=None): 326 | batch, _, height, width = img.shape 327 | 328 | G_try = G 329 | 330 | while True: 331 | if G is None: 332 | G_try = sample_affine(p, batch, height, width) 333 | 334 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding( 335 | torch.inverse(G_try), height, width 336 | ) 337 | 338 | try: 339 | img_pad = F.pad( 340 | img, 341 | (pad_x1 + pad_k, pad_x2 + pad_k, pad_y1 + pad_k, pad_y2 + pad_k), 342 | mode="reflect", 343 | ) 344 | 345 | except RuntimeError: 346 | continue 347 | 348 | break 349 | 350 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) 351 | 352 | 353 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): 354 | kernel = antialiasing_kernel 355 | len_k = len(kernel) 356 | pad_k = (len_k + 1) // 2 357 | 358 | kernel = torch.as_tensor(kernel) 359 | kernel = torch.ger(kernel, kernel).to(img) 360 | kernel_flip = torch.flip(kernel, (0, 1)) 361 | 362 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( 363 | img, p, pad_k, G 364 | ) 365 | 366 | p_ux1 = pad_x1 367 | p_ux2 = pad_x2 + 1 368 | p_uy1 = pad_y1 369 | p_uy2 = pad_y2 + 1 370 | w_p = img_pad.shape[3] - len_k + 1 371 | h_p = img_pad.shape[2] - len_k + 1 372 | h_o = img.shape[2] 373 | w_o = img.shape[3] 374 | 375 | img_2x = upfirdn2d(img_pad, kernel_flip, up=2) 376 | 377 | grid = make_grid( 378 | img_2x.shape, 379 | -2 * p_ux1 / w_o - 1, 380 | 2 * (w_p - p_ux1) / w_o - 1, 381 | -2 * p_uy1 / h_o - 1, 382 | 2 * (h_p - p_uy1) / h_o - 1, 383 | device=img_2x.device, 384 | ).to(img_2x) 385 | grid = affine_grid(grid, torch.inverse(G)[:, :2, :].to(img_2x)) 386 | grid = grid * torch.tensor( 387 | [w_o / w_p, h_o / h_p], device=grid.device 388 | ) + torch.tensor( 389 | [(w_o + 2 * p_ux1) / w_p - 1, (h_o + 2 * p_uy1) / h_p - 1], device=grid.device 390 | ) 391 | 392 | img_affine = F.grid_sample( 393 | img_2x, grid, mode="bilinear", align_corners=False, padding_mode="zeros" 394 | ) 395 | 396 | img_down = upfirdn2d(img_affine, kernel, down=2) 397 | 398 | end_y = -pad_y2 - 1 399 | if end_y == 0: 400 | end_y = img_down.shape[2] 401 | 402 | end_x = -pad_x2 - 1 403 | if end_x == 0: 404 | end_x = img_down.shape[3] 405 | 406 | img = img_down[:, :, pad_y1:end_y, pad_x1:end_x] 407 | 408 | return img, G 409 | 410 | 411 | def apply_color(img, mat): 412 | batch = img.shape[0] 413 | img = img.permute(0, 2, 3, 1) 414 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) 415 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) 416 | img = img @ mat_mul + mat_add 417 | img = img.permute(0, 3, 1, 2) 418 | 419 | return img 420 | 421 | 422 | def random_apply_color(img, p, C=None): 423 | if C is None: 424 | C = sample_color(p, img.shape[0]) 425 | 426 | img = apply_color(img, C.to(img)) 427 | 428 | return img, C 429 | 430 | 431 | def augment(img, p, transform_matrix=(None, None)): 432 | img, G = random_apply_affine(img, p, transform_matrix[0]) 433 | img, C = random_apply_color(img, p, transform_matrix[1]) 434 | 435 | return img, (G, C) 436 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | if torch.cuda.is_available(): 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu" or input.device.type == "xla": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | if torch.cuda.is_available(): 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu" or input.device.type=="xla": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import lpips 9 | from model import Generator 10 | 11 | 12 | def normalize(x): 13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 14 | 15 | 16 | def slerp(a, b, t): 17 | a = normalize(a) 18 | b = normalize(b) 19 | d = (a * b).sum(-1, keepdim=True) 20 | p = t * torch.acos(d) 21 | c = normalize(b - d * a) 22 | d = a * torch.cos(p) + c * torch.sin(p) 23 | 24 | return normalize(d) 25 | 26 | 27 | def lerp(a, b, t): 28 | return a + (b - a) * t 29 | 30 | 31 | if __name__ == "__main__": 32 | device = "cuda" 33 | 34 | parser = argparse.ArgumentParser(description="Perceptual Path Length calculator") 35 | 36 | parser.add_argument( 37 | "--space", choices=["z", "w"], help="space that PPL calculated with" 38 | ) 39 | parser.add_argument( 40 | "--batch", type=int, default=64, help="batch size for the models" 41 | ) 42 | parser.add_argument( 43 | "--n_sample", 44 | type=int, 45 | default=5000, 46 | help="number of the samples for calculating PPL", 47 | ) 48 | parser.add_argument( 49 | "--size", type=int, default=256, help="output image sizes of the generator" 50 | ) 51 | parser.add_argument( 52 | "--eps", type=float, default=1e-4, help="epsilon for numerical stability" 53 | ) 54 | parser.add_argument( 55 | "--crop", action="store_true", help="apply center crop to the images" 56 | ) 57 | parser.add_argument( 58 | "--sampling", 59 | default="end", 60 | choices=["end", "full"], 61 | help="set endpoint sampling method", 62 | ) 63 | parser.add_argument( 64 | "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | latent_dim = 512 70 | 71 | ckpt = torch.load(args.ckpt) 72 | 73 | g = Generator(args.size, latent_dim, 8).to(device) 74 | g.load_state_dict(ckpt["g_ema"]) 75 | g.eval() 76 | 77 | percept = lpips.PerceptualLoss( 78 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 79 | ) 80 | 81 | distances = [] 82 | 83 | n_batch = args.n_sample // args.batch 84 | resid = args.n_sample - (n_batch * args.batch) 85 | batch_sizes = [args.batch] * n_batch + [resid] 86 | 87 | with torch.no_grad(): 88 | for batch in tqdm(batch_sizes): 89 | noise = g.make_noise() 90 | 91 | inputs = torch.randn([batch * 2, latent_dim], device=device) 92 | if args.sampling == "full": 93 | lerp_t = torch.rand(batch, device=device) 94 | else: 95 | lerp_t = torch.zeros(batch, device=device) 96 | 97 | if args.space == "w": 98 | latent = g.get_latent(inputs) 99 | latent_t0, latent_t1 = latent[::2], latent[1::2] 100 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 101 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 102 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 103 | 104 | image, _ = g([latent_e], input_is_latent=True, noise=noise) 105 | 106 | if args.crop: 107 | c = image.shape[2] // 8 108 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 109 | 110 | factor = image.shape[2] // 256 111 | 112 | if factor > 1: 113 | image = F.interpolate( 114 | image, size=(256, 256), mode="bilinear", align_corners=False 115 | ) 116 | 117 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( 118 | args.eps ** 2 119 | ) 120 | distances.append(dist.to("cpu").numpy()) 121 | 122 | distances = np.concatenate(distances, 0) 123 | 124 | lo = np.percentile(distances, 1, interpolation="lower") 125 | hi = np.percentile(distances, 99, interpolation="higher") 126 | filtered_dist = np.extract( 127 | np.logical_and(lo <= distances, distances <= hi), distances 128 | ) 129 | 130 | print("ppl:", filtered_dist.mean()) 131 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format="jpeg", quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple( 24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 25 | ): 26 | imgs = [] 27 | 28 | for size in sizes: 29 | imgs.append(resize_and_convert(img, size, resample, quality)) 30 | 31 | return imgs 32 | 33 | 34 | def resize_worker(img_file, sizes, resample): 35 | i, file = img_file 36 | img = Image.open(file) 37 | img = img.convert("RGB") 38 | out = resize_multiple(img, sizes=sizes, resample=resample) 39 | 40 | return i, out 41 | 42 | 43 | def prepare( 44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS 45 | ): 46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 47 | 48 | files = sorted(dataset.imgs, key=lambda x: x[0]) 49 | files = [(i, file) for i, (file, label) in enumerate(files)] 50 | total = 0 51 | 52 | with multiprocessing.Pool(n_worker) as pool: 53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 54 | for size, img in zip(sizes, imgs): 55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 56 | 57 | with env.begin(write=True) as txn: 58 | txn.put(key, img) 59 | 60 | total += 1 61 | 62 | with env.begin(write=True) as txn: 63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Preprocess images for model training") 68 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset") 69 | parser.add_argument( 70 | "--size", 71 | type=str, 72 | default="128,256,512,1024", 73 | help="resolutions of images for the dataset", 74 | ) 75 | parser.add_argument( 76 | "--n_worker", 77 | type=int, 78 | default=8, 79 | help="number of workers for preparing dataset", 80 | ) 81 | parser.add_argument( 82 | "--resample", 83 | type=str, 84 | default="lanczos", 85 | help="resampling methods for resizing images", 86 | ) 87 | parser.add_argument("path", type=str, help="path to the image dataset") 88 | 89 | args = parser.parse_args() 90 | 91 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 92 | resample = resample_map[args.resample] 93 | 94 | sizes = [int(s.strip()) for s in args.size.split(",")] 95 | 96 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 97 | 98 | imgset = datasets.ImageFolder(args.path) 99 | 100 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 101 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 102 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torchvision import transforms 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import lpips 13 | from model import Generator 14 | 15 | 16 | def noise_regularize(noises): 17 | loss = 0 18 | 19 | for noise in noises: 20 | size = noise.shape[2] 21 | 22 | while True: 23 | loss = ( 24 | loss 25 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 26 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 27 | ) 28 | 29 | if size <= 8: 30 | break 31 | 32 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) 33 | noise = noise.mean([3, 5]) 34 | size //= 2 35 | 36 | return loss 37 | 38 | 39 | def noise_normalize_(noises): 40 | for noise in noises: 41 | mean = noise.mean() 42 | std = noise.std() 43 | 44 | noise.data.add_(-mean).div_(std) 45 | 46 | 47 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 48 | lr_ramp = min(1, (1 - t) / rampdown) 49 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 50 | lr_ramp = lr_ramp * min(1, t / rampup) 51 | 52 | return initial_lr * lr_ramp 53 | 54 | 55 | def latent_noise(latent, strength): 56 | noise = torch.randn_like(latent) * strength 57 | 58 | return latent + noise 59 | 60 | 61 | def make_image(tensor): 62 | return ( 63 | tensor.detach() 64 | .clamp_(min=-1, max=1) 65 | .add(1) 66 | .div_(2) 67 | .mul(255) 68 | .type(torch.uint8) 69 | .permute(0, 2, 3, 1) 70 | .to("cpu") 71 | .numpy() 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | device = "cuda" 77 | 78 | parser = argparse.ArgumentParser( 79 | description="Image projector to the generator latent spaces" 80 | ) 81 | parser.add_argument( 82 | "--ckpt", type=str, required=True, help="path to the model checkpoint" 83 | ) 84 | parser.add_argument( 85 | "--size", type=int, default=256, help="output image sizes of the generator" 86 | ) 87 | parser.add_argument( 88 | "--lr_rampup", 89 | type=float, 90 | default=0.05, 91 | help="duration of the learning rate warmup", 92 | ) 93 | parser.add_argument( 94 | "--lr_rampdown", 95 | type=float, 96 | default=0.25, 97 | help="duration of the learning rate decay", 98 | ) 99 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate") 100 | parser.add_argument( 101 | "--noise", type=float, default=0.05, help="strength of the noise level" 102 | ) 103 | parser.add_argument( 104 | "--noise_ramp", 105 | type=float, 106 | default=0.75, 107 | help="duration of the noise level decay", 108 | ) 109 | parser.add_argument("--step", type=int, default=1000, help="optimize iterations") 110 | parser.add_argument( 111 | "--noise_regularize", 112 | type=float, 113 | default=1e5, 114 | help="weight of the noise regularization", 115 | ) 116 | parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss") 117 | parser.add_argument( 118 | "--w_plus", 119 | action="store_true", 120 | help="allow to use distinct latent codes to each layers", 121 | ) 122 | parser.add_argument( 123 | "files", metavar="FILES", nargs="+", help="path to image files to be projected" 124 | ) 125 | 126 | args = parser.parse_args() 127 | 128 | n_mean_latent = 10000 129 | 130 | resize = min(args.size, 256) 131 | 132 | transform = transforms.Compose( 133 | [ 134 | transforms.Resize(resize), 135 | transforms.CenterCrop(resize), 136 | transforms.ToTensor(), 137 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 138 | ] 139 | ) 140 | 141 | imgs = [] 142 | 143 | for imgfile in args.files: 144 | img = transform(Image.open(imgfile).convert("RGB")) 145 | imgs.append(img) 146 | 147 | imgs = torch.stack(imgs, 0).to(device) 148 | 149 | g_ema = Generator(args.size, 512, 8) 150 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 151 | g_ema.eval() 152 | g_ema = g_ema.to(device) 153 | 154 | with torch.no_grad(): 155 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 156 | latent_out = g_ema.style(noise_sample) 157 | 158 | latent_mean = latent_out.mean(0) 159 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 160 | 161 | percept = lpips.PerceptualLoss( 162 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 163 | ) 164 | 165 | noises_single = g_ema.make_noise() 166 | noises = [] 167 | for noise in noises_single: 168 | noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) 169 | 170 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) 171 | 172 | if args.w_plus: 173 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) 174 | 175 | latent_in.requires_grad = True 176 | 177 | for noise in noises: 178 | noise.requires_grad = True 179 | 180 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr) 181 | 182 | pbar = tqdm(range(args.step)) 183 | latent_path = [] 184 | 185 | for i in pbar: 186 | t = i / args.step 187 | lr = get_lr(t, args.lr) 188 | optimizer.param_groups[0]["lr"] = lr 189 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 190 | latent_n = latent_noise(latent_in, noise_strength.item()) 191 | 192 | img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) 193 | 194 | batch, channel, height, width = img_gen.shape 195 | 196 | if height > 256: 197 | factor = height // 256 198 | 199 | img_gen = img_gen.reshape( 200 | batch, channel, height // factor, factor, width // factor, factor 201 | ) 202 | img_gen = img_gen.mean([3, 5]) 203 | 204 | p_loss = percept(img_gen, imgs).sum() 205 | n_loss = noise_regularize(noises) 206 | mse_loss = F.mse_loss(img_gen, imgs) 207 | 208 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss 209 | 210 | optimizer.zero_grad() 211 | loss.backward() 212 | optimizer.step() 213 | 214 | noise_normalize_(noises) 215 | 216 | if (i + 1) % 100 == 0: 217 | latent_path.append(latent_in.detach().clone()) 218 | 219 | pbar.set_description( 220 | ( 221 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" 222 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" 223 | ) 224 | ) 225 | 226 | img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) 227 | 228 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" 229 | 230 | img_ar = make_image(img_gen) 231 | 232 | result_file = {} 233 | for i, input_name in enumerate(args.files): 234 | noise_single = [] 235 | for noise in noises: 236 | noise_single.append(noise[i : i + 1]) 237 | 238 | result_file[input_name] = { 239 | "img": img_gen[i], 240 | "latent": latent_in[i], 241 | "noise": noise_single, 242 | } 243 | 244 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" 245 | pil_img = Image.fromarray(img_ar[i]) 246 | pil_img.save(img_name) 247 | 248 | torch.save(result_file, filename) 249 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | 15 | try: 16 | import wandb 17 | 18 | except ImportError: 19 | wandb = None 20 | 21 | from model import Generator, Discriminator 22 | from dataset import MultiResolutionDataset 23 | from distributed import ( 24 | get_rank, 25 | synchronize, 26 | reduce_loss_dict, 27 | reduce_sum, 28 | get_world_size, 29 | ) 30 | from non_leaking import augment, AdaptiveAugment 31 | 32 | 33 | def data_sampler(dataset, shuffle, distributed): 34 | if distributed: 35 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 36 | 37 | if shuffle: 38 | return data.RandomSampler(dataset) 39 | 40 | else: 41 | return data.SequentialSampler(dataset) 42 | 43 | 44 | def requires_grad(model, flag=True): 45 | for p in model.parameters(): 46 | p.requires_grad = flag 47 | 48 | 49 | def accumulate(model1, model2, decay=0.999): 50 | par1 = dict(model1.named_parameters()) 51 | par2 = dict(model2.named_parameters()) 52 | 53 | for k in par1.keys(): 54 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 55 | 56 | 57 | def sample_data(loader): 58 | while True: 59 | for batch in loader: 60 | yield batch 61 | 62 | 63 | def d_logistic_loss(real_pred, fake_pred): 64 | real_loss = F.softplus(-real_pred) 65 | fake_loss = F.softplus(fake_pred) 66 | 67 | return real_loss.mean() + fake_loss.mean() 68 | 69 | 70 | def d_r1_loss(real_pred, real_img): 71 | grad_real, = autograd.grad( 72 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 73 | ) 74 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 75 | 76 | return grad_penalty 77 | 78 | 79 | def g_nonsaturating_loss(fake_pred): 80 | loss = F.softplus(-fake_pred).mean() 81 | 82 | return loss 83 | 84 | 85 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 86 | noise = torch.randn_like(fake_img) / math.sqrt( 87 | fake_img.shape[2] * fake_img.shape[3] 88 | ) 89 | grad, = autograd.grad( 90 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 91 | ) 92 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 93 | 94 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 95 | 96 | path_penalty = (path_lengths - path_mean).pow(2).mean() 97 | 98 | return path_penalty, path_mean.detach(), path_lengths 99 | 100 | 101 | def make_noise(batch, latent_dim, n_noise, device): 102 | if n_noise == 1: 103 | return torch.randn(batch, latent_dim, device=device) 104 | 105 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 106 | 107 | return noises 108 | 109 | 110 | def mixing_noise(batch, latent_dim, prob, device): 111 | if prob > 0 and random.random() < prob: 112 | return make_noise(batch, latent_dim, 2, device) 113 | 114 | else: 115 | return [make_noise(batch, latent_dim, 1, device)] 116 | 117 | 118 | def set_grad_none(model, targets): 119 | for n, p in model.named_parameters(): 120 | if n in targets: 121 | p.grad = None 122 | 123 | 124 | def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): 125 | loader = sample_data(loader) 126 | 127 | pbar = range(args.iter) 128 | 129 | if get_rank() == 0: 130 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 131 | 132 | mean_path_length = 0 133 | 134 | d_loss_val = 0 135 | r1_loss = torch.tensor(0.0, device=device) 136 | g_loss_val = 0 137 | path_loss = torch.tensor(0.0, device=device) 138 | path_lengths = torch.tensor(0.0, device=device) 139 | mean_path_length_avg = 0 140 | loss_dict = {} 141 | 142 | if args.distributed: 143 | g_module = generator.module 144 | d_module = discriminator.module 145 | 146 | else: 147 | g_module = generator 148 | d_module = discriminator 149 | 150 | accum = 0.5 ** (32 / (10 * 1000)) 151 | ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0 152 | r_t_stat = 0 153 | 154 | if args.augment and args.augment_p == 0: 155 | ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 256, device) 156 | 157 | sample_z = torch.randn(args.n_sample, args.latent, device=device) 158 | 159 | for idx in pbar: 160 | i = idx + args.start_iter 161 | 162 | if i > args.iter: 163 | print("Done!") 164 | 165 | break 166 | 167 | real_img = next(loader) 168 | real_img = real_img.to(device) 169 | 170 | requires_grad(generator, False) 171 | requires_grad(discriminator, True) 172 | 173 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 174 | fake_img, _ = generator(noise) 175 | 176 | if args.augment: 177 | real_img_aug, _ = augment(real_img, ada_aug_p) 178 | fake_img, _ = augment(fake_img, ada_aug_p) 179 | 180 | else: 181 | real_img_aug = real_img 182 | 183 | fake_pred = discriminator(fake_img) 184 | real_pred = discriminator(real_img_aug) 185 | d_loss = d_logistic_loss(real_pred, fake_pred) 186 | 187 | loss_dict["d"] = d_loss 188 | loss_dict["real_score"] = real_pred.mean() 189 | loss_dict["fake_score"] = fake_pred.mean() 190 | 191 | discriminator.zero_grad() 192 | d_loss.backward() 193 | d_optim.step() 194 | 195 | if args.augment and args.augment_p == 0: 196 | ada_aug_p = ada_augment.tune(real_pred) 197 | r_t_stat = ada_augment.r_t_stat 198 | 199 | d_regularize = i % args.d_reg_every == 0 200 | 201 | if d_regularize: 202 | real_img.requires_grad = True 203 | real_pred = discriminator(real_img) 204 | r1_loss = d_r1_loss(real_pred, real_img) 205 | 206 | discriminator.zero_grad() 207 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() 208 | 209 | d_optim.step() 210 | 211 | loss_dict["r1"] = r1_loss 212 | 213 | requires_grad(generator, True) 214 | requires_grad(discriminator, False) 215 | 216 | noise = mixing_noise(args.batch, args.latent, args.mixing, device) 217 | fake_img, _ = generator(noise) 218 | 219 | if args.augment: 220 | fake_img, _ = augment(fake_img, ada_aug_p) 221 | 222 | fake_pred = discriminator(fake_img) 223 | g_loss = g_nonsaturating_loss(fake_pred) 224 | 225 | loss_dict["g"] = g_loss 226 | 227 | generator.zero_grad() 228 | g_loss.backward() 229 | g_optim.step() 230 | 231 | g_regularize = i % args.g_reg_every == 0 232 | 233 | if g_regularize: 234 | path_batch_size = max(1, args.batch // args.path_batch_shrink) 235 | noise = mixing_noise(path_batch_size, args.latent, args.mixing, device) 236 | fake_img, latents = generator(noise, return_latents=True) 237 | 238 | path_loss, mean_path_length, path_lengths = g_path_regularize( 239 | fake_img, latents, mean_path_length 240 | ) 241 | 242 | generator.zero_grad() 243 | weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss 244 | 245 | if args.path_batch_shrink: 246 | weighted_path_loss += 0 * fake_img[0, 0, 0, 0] 247 | 248 | weighted_path_loss.backward() 249 | 250 | g_optim.step() 251 | 252 | mean_path_length_avg = ( 253 | reduce_sum(mean_path_length).item() / get_world_size() 254 | ) 255 | 256 | loss_dict["path"] = path_loss 257 | loss_dict["path_length"] = path_lengths.mean() 258 | 259 | accumulate(g_ema, g_module, accum) 260 | 261 | loss_reduced = reduce_loss_dict(loss_dict) 262 | 263 | d_loss_val = loss_reduced["d"].mean().item() 264 | g_loss_val = loss_reduced["g"].mean().item() 265 | r1_val = loss_reduced["r1"].mean().item() 266 | path_loss_val = loss_reduced["path"].mean().item() 267 | real_score_val = loss_reduced["real_score"].mean().item() 268 | fake_score_val = loss_reduced["fake_score"].mean().item() 269 | path_length_val = loss_reduced["path_length"].mean().item() 270 | 271 | if get_rank() == 0: 272 | pbar.set_description( 273 | ( 274 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " 275 | f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; " 276 | f"augment: {ada_aug_p:.4f}" 277 | ) 278 | ) 279 | 280 | if wandb and args.wandb: 281 | wandb.log( 282 | { 283 | "Generator": g_loss_val, 284 | "Discriminator": d_loss_val, 285 | "Augment": ada_aug_p, 286 | "Rt": r_t_stat, 287 | "R1": r1_val, 288 | "Path Length Regularization": path_loss_val, 289 | "Mean Path Length": mean_path_length, 290 | "Real Score": real_score_val, 291 | "Fake Score": fake_score_val, 292 | "Path Length": path_length_val, 293 | } 294 | ) 295 | 296 | if i % 100 == 0: 297 | with torch.no_grad(): 298 | g_ema.eval() 299 | sample, _ = g_ema([sample_z]) 300 | utils.save_image( 301 | sample, 302 | f"sample/{str(i).zfill(6)}.png", 303 | nrow=int(args.n_sample ** 0.5), 304 | normalize=True, 305 | range=(-1, 1), 306 | ) 307 | 308 | if i % 10000 == 0: 309 | torch.save( 310 | { 311 | "g": g_module.state_dict(), 312 | "d": d_module.state_dict(), 313 | "g_ema": g_ema.state_dict(), 314 | "g_optim": g_optim.state_dict(), 315 | "d_optim": d_optim.state_dict(), 316 | "args": args, 317 | "ada_aug_p": ada_aug_p, 318 | }, 319 | f"checkpoint/{str(i).zfill(6)}.pt", 320 | ) 321 | 322 | 323 | if __name__ == "__main__": 324 | device = "cuda" 325 | 326 | parser = argparse.ArgumentParser(description="StyleGAN2 trainer") 327 | 328 | parser.add_argument("path", type=str, help="path to the lmdb dataset") 329 | parser.add_argument( 330 | "--iter", type=int, default=800000, help="total training iterations" 331 | ) 332 | parser.add_argument( 333 | "--batch", type=int, default=16, help="batch sizes for each gpus" 334 | ) 335 | parser.add_argument( 336 | "--n_sample", 337 | type=int, 338 | default=64, 339 | help="number of the samples generated during training", 340 | ) 341 | parser.add_argument( 342 | "--size", type=int, default=256, help="image sizes for the model" 343 | ) 344 | parser.add_argument( 345 | "--r1", type=float, default=10, help="weight of the r1 regularization" 346 | ) 347 | parser.add_argument( 348 | "--path_regularize", 349 | type=float, 350 | default=2, 351 | help="weight of the path length regularization", 352 | ) 353 | parser.add_argument( 354 | "--path_batch_shrink", 355 | type=int, 356 | default=2, 357 | help="batch size reducing factor for the path length regularization (reduce memory consumption)", 358 | ) 359 | parser.add_argument( 360 | "--d_reg_every", 361 | type=int, 362 | default=16, 363 | help="interval of the applying r1 regularization", 364 | ) 365 | parser.add_argument( 366 | "--g_reg_every", 367 | type=int, 368 | default=4, 369 | help="interval of the applying path length regularization", 370 | ) 371 | parser.add_argument( 372 | "--mixing", type=float, default=0.9, help="probability of latent code mixing" 373 | ) 374 | parser.add_argument( 375 | "--ckpt", 376 | type=str, 377 | default=None, 378 | help="path to the checkpoints to resume training", 379 | ) 380 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate") 381 | parser.add_argument( 382 | "--channel_multiplier", 383 | type=int, 384 | default=2, 385 | help="channel multiplier factor for the model. config-f = 2, else = 1", 386 | ) 387 | parser.add_argument( 388 | "--wandb", action="store_true", help="use weights and biases logging" 389 | ) 390 | parser.add_argument( 391 | "--local_rank", type=int, default=0, help="local rank for distributed training" 392 | ) 393 | parser.add_argument( 394 | "--augment", action="store_true", help="apply non leaking augmentation" 395 | ) 396 | parser.add_argument( 397 | "--augment_p", 398 | type=float, 399 | default=0, 400 | help="probability of applying augmentation. 0 = use adaptive augmentation", 401 | ) 402 | parser.add_argument( 403 | "--ada_target", 404 | type=float, 405 | default=0.6, 406 | help="target augmentation probability for adaptive augmentation", 407 | ) 408 | parser.add_argument( 409 | "--ada_length", 410 | type=int, 411 | default=500 * 1000, 412 | help="target duraing to reach augmentation probability for adaptive augmentation", 413 | ) 414 | parser.add_argument( 415 | "--ada_every", 416 | type=int, 417 | default=256, 418 | help="probability update interval of the adaptive augmentation", 419 | ) 420 | 421 | args = parser.parse_args() 422 | 423 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 424 | args.distributed = n_gpu > 1 425 | 426 | if args.distributed: 427 | torch.cuda.set_device(args.local_rank) 428 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 429 | synchronize() 430 | 431 | args.latent = 512 432 | args.n_mlp = 8 433 | 434 | args.start_iter = 0 435 | 436 | generator = Generator( 437 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 438 | ).to(device) 439 | discriminator = Discriminator( 440 | args.size, channel_multiplier=args.channel_multiplier 441 | ).to(device) 442 | g_ema = Generator( 443 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 444 | ).to(device) 445 | g_ema.eval() 446 | accumulate(g_ema, generator, 0) 447 | 448 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 449 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 450 | 451 | g_optim = optim.Adam( 452 | generator.parameters(), 453 | lr=args.lr * g_reg_ratio, 454 | betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), 455 | ) 456 | d_optim = optim.Adam( 457 | discriminator.parameters(), 458 | lr=args.lr * d_reg_ratio, 459 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 460 | ) 461 | 462 | if args.ckpt is not None: 463 | print("load model:", args.ckpt) 464 | 465 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 466 | 467 | try: 468 | ckpt_name = os.path.basename(args.ckpt) 469 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 470 | 471 | except ValueError: 472 | pass 473 | 474 | generator.load_state_dict(ckpt["g"]) 475 | discriminator.load_state_dict(ckpt["d"]) 476 | g_ema.load_state_dict(ckpt["g_ema"]) 477 | 478 | g_optim.load_state_dict(ckpt["g_optim"]) 479 | d_optim.load_state_dict(ckpt["d_optim"]) 480 | 481 | if args.distributed: 482 | generator = nn.parallel.DistributedDataParallel( 483 | generator, 484 | device_ids=[args.local_rank], 485 | output_device=args.local_rank, 486 | broadcast_buffers=False, 487 | ) 488 | 489 | discriminator = nn.parallel.DistributedDataParallel( 490 | discriminator, 491 | device_ids=[args.local_rank], 492 | output_device=args.local_rank, 493 | broadcast_buffers=False, 494 | ) 495 | 496 | transform = transforms.Compose( 497 | [ 498 | transforms.RandomHorizontalFlip(), 499 | transforms.ToTensor(), 500 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 501 | ] 502 | ) 503 | 504 | dataset = MultiResolutionDataset(args.path, transform, args.size) 505 | loader = data.DataLoader( 506 | dataset, 507 | batch_size=args.batch, 508 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 509 | drop_last=True, 510 | ) 511 | 512 | if get_rank() == 0 and wandb is not None and args.wandb: 513 | wandb.init(project="stylegan 2") 514 | 515 | train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) 516 | --------------------------------------------------------------------------------