├── .DS_Store ├── .idea ├── .gitignore ├── DCT-Net.iml ├── deployment.xml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── vcs.xml └── webServers.xml ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── demo.gif ├── sim1.png ├── sim2.png ├── sim3.png ├── sim4.png ├── sim5.png ├── sim_3d.png ├── sim_anime.png ├── sim_artstyle.png ├── sim_design.png ├── sim_handdrawn.png ├── sim_illu.png ├── sim_sketch.png ├── styles.png └── video.gif ├── celeb.txt ├── download.py ├── export.py ├── extract_align_faces.py ├── generate_data.py ├── input.mp4 ├── input.png ├── multi-style ├── download.py ├── run.py └── run_sdk.py ├── notebooks ├── .DS_Store ├── fastTrain.ipynb └── inference.ipynb ├── prepare_data.sh ├── run.py ├── run_sdk.py ├── run_vid.py ├── source ├── .DS_Store ├── __init__.py ├── cartoonize.py ├── facelib │ ├── LICENSE │ ├── LK │ │ ├── __init__.py │ │ └── lk.py │ ├── __init__.py │ ├── config.py │ ├── face_detector.py │ ├── face_landmark.py │ └── facer.py ├── image_flip_agument_parallel.py ├── image_rotation_agument_parallel_flat.py ├── image_scale_agument_parallel_flat.py ├── mtcnn_pytorch │ ├── .DS_Store │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ └── src │ │ ├── __init__.py │ │ ├── align_trans.py │ │ └── matlab_cp2tform.py ├── stylegan2 │ ├── .DS_Store │ ├── __init__.py │ ├── config │ │ ├── .DS_Store │ │ ├── conf_server_test_blend_shell.json │ │ └── conf_server_train_condition_shell.json │ ├── criteria │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── id_loss.py │ │ ├── lpips │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ ├── networks.py │ │ │ └── utils.py │ │ ├── moco_loss.py │ │ ├── model_irse.py │ │ ├── vgg.py │ │ ├── vgg_loss.py │ │ └── w_norm.py │ ├── dataset.py │ ├── distributed.py │ ├── generate_blendmodel.py │ ├── model.py │ ├── noise.pt │ ├── non_leaking.py │ ├── op │ │ ├── __init__.py │ │ ├── conv2d_gradfix.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── prepare_data.py │ ├── style_blend.py │ └── train_condition.py └── utils.py └── train_localtoon.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/.DS_Store -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/DCT-Net.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 110 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DCT-Net: Domain-Calibrated Translation for Portrait Stylization 2 | 3 | ### [Project page](https://menyifang.github.io/projects/DCTNet/DCTNet.html) | [Video](https://www.youtube.com/watch?v=Y8BrfOjXYQM) | [Paper](https://arxiv.org/abs/2207.02426) 4 | 5 | Official implementation of DCT-Net for Full-body Portrait Stylization. 6 | 7 | 8 | > [**DCT-Net: Domain-Calibrated Translation for Portrait Stylization**](arxiv_url_coming_soon), 9 | > [Yifang Men](https://menyifang.github.io/)1, Yuan Yao1, Miaomiao Cui1, [Zhouhui Lian](https://www.icst.pku.edu.cn/zlian/)2, Xuansong Xie1, 10 | > _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Beijing, China_ 11 | > _2[Wangxuan Institute of Computer Technology, Peking University](https://www.icst.pku.edu.cn/), China_ 12 | > In: SIGGRAPH 2022 (**TOG**) 13 | > *[arXiv preprint](https://arxiv.org/abs/2207.02426)* 14 | 15 | google colab logo 16 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net) 17 | 18 | 19 | ## Demo 20 | ![demo](assets/demo.gif) 21 | 22 | 23 | ## News 24 | 25 | (2023-03-14) The training guidance has been released, train DCT-Net with your own style data. 26 | 27 | (2023-02-20) Two new style pre-trained models (design, illustration) trained with combined DCT-Net and Stable-Diffusion are provided. The training guidance will be released soon. 28 | 29 | (2022-10-09) The multi-style pre-trained models (3d, handdrawn, sketch, artstyle) and usage are available now. 30 | 31 | (2022-08-08) The pertained model and infer code of 'anime' style is available now. More styles coming soon. 32 | 33 | (2022-08-08) cartoon function can be directly call from pythonSDK. 34 | 35 | (2022-07-07) The paper is available now at arxiv(https://arxiv.org/abs/2207.02426). 36 | 37 | 38 | ## Web Demo 39 | - Integrated into [Colab notebook](https://colab.research.google.com/github/menyifang/DCT-Net/blob/main/notebooks/inference.ipynb). Try out the colab demo.google colab logo 40 | 41 | - Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net) 42 | 43 | - [Chinese version] Integrated into [ModelScope](https://modelscope.cn/#/models). Try out the Web Demo [![ModelScope Spaces]( 44 | https://img.shields.io/badge/ModelScope-Spaces-blue)](https://modelscope.cn/#/models/damo/cv_unet_person-image-cartoon_compound-models/summary) 45 | 46 | ## Requirements 47 | * python 3 48 | * tensorflow (>=1.14, training only support tf1.x) 49 | * easydict 50 | * numpy 51 | * both CPU/GPU are supported 52 | 53 | 54 | ## Quick Start 55 | google colab logo 56 | 57 | 58 | ```bash 59 | git clone https://github.com/menyifang/DCT-Net.git 60 | cd DCT-Net 61 | 62 | ``` 63 | 64 | ### Installation 65 | ```bash 66 | conda create -n dctnet python=3.7 67 | conda activate dctnet 68 | pip install --upgrade tensorflow-gpu==1.15 # GPU support, use tensorflow for CPU only 69 | pip install "modelscope[cv]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 70 | pip install "modelscope[multi-modal]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html 71 | ``` 72 | 73 | ### Downloads 74 | 75 | | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary)| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary)| [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary)| 76 | |:--:|:--:|:--:|:--:|:--:| 77 | | [anime](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [3d](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [handdrawn](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary) | [sketch](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary) | [artstyle](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary) | 78 | 79 | | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary) | 80 | |:--:|:--:| 81 | | [design](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [illustration](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary) 82 | 83 | Pre-trained models in different styles can be downloaded by 84 | ```bash 85 | python download.py 86 | ``` 87 | 88 | ### Inference 89 | 90 | - from python SDK 91 | ```bash 92 | python run_sdk.py 93 | ``` 94 | 95 | - from source code 96 | ```bash 97 | python run.py 98 | ``` 99 | 100 | ### Video cartoonization 101 | 102 | ![demo_vid](assets/video.gif) 103 | 104 | video can be directly processed as image sequences, style choice [option: anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration] 105 | 106 | ```bash 107 | python run_vid.py --style anime 108 | ``` 109 | 110 | 111 | ## Training 112 | google colab logo 113 | 114 | ### Data preparation 115 | ``` 116 | face_photo: face dataset such as [FFHQ](https://github.com/NVlabs/ffhq-dataset) or other collected real faces. 117 | face_cartoon: 100-300 cartoon face images in a specific style, which can be self-collected or synthsized with generative models. 118 | ``` 119 | Due to the copyrighe issues, we can not provide collected cartoon exemplar for training. You can produce cartoon exemplars with the style-finetuned Stable-Diffusion (SD) models, which can be downloaded from modelscope or huggingface hubs. 120 | 121 | The effects of some style-finetune SD models are as follows: 122 | 123 | | [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor) | [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary)| [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary)| [](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary)| 124 | |:--:|:--:|:--:|:--:|:--:| 125 | | [design](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [watercolor](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor/summary) | [illustration](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary) | [clipart](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary) | [flat](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary) | 126 | 127 | - Generate stylized data, style choice [option: clipart, design, illustration, watercolor, flat] 128 | ```bash 129 | python generate_data.py --style clipart 130 | ``` 131 | 132 | - preprocess 133 | 134 | extract aligned faces from raw style images: 135 | ```bash 136 | python extract_align_faces.py --src_dir 'data/raw_style_data' 137 | ``` 138 | 139 | - train content calibration network 140 | 141 | install environment required by [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch) 142 | ```bash 143 | cd source/stylegan2 144 | python prepare_data.py '../../data/face_cartoon' --size 256 --out '../../data/stylegan2/traindata' 145 | python train_condition.py --name 'ffhq_style_s256' --path '../../data/stylegan2/traindata' --config config/conf_server_train_condition_shell.json 146 | ``` 147 | 148 | after training, generated content calibrated samples via: 149 | ```bash 150 | python style_blend.py --name 'ffhq_style_s256' 151 | python generate_blendmodel.py --name 'ffhq_style_s256' --save_dir '../../data/face_cartoon/syn_style_faces' 152 | ``` 153 | 154 | - geometry calibration 155 | 156 | run geometry calibration for both photo and cartoon: 157 | ```bash 158 | cd source 159 | python image_flip_agument_parallel.py --data_dir '../data/face_cartoon' 160 | python image_scale_agument_parallel_flat.py --data_dir '../data/face_cartoon' 161 | python image_rotation_agument_parallel_flat.py --data_dir '../data/face_cartoon' 162 | ``` 163 | 164 | - train texture translator 165 | 166 | The dataset structure is recommended as: 167 | ``` 168 | +—data 169 | | +—face_photo 170 | | +—face_cartoon 171 | ``` 172 | resume training from pretrained model in similar style, 173 | 174 | style can be chosen from 'anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration' 175 | 176 | ```bash 177 | python train_localtoon.py --data_dir PATH_TO_YOU_DATA --work_dir PATH_SAVE --style anime 178 | ``` 179 | 180 | 181 | 182 | 183 | ## Acknowledgments 184 | 185 | Face detector and aligner are adapted from [Peppa_Pig_Face_Engine](https://github.com/610265158/Peppa_Pig_Face_Engine 186 | ) and [InsightFace](https://github.com/TreB1eN/InsightFace_Pytorch). 187 | 188 | 189 | 190 | ## Citation 191 | 192 | If you find this code useful for your research, please use the following BibTeX entry. 193 | 194 | ```bibtex 195 | @inproceedings{men2022dct, 196 | title={DCT-Net: Domain-Calibrated Translation for Portrait Stylization}, 197 | author={Men, Yifang and Yao, Yuan and Cui, Miaomiao and Lian, Zhouhui and Xie, Xuansong}, 198 | journal={ACM Transactions on Graphics (TOG)}, 199 | volume={41}, 200 | number={4}, 201 | pages={1--9}, 202 | year={2022}, 203 | publisher={ACM New York, NY, USA} 204 | } 205 | ``` 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/.DS_Store -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/demo.gif -------------------------------------------------------------------------------- /assets/sim1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim1.png -------------------------------------------------------------------------------- /assets/sim2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim2.png -------------------------------------------------------------------------------- /assets/sim3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim3.png -------------------------------------------------------------------------------- /assets/sim4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim4.png -------------------------------------------------------------------------------- /assets/sim5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim5.png -------------------------------------------------------------------------------- /assets/sim_3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_3d.png -------------------------------------------------------------------------------- /assets/sim_anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_anime.png -------------------------------------------------------------------------------- /assets/sim_artstyle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_artstyle.png -------------------------------------------------------------------------------- /assets/sim_design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_design.png -------------------------------------------------------------------------------- /assets/sim_handdrawn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_handdrawn.png -------------------------------------------------------------------------------- /assets/sim_illu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_illu.png -------------------------------------------------------------------------------- /assets/sim_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_sketch.png -------------------------------------------------------------------------------- /assets/styles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/styles.png -------------------------------------------------------------------------------- /assets/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/video.gif -------------------------------------------------------------------------------- /celeb.txt: -------------------------------------------------------------------------------- 1 | Beyoncé 2 | Taylor Swift 3 | Rihanna 4 | Ariana Grande 5 | Angelina Jolie 6 | Jennifer Lawrence 7 | Emma Watson 8 | Katy Perry 9 | Scarlett Johansson 10 | Lady Gaga 11 | Gal Gadot 12 | Selena Gomez 13 | Sandra Bullock 14 | Ellen DeGeneres 15 | Mila Kunis 16 | Jennifer Aniston 17 | Margot Robbie 18 | Blake Lively 19 | Miley Cyrus 20 | Charlize Theron 21 | Emma Stone 22 | Sofia Vergara 23 | Halle Berry 24 | Zendaya 25 | Dwayne Johnson 26 | Lady Amelia Windsor 27 | Brie Larson 28 | Adele 29 | Janelle Monae 30 | Shakira 31 | Priyanka Chopra 32 | Betty White 33 | Nina Dobrev 34 | Meghan Markle 35 | Lupita Nyong'o 36 | Emilia Clarke 37 | Kate Middleton 38 | Zooey Deschanel 39 | Sienna Miller 40 | Christina Aguilera 41 | Kate Hudson 42 | Gina Rodriguez 43 | Cardi B 44 | Yara Shahidi 45 | Michelle Obama 46 | Kourtney Kardashian 47 | Portia de Rossi 48 | Kerry Washington 49 | Jada Pinkett Smith 50 | Lucy Liu 51 | Victoria Beckham 52 | Gwyneth Paltrow 53 | Kim Kardashian 54 | Ellen Page 55 | Kerry Washington 56 | Maya Rudolph 57 | Alicia Keys 58 | Oprah Winfrey 59 | Tracee Ellis Ross 60 | Jennifer Lopez 61 | Rachel McAdams 62 | Pink 63 | Cameron Diaz 64 | Lily Collins 65 | Anne Hathaway 66 | Tyra Banks 67 | Ashley Tisdale 68 | Amanda Seyfried 69 | Jessica Alba 70 | Demi Lovato 71 | Keira Knightley 72 | Bella Hadid 73 | Kendall Jenner 74 | Emma Roberts 75 | Vanessa Hudgens 76 | Sofia Richie 77 | Hailey Bieber 78 | Gisele Bündchen 79 | Taylor Hill 80 | Kiki Layne 81 | Cate Blanchett 82 | Kate Winslet 83 | Gal Gadot 84 | Salma Hayek 85 | Julia Roberts 86 | Mariah Carey 87 | Scarlett Johansson 88 | Rosie Huntington-Whitely 89 | Marjane Satrapi 90 | Halle Berry 91 | Mariah Carey 92 | Selena Gomez 93 | Emma Watson 94 | Jennifer Aniston 95 | Rihanna 96 | Blake Lively 97 | Ariana Grande 98 | Angelina Jolie 99 | Lady Gaga 100 | Taylor Swift 101 | 102 | Robert Downey Jr. 103 | Tom Cruise 104 | George Clooney 105 | Brad Pitt 106 | Dwayne Johnson 107 | Leonardo DiCaprio 108 | Will Smith 109 | Johnny Depp 110 | Chris Evans 111 | Ryan Reynolds 112 | Tom Hanks 113 | Matt Damon 114 | Denzel Washington 115 | Hugh Jackman 116 | Chris Hemsworth 117 | Chris Pratt 118 | Idris Elba 119 | Daniel Craig 120 | Samuel L. Jackson 121 | Jeremy Renner 122 | Chris Pine 123 | Robert Pattinson 124 | Sebastian Stan 125 | Benedict Cumberbatch 126 | Paul Rudd 127 | Mark Wahlberg 128 | Zac Efron 129 | Jason Statham 130 | Michael Fassbender 131 | Joel Kinnaman 132 | Keanu Reeves 133 | Scarlett Johansson 134 | Vin Diesel 135 | Angelina Jolie 136 | Emma Watson 137 | Jennifer Lawrence 138 | Gal Gadot 139 | Margot Robbie 140 | Brie Larson 141 | Sofia Vergara 142 | Mila Kunis 143 | Emily Blunt 144 | Sandra Bullock 145 | Kate Winslet 146 | Nicole Kidman 147 | Charlize Theron 148 | Anne Hathaway 149 | Cate Blanchett 150 | Emma Stone 151 | Lupita Nyong'o 152 | Jennifer Aniston 153 | Halle Berry 154 | Rihanna 155 | Lady Gaga 156 | Beyoncé 157 | Taylor Swift 158 | Miley Cyrus 159 | Ariana Grande 160 | Meghan Markle 161 | Kate Middleton 162 | Angelina Jolie 163 | Jennifer Lopez 164 | Shakira 165 | Katy Perry 166 | Lady Gaga 167 | Britney Spears 168 | Adele 169 | Mariah Carey 170 | Madonna 171 | Janet Jackson 172 | Whitney Houston 173 | Tina Turner 174 | Celine Dion 175 | Barbra Streisand 176 | Cher 177 | Gloria Estefan 178 | Diana Ross 179 | Julie Andrews 180 | Liza Minnelli 181 | Bette Midler 182 | Elton John 183 | Freddie Mercury 184 | Paul McCartney 185 | Elvis Presley 186 | Michael Jackson 187 | Prince 188 | Madonna 189 | Mariah Carey 190 | Janet Jackson 191 | Whitney Houston 192 | Tina Turner 193 | Celine Dion 194 | Barbra Streisand 195 | Cher 196 | Gloria Estefan 197 | Diana Ross 198 | Julie Andrews 199 | Liza Minnelli 200 | Bette Midler 201 | Elton John -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | from modelscope.hub.snapshot_download import snapshot_download 2 | # pre-trained models in different style 3 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.') 4 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-3d_compound-models', cache_dir='.') 5 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-handdrawn_compound-models', cache_dir='.') 6 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sketch_compound-models', cache_dir='.') 7 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-artstyle_compound-models', cache_dir='.') 8 | 9 | # pre-trained models trained with DCT-Net + Stable-Diffusion 10 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sd-design_compound-models', revision='v1.0.0', cache_dir='.') 11 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sd-illustration_compound-models', revision='v1.0.0', cache_dir='.') 12 | 13 | 14 | -------------------------------------------------------------------------------- /export.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import cv2 5 | 6 | from modelscope.exporters.cv import CartoonTranslationExporter 7 | from modelscope.msdatasets import MsDataset 8 | from modelscope.outputs import OutputKeys 9 | from modelscope.pipelines import pipeline 10 | from modelscope.pipelines.base import Pipeline 11 | from modelscope.trainers.cv import CartoonTranslationTrainer 12 | from modelscope.utils.constant import Tasks 13 | from modelscope.utils.test_utils import test_level 14 | 15 | 16 | class TestImagePortraitStylizationTrainer(unittest.TestCase): 17 | 18 | def setUp(self) -> None: 19 | self.task = Tasks.image_portrait_stylization 20 | self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png' 21 | 22 | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') 23 | def test_run_with_model_name(self): 24 | model_id = 'damo/cv_unet_person-image-cartoon_compound-models' 25 | 26 | data_dir = MsDataset.load( 27 | 'dctnet_train_clipart_mini_ms', 28 | namespace='menyifang', 29 | split='train').config_kwargs['split_config']['train'] 30 | 31 | data_photo = os.path.join(data_dir, 'face_photo') 32 | data_cartoon = os.path.join(data_dir, 'face_cartoon') 33 | work_dir = 'exp_localtoon' 34 | max_steps = 10 35 | trainer = CartoonTranslationTrainer( 36 | model=model_id, 37 | work_dir=work_dir, 38 | photo=data_photo, 39 | cartoon=data_cartoon, 40 | max_steps=max_steps) 41 | trainer.train() 42 | 43 | # export pb file 44 | ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0)) 45 | pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb') 46 | exporter = CartoonTranslationExporter() 47 | exporter.export_frozen_graph_def( 48 | ckpt_path=ckpt_path, frozen_graph_path=pb_path) 49 | 50 | # infer with pb file 51 | self.pipeline_person_image_cartoon(trainer.model_dir) 52 | 53 | def pipeline_person_image_cartoon(self, model_dir): 54 | pipeline_cartoon = pipeline(task=self.task, model=model_dir) 55 | result = pipeline_cartoon(input=self.test_image) 56 | if result is not None: 57 | cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) 58 | print(f'Output written to {os.path.abspath("result.png")}') 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() -------------------------------------------------------------------------------- /extract_align_faces.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import argparse 5 | from source.facelib.facer import FaceAna 6 | import source.utils as utils 7 | from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points 8 | from modelscope.hub.snapshot_download import snapshot_download 9 | 10 | class FaceProcesser: 11 | def __init__(self, dataroot, crop_size = 256, max_face = 1): 12 | self.max_face = max_face 13 | self.crop_size = crop_size 14 | self.facer = FaceAna(dataroot) 15 | 16 | def filter_face(self, lm, crop_size): 17 | a = max(lm[:, 0])-min(lm[:, 0]) 18 | b = max(lm[:, 1])-min(lm[:, 1]) 19 | # print("a:%d, b:%d"%(a,b)) 20 | if max(a, b)0: 48 | continue 49 | 50 | if self.filter_face(landmark, self.crop_size)==0: 51 | print("filtered!") 52 | continue 53 | 54 | f5p = utils.get_f5p(landmark, img_bgr) 55 | # face alignment 56 | warped_face, _ = warp_and_crop_face( 57 | img_bgr, 58 | f5p, 59 | ratio=0.75, 60 | reference_pts=get_reference_facial_points(default_square=True), 61 | crop_size=(self.crop_size, self.crop_size), 62 | return_trans_inv=True) 63 | 64 | warped_faces.append(warped_face) 65 | i = i+1 66 | 67 | 68 | return warped_faces 69 | 70 | 71 | 72 | 73 | if __name__ == "__main__": 74 | 75 | 76 | parser = argparse.ArgumentParser(description="process remove bg result") 77 | parser.add_argument("--src_dir", type=str, default='', help="Path to src images.") 78 | parser.add_argument("--save_dir", type=str, default='', help="Path to save images.") 79 | parser.add_argument("--crop_size", type=int, default=256) 80 | parser.add_argument("--max_face", type=int, default=1) 81 | parser.add_argument("--overwrite", type=int, default=1) 82 | args = parser.parse_args() 83 | args.save_dir = os.path.dirname(args.src_dir) + '/face_cartoon/raw_style_faces' 84 | 85 | crop_size = args.crop_size 86 | max_face = args.max_face 87 | overwrite = args.overwrite 88 | 89 | # model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.') 90 | # print('model assets saved to %s'%model_dir) 91 | model_dir = 'damo/cv_unet_person-image-cartoon_compound-models' 92 | 93 | processer = FaceProcesser(dataroot=model_dir,crop_size=crop_size, max_face =max_face) 94 | 95 | src_dir = args.src_dir 96 | save_dir = args.save_dir 97 | 98 | # print('Step: start to extract aligned faces ... ...') 99 | 100 | print('src_dir:%s'% src_dir) 101 | print('save_dir:%s'% save_dir) 102 | 103 | if not os.path.exists(save_dir): 104 | os.makedirs(save_dir) 105 | 106 | paths = utils.all_file(src_dir) 107 | print('to process %d images'% len(paths)) 108 | 109 | for path in sorted(paths): 110 | dirname = path[len(src_dir)+1:].split('/')[0] 111 | 112 | outpath = save_dir + path[len(src_dir):] 113 | if not overwrite: 114 | if os.path.exists(outpath): 115 | continue 116 | 117 | sub_dir = os.path.dirname(outpath) 118 | # print(sub_dir) 119 | if not os.path.exists(sub_dir): 120 | os.makedirs(sub_dir, exist_ok=True) 121 | 122 | imgb = None 123 | imgc = None 124 | img = cv2.imread(path, -1) 125 | if img is None: 126 | continue 127 | 128 | if len(img.shape)==2: 129 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 130 | 131 | # print(img.shape) 132 | h,w,c = img.shape 133 | if h<256 or w<256: 134 | continue 135 | imgs = [] 136 | 137 | # if need resize, resize here 138 | img_h, img_w, _ = img.shape 139 | warped_faces = processer.process(img) 140 | if warped_faces is None: 141 | continue 142 | # ### only for anime faces, single, not detect face 143 | # warped_face = imga 144 | 145 | i=0 146 | for res in warped_faces: 147 | # filter small faces 148 | h, w, c = res.shape 149 | if h < 256 or w < 256: 150 | continue 151 | outpath = os.path.join(os.path.dirname(outpath), os.path.basename(outpath)[:-4] + '_' + str(i) + '.png') 152 | 153 | cv2.imwrite(outpath, res) 154 | print('save %s' % outpath) 155 | i = i+1 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | from modelscope.pipelines import pipeline 2 | from modelscope.utils.constant import Tasks 3 | import torch 4 | import os, cv2 5 | import argparse 6 | 7 | def load_cele_txt(celeb_file='celeb.txt'): 8 | celeb = open(celeb_file, 'r') 9 | lines = celeb.readlines() 10 | name_list = [] 11 | for line in lines: 12 | name = line.strip('\n') 13 | if name != '': 14 | name_list.append(name) 15 | return name_list 16 | 17 | 18 | def main(args): 19 | style = args.style 20 | repeat_num = 5 21 | 22 | model_id = 'damo/cv_cartoon_stable_diffusion_' + style 23 | pipe = pipeline(Tasks.text_to_image_synthesis, model=model_id, 24 | model_revision='v1.0.0', torch_dtype=torch.float16) 25 | from diffusers.schedulers import EulerAncestralDiscreteScheduler 26 | pipe.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.pipeline.scheduler.config) 27 | print('model init finished!') 28 | 29 | 30 | save_dir = 'res_style_%s/syn_celeb' % (style) 31 | if not os.path.exists(save_dir): 32 | os.makedirs(save_dir) 33 | 34 | name_list = load_cele_txt('celeb.txt') 35 | person_num = len(name_list) 36 | for i in range(person_num): 37 | name = name_list[i] 38 | print('process %s' % name) 39 | 40 | if style == "clipart": 41 | prompt = 'archer style, a portrait painting of %s' % (name) 42 | else: 43 | prompt = 'sks style, a painting of a %s, no text' % (name) 44 | 45 | images = pipe({'text': prompt, 'num_images_per_prompt': repeat_num})['output_imgs'] 46 | idx = 0 47 | for image in images: 48 | outpath = os.path.join(save_dir, '%s_%d.png' % (name, idx)) 49 | cv2.imwrite(outpath, image) 50 | idx += 1 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('--style', type=str, default='clipart') 56 | 57 | args = parser.parse_args() 58 | main(args) 59 | 60 | -------------------------------------------------------------------------------- /input.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/input.mp4 -------------------------------------------------------------------------------- /input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/input.png -------------------------------------------------------------------------------- /multi-style/download.py: -------------------------------------------------------------------------------- 1 | from modelscope.hub.snapshot_download import snapshot_download 2 | import argparse 3 | 4 | 5 | 6 | def process(args): 7 | style = args.style 8 | print('download %s model'%style) 9 | if style == "anime": 10 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.') 11 | 12 | elif style == "3d": 13 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-3d_compound-models', cache_dir='.') 14 | 15 | elif style == "handdrawn": 16 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-handdrawn_compound-models', cache_dir='.') 17 | 18 | elif style == "sketch": 19 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sketch_compound-models', cache_dir='.') 20 | 21 | elif style == "artstyle": 22 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-artstyle_compound-models', cache_dir='.') 23 | 24 | else: 25 | print('no such style %s'% style) 26 | 27 | 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--style', type=str, default='anime') 33 | args = parser.parse_args() 34 | 35 | process(args) 36 | -------------------------------------------------------------------------------- /multi-style/run.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | 4 | import cv2 5 | from source.cartoonize import Cartoonizer 6 | import os 7 | import argparse 8 | 9 | 10 | def process(args): 11 | 12 | style = args.style 13 | if style == "anime": 14 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon_compound-models') 15 | 16 | elif style == "3d": 17 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-3d_compound-models') 18 | 19 | elif style == "handdrawn": 20 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-handdrawn_compound-models') 21 | 22 | elif style == "sketch": 23 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-sketch_compound-models') 24 | 25 | elif style == "artstyle": 26 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-artstyle_compound-models') 27 | 28 | else: 29 | print('no such style %s' % style) 30 | return 0 31 | 32 | img = cv2.imread('input.png')[..., ::-1] 33 | result = algo.cartoonize(img) 34 | cv2.imwrite('result1_%s.png'%style, result) 35 | 36 | print('finished!') 37 | 38 | 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--style', type=str, default='anime') 44 | args = parser.parse_args() 45 | 46 | process(args) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /multi-style/run_sdk.py: -------------------------------------------------------------------------------- 1 | import cv2, argparse 2 | from modelscope.outputs import OutputKeys 3 | from modelscope.pipelines import pipeline 4 | from modelscope.utils.constant import Tasks 5 | 6 | def process(args): 7 | style = args.style 8 | print('choose style %s'%style) 9 | if style == "anime": 10 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 11 | model='damo/cv_unet_person-image-cartoon_compound-models') 12 | elif style == "3d": 13 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 14 | model='damo/cv_unet_person-image-cartoon-3d_compound-models') 15 | elif style == "handdrawn": 16 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 17 | model='damo/cv_unet_person-image-cartoon-handdrawn_compound-models') 18 | elif style == "sketch": 19 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 20 | model='damo/cv_unet_person-image-cartoon-sketch_compound-models') 21 | elif style == "artstyle": 22 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 23 | model='damo/cv_unet_person-image-cartoon-artstyle_compound-models') 24 | else: 25 | print('no such style %s'% style) 26 | return 0 27 | 28 | 29 | result = img_cartoon('input.png') 30 | 31 | cv2.imwrite('result_%s.png'%style, result[OutputKeys.OUTPUT_IMG]) 32 | print('finished!') 33 | 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--style', type=str, default='anime') 40 | args = parser.parse_args() 41 | 42 | process(args) 43 | -------------------------------------------------------------------------------- /notebooks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/notebooks/.DS_Store -------------------------------------------------------------------------------- /prepare_data.sh: -------------------------------------------------------------------------------- 1 | 2 | data_root='data' 3 | align_dir='raw_style_data_faces' 4 | 5 | echo "STEP: start to prepare data for stylegan ..." 6 | cd $data_root 7 | if [ ! -d stylegan ]; then 8 | mkdir stylegan 9 | fi 10 | cd stylegan 11 | stylegan_data_dir=$(pwd) 12 | if [ ! -d "$(date +"%Y%m%d")" ]; then 13 | mkdir "$(date +"%Y%m%d")" 14 | fi 15 | cd "$(date +"%Y%m%d")" 16 | cp $align_dir . -r 17 | if [ -d $(echo $align_dir) ]; then 18 | cp $(echo $align_dir) . -r 19 | fi 20 | src_dir_sg=$(pwd) 21 | 22 | cd $data_root/../source 23 | outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")" 24 | echo $outdir_sg 25 | echo $src_dir_sg 26 | if [ ! -d "$outdir_sg" ]; then 27 | python prepare_data.py --size 256 --out $outdir_sg $src_dir_sg 28 | fi 29 | echo "prepare data for stylegan finished!" 30 | 31 | ### train model 32 | #cd $data_root 33 | #cd stylegan 34 | #stylegan_data_dir=$(pwd) 35 | #outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")" 36 | #echo "STEP:start to train the style learner ..." 37 | #echo $outdir_sg 38 | #exp_name="ffhq_$(echo $stylename)_s256_id01_$(date +"%m%d")" 39 | #cd /data/vdb/qingyao/cartoon/mycode/stylegan2-pytorch 40 | #model_path=face_generation/experiment_stylegan/$(echo $exp_name)/models/001000.pt 41 | #if [ ! -f "$model_path" ]; then 42 | # CUDA_VISIBLE_DEVICES=6 python train_condition.py --name $exp_name --path $outdir_sg --config config/conf_server_train_condition_shell.json 43 | #fi 44 | #### [training...] 45 | #echo "train the style learner finished!" -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | from source.cartoonize import Cartoonizer 4 | import os 5 | 6 | def process(): 7 | 8 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon_compound-models') 9 | img = cv2.imread('input.png')[...,::-1] 10 | 11 | result = algo.cartoonize(img) 12 | 13 | cv2.imwrite('res.png', result) 14 | print('finished!') 15 | 16 | 17 | 18 | 19 | if __name__ == '__main__': 20 | process() 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /run_sdk.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from modelscope.outputs import OutputKeys 3 | from modelscope.pipelines import pipeline 4 | from modelscope.utils.constant import Tasks 5 | 6 | ##### DCT-Net 7 | ## anime style 8 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 9 | model='damo/cv_unet_person-image-cartoon_compound-models') 10 | result = img_cartoon('input.png') 11 | cv2.imwrite('result_anime.png', result[OutputKeys.OUTPUT_IMG]) 12 | 13 | ## 3d style 14 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 15 | model='damo/cv_unet_person-image-cartoon-3d_compound-models') 16 | result = img_cartoon('input.png') 17 | cv2.imwrite('result_3d.png', result[OutputKeys.OUTPUT_IMG]) 18 | 19 | ## handdrawn style 20 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 21 | model='damo/cv_unet_person-image-cartoon-handdrawn_compound-models') 22 | result = img_cartoon('input.png') 23 | cv2.imwrite('result_handdrawn.png', result[OutputKeys.OUTPUT_IMG]) 24 | 25 | ## sketch style 26 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 27 | model='damo/cv_unet_person-image-cartoon-sketch_compound-models') 28 | result = img_cartoon('input.png') 29 | cv2.imwrite('result_sketch.png', result[OutputKeys.OUTPUT_IMG]) 30 | 31 | ## artstyle style 32 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 33 | model='damo/cv_unet_person-image-cartoon-artstyle_compound-models') 34 | result = img_cartoon('input.png') 35 | cv2.imwrite('result_artstyle.png', result[OutputKeys.OUTPUT_IMG]) 36 | 37 | #### DCT-Net + SD 38 | ## design style 39 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 40 | model='damo/cv_unet_person-image-cartoon-sd-design_compound-models') 41 | result = img_cartoon('input.png') 42 | cv2.imwrite('result_sd_design.png', result[OutputKeys.OUTPUT_IMG]) 43 | 44 | ## illustration style 45 | img_cartoon = pipeline(Tasks.image_portrait_stylization, 46 | model='damo/cv_unet_person-image-cartoon-sd-illustration_compound-models') 47 | result = img_cartoon('input.png') 48 | cv2.imwrite('result_sd_illustration.png', result[OutputKeys.OUTPUT_IMG]) 49 | 50 | 51 | print('finished!') 52 | -------------------------------------------------------------------------------- /run_vid.py: -------------------------------------------------------------------------------- 1 | import cv2, argparse 2 | import imageio 3 | from tqdm import tqdm 4 | import numpy as np 5 | from modelscope.outputs import OutputKeys 6 | from modelscope.pipelines import pipeline 7 | from modelscope.utils.constant import Tasks 8 | 9 | def process(args): 10 | style = args.style 11 | print('choose style %s'%style) 12 | 13 | reader = imageio.get_reader(args.video_path) 14 | fps = reader.get_meta_data()['fps'] 15 | writer = imageio.get_writer(args.save_path, mode='I', fps=fps, codec='libx264') 16 | 17 | if style == "anime": 18 | style = "" 19 | else: 20 | style = '-' + style 21 | 22 | model_name = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models' 23 | img_cartoon = pipeline(Tasks.image_portrait_stylization, model=model_name) 24 | 25 | for _, img in tqdm(enumerate(reader)): 26 | result = img_cartoon(img[..., ::-1]) 27 | res = result[OutputKeys.OUTPUT_IMG] 28 | writer.append_data(res[..., ::-1].astype(np.uint8)) 29 | writer.close() 30 | print('finished!') 31 | print('result saved to %s'% args.save_path) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('--video_path', type=str, default='PATH/TO/YOUR/MP4') 37 | parser.add_argument('--save_path', type=str, default='res.mp4') 38 | parser.add_argument('--style', type=str, default='anime') 39 | 40 | args = parser.parse_args() 41 | 42 | process(args) 43 | -------------------------------------------------------------------------------- /source/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/.DS_Store -------------------------------------------------------------------------------- /source/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/__init__.py -------------------------------------------------------------------------------- /source/cartoonize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import tensorflow as tf 4 | import numpy as np 5 | from source.facelib.facer import FaceAna 6 | import source.utils as utils 7 | from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points 8 | 9 | if tf.__version__ >= '2.0': 10 | tf = tf.compat.v1 11 | tf.disable_eager_execution() 12 | 13 | 14 | class Cartoonizer(): 15 | def __init__(self, dataroot): 16 | 17 | self.facer = FaceAna(dataroot) 18 | self.sess_head = self.load_sess( 19 | os.path.join(dataroot, 'cartoon_anime_h.pb'), 'model_head') 20 | self.sess_bg = self.load_sess( 21 | os.path.join(dataroot, 'cartoon_anime_bg.pb'), 'model_bg') 22 | 23 | self.box_width = 288 24 | global_mask = cv2.imread(os.path.join(dataroot, 'alpha.jpg')) 25 | global_mask = cv2.resize( 26 | global_mask, (self.box_width, self.box_width), 27 | interpolation=cv2.INTER_AREA) 28 | self.global_mask = cv2.cvtColor( 29 | global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 30 | 31 | def load_sess(self, model_path, name): 32 | config = tf.ConfigProto(allow_soft_placement=True) 33 | config.gpu_options.allow_growth = True 34 | sess = tf.Session(config=config) 35 | print(f'loading model from {model_path}') 36 | with tf.gfile.FastGFile(model_path, 'rb') as f: 37 | graph_def = tf.GraphDef() 38 | graph_def.ParseFromString(f.read()) 39 | sess.graph.as_default() 40 | tf.import_graph_def(graph_def, name=name) 41 | sess.run(tf.global_variables_initializer()) 42 | print(f'load model {model_path} done.') 43 | return sess 44 | 45 | 46 | def detect_face(self, img): 47 | src_h, src_w, _ = img.shape 48 | src_x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 49 | boxes, landmarks, _ = self.facer.run(src_x) 50 | if boxes.shape[0] == 0: 51 | return None 52 | else: 53 | return landmarks 54 | 55 | 56 | def cartoonize(self, img): 57 | # img: RGB input 58 | ori_h, ori_w, _ = img.shape 59 | img = utils.resize_size(img, size=720) 60 | 61 | img_brg = img[:, :, ::-1] 62 | 63 | # background process 64 | pad_bg, pad_h, pad_w = utils.padTo16x(img_brg) 65 | 66 | bg_res = self.sess_bg.run( 67 | self.sess_bg.graph.get_tensor_by_name( 68 | 'model_bg/output_image:0'), 69 | feed_dict={'model_bg/input_image:0': pad_bg}) 70 | res = bg_res[:pad_h, :pad_w, :] 71 | 72 | landmarks = self.detect_face(img_brg) 73 | if landmarks is None: 74 | print('No face detected!') 75 | return res 76 | 77 | print('%d faces detected!'%len(landmarks)) 78 | for landmark in landmarks: 79 | # get facial 5 points 80 | f5p = utils.get_f5p(landmark, img_brg) 81 | 82 | # face alignment 83 | head_img, trans_inv = warp_and_crop_face( 84 | img, 85 | f5p, 86 | ratio=0.75, 87 | reference_pts=get_reference_facial_points(default_square=True), 88 | crop_size=(self.box_width, self.box_width), 89 | return_trans_inv=True) 90 | 91 | # head process 92 | head_res = self.sess_head.run( 93 | self.sess_head.graph.get_tensor_by_name( 94 | 'model_head/output_image:0'), 95 | feed_dict={ 96 | 'model_head/input_image:0': head_img[:, :, ::-1] 97 | }) 98 | 99 | # merge head and background 100 | head_trans_inv = cv2.warpAffine( 101 | head_res, 102 | trans_inv, (np.size(img, 1), np.size(img, 0)), 103 | borderValue=(0, 0, 0)) 104 | 105 | mask = self.global_mask 106 | mask_trans_inv = cv2.warpAffine( 107 | mask, 108 | trans_inv, (np.size(img, 1), np.size(img, 0)), 109 | borderValue=(0, 0, 0)) 110 | mask_trans_inv = np.expand_dims(mask_trans_inv, 2) 111 | 112 | res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res 113 | 114 | res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA) 115 | 116 | return res 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /source/facelib/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Copyright (c) Peppa_Pig_Face_Engine 3 | 4 | https://github.com/610265158/Peppa_Pig_Face_Engine 5 | -------------------------------------------------------------------------------- /source/facelib/LK/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/facelib/LK/__init__.py -------------------------------------------------------------------------------- /source/facelib/LK/lk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from modelscope.models.cv.cartoon.facelib.config import config as cfg 4 | 5 | 6 | class GroupTrack(): 7 | 8 | def __init__(self): 9 | self.old_frame = None 10 | self.previous_landmarks_set = None 11 | self.with_landmark = True 12 | self.thres = cfg.TRACE.pixel_thres 13 | self.alpha = cfg.TRACE.smooth_landmark 14 | self.iou_thres = cfg.TRACE.iou_thres 15 | 16 | def calculate(self, img, current_landmarks_set): 17 | if self.previous_landmarks_set is None: 18 | self.previous_landmarks_set = current_landmarks_set 19 | result = current_landmarks_set 20 | else: 21 | previous_lm_num = self.previous_landmarks_set.shape[0] 22 | if previous_lm_num == 0: 23 | self.previous_landmarks_set = current_landmarks_set 24 | result = current_landmarks_set 25 | return result 26 | else: 27 | result = [] 28 | for i in range(current_landmarks_set.shape[0]): 29 | not_in_flag = True 30 | for j in range(previous_lm_num): 31 | if self.iou(current_landmarks_set[i], 32 | self.previous_landmarks_set[j] 33 | ) > self.iou_thres: 34 | result.append( 35 | self.smooth(current_landmarks_set[i], 36 | self.previous_landmarks_set[j])) 37 | not_in_flag = False 38 | break 39 | if not_in_flag: 40 | result.append(current_landmarks_set[i]) 41 | 42 | result = np.array(result) 43 | self.previous_landmarks_set = result 44 | 45 | return result 46 | 47 | def iou(self, p_set0, p_set1): 48 | rec1 = [ 49 | np.min(p_set0[:, 0]), 50 | np.min(p_set0[:, 1]), 51 | np.max(p_set0[:, 0]), 52 | np.max(p_set0[:, 1]) 53 | ] 54 | rec2 = [ 55 | np.min(p_set1[:, 0]), 56 | np.min(p_set1[:, 1]), 57 | np.max(p_set1[:, 0]), 58 | np.max(p_set1[:, 1]) 59 | ] 60 | 61 | # computing area of each rectangles 62 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 63 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 64 | 65 | # computing the sum_area 66 | sum_area = S_rec1 + S_rec2 67 | 68 | # find the each edge of intersect rectangle 69 | x1 = max(rec1[0], rec2[0]) 70 | y1 = max(rec1[1], rec2[1]) 71 | x2 = min(rec1[2], rec2[2]) 72 | y2 = min(rec1[3], rec2[3]) 73 | 74 | # judge if there is an intersect 75 | intersect = max(0, x2 - x1) * max(0, y2 - y1) 76 | 77 | iou = intersect / (sum_area - intersect) 78 | return iou 79 | 80 | def smooth(self, now_landmarks, previous_landmarks): 81 | result = [] 82 | for i in range(now_landmarks.shape[0]): 83 | x = now_landmarks[i][0] - previous_landmarks[i][0] 84 | y = now_landmarks[i][1] - previous_landmarks[i][1] 85 | dis = np.sqrt(np.square(x) + np.square(y)) 86 | if dis < self.thres: 87 | result.append(previous_landmarks[i]) 88 | else: 89 | result.append( 90 | self.do_moving_average(now_landmarks[i], 91 | previous_landmarks[i])) 92 | 93 | return np.array(result) 94 | 95 | def do_moving_average(self, p_now, p_previous): 96 | p = self.alpha * p_now + (1 - self.alpha) * p_previous 97 | return p 98 | -------------------------------------------------------------------------------- /source/facelib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/facelib/__init__.py -------------------------------------------------------------------------------- /source/facelib/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from easydict import EasyDict as edict 5 | 6 | config = edict() 7 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 8 | 9 | config.DETECT = edict() 10 | config.DETECT.topk = 10 11 | config.DETECT.thres = 0.8 12 | config.DETECT.input_shape = (512, 512, 3) 13 | config.KEYPOINTS = edict() 14 | config.KEYPOINTS.p_num = 68 15 | config.KEYPOINTS.base_extend_range = [0.2, 0.3] 16 | config.KEYPOINTS.input_shape = (160, 160, 3) 17 | config.TRACE = edict() 18 | config.TRACE.pixel_thres = 1 19 | config.TRACE.smooth_box = 0.3 20 | config.TRACE.smooth_landmark = 0.95 21 | config.TRACE.iou_thres = 0.5 22 | config.DATA = edict() 23 | config.DATA.pixel_means = np.array([123., 116., 103.]) # RGB 24 | -------------------------------------------------------------------------------- /source/facelib/face_detector.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from .config import config as cfg 8 | 9 | if tf.__version__ >= '2.0': 10 | tf = tf.compat.v1 11 | 12 | 13 | class FaceDetector: 14 | 15 | def __init__(self, dir): 16 | 17 | self.model_path = dir + '/detector.pb' 18 | self.thres = cfg.DETECT.thres 19 | self.input_shape = cfg.DETECT.input_shape 20 | 21 | self._graph = tf.Graph() 22 | 23 | with self._graph.as_default(): 24 | self._graph, self._sess = self.init_model(self.model_path) 25 | 26 | self.input_image = tf.get_default_graph().get_tensor_by_name( 27 | 'tower_0/images:0') 28 | self.training = tf.get_default_graph().get_tensor_by_name( 29 | 'training_flag:0') 30 | self.output_ops = [ 31 | tf.get_default_graph().get_tensor_by_name('tower_0/boxes:0'), 32 | tf.get_default_graph().get_tensor_by_name('tower_0/scores:0'), 33 | tf.get_default_graph().get_tensor_by_name( 34 | 'tower_0/num_detections:0'), 35 | ] 36 | 37 | def __call__(self, image): 38 | 39 | image, scale_x, scale_y = self.preprocess( 40 | image, 41 | target_width=self.input_shape[1], 42 | target_height=self.input_shape[0]) 43 | 44 | image = np.expand_dims(image, 0) 45 | 46 | boxes, scores, num_boxes = self._sess.run( 47 | self.output_ops, 48 | feed_dict={ 49 | self.input_image: image, 50 | self.training: False 51 | }) 52 | 53 | num_boxes = num_boxes[0] 54 | boxes = boxes[0][:num_boxes] 55 | 56 | scores = scores[0][:num_boxes] 57 | 58 | to_keep = scores > self.thres 59 | boxes = boxes[to_keep] 60 | scores = scores[to_keep] 61 | 62 | y1 = self.input_shape[0] / scale_y 63 | x1 = self.input_shape[1] / scale_x 64 | y2 = self.input_shape[0] / scale_y 65 | x2 = self.input_shape[1] / scale_x 66 | scaler = np.array([y1, x1, y2, x2], dtype='float32') 67 | boxes = boxes * scaler 68 | 69 | scores = np.expand_dims(scores, 0).reshape([-1, 1]) 70 | 71 | for i in range(boxes.shape[0]): 72 | boxes[i] = np.array( 73 | [boxes[i][1], boxes[i][0], boxes[i][3], boxes[i][2]]) 74 | return np.concatenate([boxes, scores], axis=1) 75 | 76 | def preprocess(self, image, target_height, target_width, label=None): 77 | 78 | h, w, c = image.shape 79 | 80 | bimage = np.zeros( 81 | shape=[target_height, target_width, c], 82 | dtype=image.dtype) + np.array( 83 | cfg.DATA.pixel_means, dtype=image.dtype) 84 | long_side = max(h, w) 85 | 86 | scale_x = scale_y = target_height / long_side 87 | 88 | image = cv2.resize(image, None, fx=scale_x, fy=scale_y) 89 | 90 | h_, w_, _ = image.shape 91 | bimage[:h_, :w_, :] = image 92 | 93 | return bimage, scale_x, scale_y 94 | 95 | def init_model(self, *args): 96 | pb_path = args[0] 97 | 98 | def init_pb(model_path): 99 | config = tf.ConfigProto() 100 | config.gpu_options.per_process_gpu_memory_fraction = 0.2 101 | compute_graph = tf.Graph() 102 | compute_graph.as_default() 103 | sess = tf.Session(config=config) 104 | with tf.gfile.GFile(model_path, 'rb') as fid: 105 | graph_def = tf.GraphDef() 106 | graph_def.ParseFromString(fid.read()) 107 | tf.import_graph_def(graph_def, name='') 108 | 109 | return (compute_graph, sess) 110 | 111 | model = init_pb(pb_path) 112 | 113 | graph = model[0] 114 | sess = model[1] 115 | 116 | return graph, sess 117 | -------------------------------------------------------------------------------- /source/facelib/face_landmark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | from .config import config as cfg 6 | 7 | if tf.__version__ >= '2.0': 8 | tf = tf.compat.v1 9 | 10 | 11 | class FaceLandmark: 12 | 13 | def __init__(self, dir): 14 | self.model_path = dir + '/keypoints.pb' 15 | self.min_face = 60 16 | self.keypoint_num = cfg.KEYPOINTS.p_num * 2 17 | 18 | self._graph = tf.Graph() 19 | 20 | with self._graph.as_default(): 21 | 22 | self._graph, self._sess = self.init_model(self.model_path) 23 | self.img_input = tf.get_default_graph().get_tensor_by_name( 24 | 'tower_0/images:0') 25 | self.embeddings = tf.get_default_graph().get_tensor_by_name( 26 | 'tower_0/prediction:0') 27 | self.training = tf.get_default_graph().get_tensor_by_name( 28 | 'training_flag:0') 29 | 30 | self.landmark = self.embeddings[:, :self.keypoint_num] 31 | self.headpose = self.embeddings[:, -7:-4] * 90. 32 | self.state = tf.nn.sigmoid(self.embeddings[:, -4:]) 33 | 34 | def __call__(self, img, bboxes): 35 | landmark_result = [] 36 | state_result = [] 37 | for i, bbox in enumerate(bboxes): 38 | landmark, state = self._one_shot_run(img, bbox, i) 39 | if landmark is not None: 40 | landmark_result.append(landmark) 41 | state_result.append(state) 42 | return np.array(landmark_result), np.array(state_result) 43 | 44 | def simple_run(self, cropped_img): 45 | with self._graph.as_default(): 46 | 47 | cropped_img = np.expand_dims(cropped_img, axis=0) 48 | landmark, p, states = self._sess.run( 49 | [self.landmark, self.headpose, self.state], 50 | feed_dict={ 51 | self.img_input: cropped_img, 52 | self.training: False 53 | }) 54 | 55 | return landmark, states 56 | 57 | def _one_shot_run(self, image, bbox, i): 58 | 59 | bbox_width = bbox[2] - bbox[0] 60 | bbox_height = bbox[3] - bbox[1] 61 | if (bbox_width <= self.min_face and bbox_height <= self.min_face): 62 | return None, None 63 | add = int(max(bbox_width, bbox_height)) 64 | bimg = cv2.copyMakeBorder( 65 | image, 66 | add, 67 | add, 68 | add, 69 | add, 70 | borderType=cv2.BORDER_CONSTANT, 71 | value=cfg.DATA.pixel_means) 72 | bbox += add 73 | 74 | one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width 75 | center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2] 76 | 77 | bbox[0] = center[0] - one_edge // 2 78 | bbox[1] = center[1] - one_edge // 2 79 | bbox[2] = center[0] + one_edge // 2 80 | bbox[3] = center[1] + one_edge // 2 81 | 82 | bbox = bbox.astype(np.int) 83 | crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :] 84 | h, w, _ = crop_image.shape 85 | crop_image = cv2.resize( 86 | crop_image, 87 | (cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0])) 88 | crop_image = crop_image.astype(np.float32) 89 | 90 | keypoints, state = self.simple_run(crop_image) 91 | 92 | res = keypoints[0][:self.keypoint_num].reshape((-1, 2)) 93 | res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1] 94 | res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0] 95 | 96 | landmark = [] 97 | for _index in range(res.shape[0]): 98 | x_y = res[_index] 99 | landmark.append([ 100 | int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add), 101 | int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add) 102 | ]) 103 | 104 | landmark = np.array(landmark, np.float32) 105 | 106 | return landmark, state 107 | 108 | def init_model(self, *args): 109 | 110 | if len(args) == 1: 111 | use_pb = True 112 | pb_path = args[0] 113 | else: 114 | use_pb = False 115 | meta_path = args[0] 116 | restore_model_path = args[1] 117 | 118 | def ini_ckpt(): 119 | graph = tf.Graph() 120 | graph.as_default() 121 | configProto = tf.ConfigProto() 122 | configProto.gpu_options.allow_growth = True 123 | sess = tf.Session(config=configProto) 124 | # load_model(model_path, sess) 125 | saver = tf.train.import_meta_graph(meta_path) 126 | saver.restore(sess, restore_model_path) 127 | 128 | print('Model restred!') 129 | return (graph, sess) 130 | 131 | def init_pb(model_path): 132 | config = tf.ConfigProto() 133 | config.gpu_options.per_process_gpu_memory_fraction = 0.2 134 | compute_graph = tf.Graph() 135 | compute_graph.as_default() 136 | sess = tf.Session(config=config) 137 | with tf.gfile.GFile(model_path, 'rb') as fid: 138 | graph_def = tf.GraphDef() 139 | graph_def.ParseFromString(fid.read()) 140 | tf.import_graph_def(graph_def, name='') 141 | 142 | # saver = tf.train.Saver(tf.global_variables()) 143 | # saver.save(sess, save_path='./tmp.ckpt') 144 | return (compute_graph, sess) 145 | 146 | if use_pb: 147 | model = init_pb(pb_path) 148 | else: 149 | model = ini_ckpt() 150 | 151 | graph = model[0] 152 | sess = model[1] 153 | 154 | return graph, sess 155 | -------------------------------------------------------------------------------- /source/facelib/facer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from .config import config as cfg 7 | from .face_detector import FaceDetector 8 | from .face_landmark import FaceLandmark 9 | from .LK.lk import GroupTrack 10 | 11 | 12 | class FaceAna(): 13 | ''' 14 | by default the top3 facea sorted by area will be calculated for time reason 15 | ''' 16 | 17 | def __init__(self, model_dir): 18 | self.face_detector = FaceDetector(model_dir) 19 | self.face_landmark = FaceLandmark(model_dir) 20 | self.trace = GroupTrack() 21 | 22 | self.track_box = None 23 | self.previous_image = None 24 | self.previous_box = None 25 | 26 | self.diff_thres = 5 27 | self.top_k = cfg.DETECT.topk 28 | self.iou_thres = cfg.TRACE.iou_thres 29 | self.alpha = cfg.TRACE.smooth_box 30 | 31 | def run(self, image): 32 | 33 | boxes = self.face_detector(image) 34 | 35 | if boxes.shape[0] > self.top_k: 36 | boxes = self.sort(boxes) 37 | 38 | boxes_return = np.array(boxes) 39 | landmarks, states = self.face_landmark(image, boxes) 40 | 41 | if 1: 42 | track = [] 43 | for i in range(landmarks.shape[0]): 44 | track.append([ 45 | np.min(landmarks[i][:, 0]), 46 | np.min(landmarks[i][:, 1]), 47 | np.max(landmarks[i][:, 0]), 48 | np.max(landmarks[i][:, 1]) 49 | ]) 50 | tmp_box = np.array(track) 51 | 52 | self.track_box = self.judge_boxs(boxes_return, tmp_box) 53 | 54 | self.track_box, landmarks = self.sort_res(self.track_box, landmarks) 55 | return self.track_box, landmarks, states 56 | 57 | def sort_res(self, bboxes, points): 58 | area = [] 59 | for bbox in bboxes: 60 | bbox_width = bbox[2] - bbox[0] 61 | bbox_height = bbox[3] - bbox[1] 62 | area.append(bbox_height * bbox_width) 63 | 64 | area = np.array(area) 65 | picked = area.argsort()[::-1] 66 | sorted_bboxes = [bboxes[x] for x in picked] 67 | sorted_points = [points[x] for x in picked] 68 | return np.array(sorted_bboxes), np.array(sorted_points) 69 | 70 | def diff_frames(self, previous_frame, image): 71 | if previous_frame is None: 72 | return True 73 | else: 74 | _diff = cv2.absdiff(previous_frame, image) 75 | diff = np.sum( 76 | _diff) / previous_frame.shape[0] / previous_frame.shape[1] / 3. 77 | return diff > self.diff_thres 78 | 79 | def sort(self, bboxes): 80 | if self.top_k > 100: 81 | return bboxes 82 | area = [] 83 | for bbox in bboxes: 84 | 85 | bbox_width = bbox[2] - bbox[0] 86 | bbox_height = bbox[3] - bbox[1] 87 | area.append(bbox_height * bbox_width) 88 | 89 | area = np.array(area) 90 | 91 | picked = area.argsort()[-self.top_k:][::-1] 92 | sorted_bboxes = [bboxes[x] for x in picked] 93 | return np.array(sorted_bboxes) 94 | 95 | def judge_boxs(self, previuous_bboxs, now_bboxs): 96 | 97 | def iou(rec1, rec2): 98 | 99 | # computing area of each rectangles 100 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 101 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 102 | 103 | # computing the sum_area 104 | sum_area = S_rec1 + S_rec2 105 | 106 | # find the each edge of intersect rectangle 107 | x1 = max(rec1[0], rec2[0]) 108 | y1 = max(rec1[1], rec2[1]) 109 | x2 = min(rec1[2], rec2[2]) 110 | y2 = min(rec1[3], rec2[3]) 111 | 112 | # judge if there is an intersect 113 | intersect = max(0, x2 - x1) * max(0, y2 - y1) 114 | 115 | return intersect / (sum_area - intersect) 116 | 117 | if previuous_bboxs is None: 118 | return now_bboxs 119 | 120 | result = [] 121 | 122 | for i in range(now_bboxs.shape[0]): 123 | contain = False 124 | for j in range(previuous_bboxs.shape[0]): 125 | if iou(now_bboxs[i], previuous_bboxs[j]) > self.iou_thres: 126 | result.append( 127 | self.smooth(now_bboxs[i], previuous_bboxs[j])) 128 | contain = True 129 | break 130 | if not contain: 131 | result.append(now_bboxs[i]) 132 | 133 | return np.array(result) 134 | 135 | def smooth(self, now_box, previous_box): 136 | 137 | return self.do_moving_average(now_box[:4], previous_box[:4]) 138 | 139 | def do_moving_average(self, p_now, p_previous): 140 | p = self.alpha * p_now + (1 - self.alpha) * p_previous 141 | return p 142 | 143 | def reset(self): 144 | ''' 145 | reset the previous info used foe tracking, 146 | :return: 147 | ''' 148 | self.track_box = None 149 | self.previous_image = None 150 | self.previous_box = None 151 | -------------------------------------------------------------------------------- /source/image_flip_agument_parallel.py: -------------------------------------------------------------------------------- 1 | import oss2 2 | import argparse 3 | import cv2 4 | import glob 5 | import os 6 | import tqdm 7 | import numpy as np 8 | # from .utils import get_rmbg_alpha, get_img_from_url,reasonable_resize,major_detection,crop_img 9 | import tqdm 10 | import urllib 11 | import random 12 | from multiprocessing import Pool 13 | 14 | 15 | parser = argparse.ArgumentParser(description="process remove bg result") 16 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.") 17 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.") 18 | args = parser.parse_args() 19 | 20 | 21 | args.save_dir = os.path.join(args.data_dir, 'total_flip') 22 | form = 'single' 23 | 24 | 25 | def flipImage(image): 26 | new_image = cv2.flip(image, 1) 27 | return new_image 28 | 29 | def all_file(file_dir): 30 | L=[] 31 | for root, dirs, files in os.walk(file_dir): 32 | for file in files: 33 | extend = os.path.splitext(file)[1] 34 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG': 35 | L.append(os.path.join(root, file)) 36 | return L 37 | 38 | 39 | paths = all_file(args.data_dir) 40 | 41 | 42 | def process(path): 43 | 44 | print(path) 45 | outpath = args.save_dir+path[len(args.data_dir):] 46 | if os.path.exists(outpath): 47 | return 48 | 49 | sub_dir = os.path.dirname(outpath) 50 | # print(sub_dir) 51 | if not os.path.exists(sub_dir): 52 | os.makedirs(sub_dir,exist_ok=True) 53 | 54 | img = cv2.imread(path, -1) 55 | h, w, c = img.shape 56 | if form == "pair": 57 | imga = img[:, :int(w / 2), :] 58 | imgb = img[:, int(w / 2):, :] 59 | imga = flipImage(imga) 60 | imgb = flipImage(imgb) 61 | res = cv2.hconcat([imga, imgb]) # 水平拼接 62 | 63 | else: 64 | res = flipImage(img) 65 | 66 | cv2.imwrite(outpath, res) 67 | print('save %s' % outpath) 68 | 69 | 70 | 71 | 72 | 73 | if __name__ == "__main__": 74 | # main(args) 75 | pool = Pool(100) 76 | rl = pool.map(process, paths) 77 | pool.close() 78 | pool.join() -------------------------------------------------------------------------------- /source/image_rotation_agument_parallel_flat.py: -------------------------------------------------------------------------------- 1 | import oss2 2 | import argparse 3 | import cv2 4 | import glob 5 | import os 6 | import tqdm 7 | import numpy as np 8 | import tqdm 9 | import urllib 10 | import random 11 | from multiprocessing import Pool 12 | 13 | 14 | parser = argparse.ArgumentParser(description="process remove bg result") 15 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.") 16 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.") 17 | args = parser.parse_args() 18 | 19 | 20 | args.save_dir = os.path.join(args.data_dir, 'total_rotate') 21 | form = 'single' 22 | 23 | if not os.path.exists(args.save_dir): 24 | os.makedirs(args.save_dir,exist_ok=True) 25 | 26 | def all_file(file_dir): 27 | L=[] 28 | for root, dirs, files in os.walk(file_dir): 29 | for file in files: 30 | extend = os.path.splitext(file)[1] 31 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg': 32 | L.append(os.path.join(root, file)) 33 | return L 34 | 35 | def rotateImage(image, angle): 36 | row,col,_ = image.shape 37 | center=tuple(np.array([row,col])/2) 38 | rot_mat = cv2.getRotationMatrix2D(center,angle,1.0) 39 | new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT) 40 | return new_image 41 | 42 | 43 | paths = all_file(args.data_dir) 44 | 45 | 46 | def process(path): 47 | 48 | if 'total_scale' in path: 49 | return 50 | 51 | outpath = args.save_dir + path[len(args.data_dir):] 52 | sub_dir = os.path.dirname(outpath) 53 | if not os.path.exists(sub_dir): 54 | os.makedirs(sub_dir, exist_ok=True) 55 | 56 | img0 = cv2.imread(path, -1) 57 | h, w, c = img0.shape 58 | img = img0[:, :, :3].copy() 59 | if c == 4: 60 | alpha = img0[:, :, 3] 61 | mask = alpha[:, :, np.newaxis].copy() / 255. 62 | img = (img * mask + (1 - mask) * 255) 63 | 64 | imgb = None 65 | imgc = None 66 | if form is 'single': 67 | imga = img 68 | elif form is 'pair': 69 | imga = img[:, :int(w / 2), :] 70 | imgb = img[:, int(w / 2):, :] 71 | elif form is 'tuple': 72 | imga = img[:, :int(w / 3), :] 73 | imgb = img[:, int(w / 3): int(w * 2 / 3), :] 74 | imgc = img[:, int(w * 2 / 3):, :] 75 | 76 | angles = [ random.randint(-10, 0), random.randint(0, 10)] 77 | 78 | for angle in angles: 79 | 80 | imga_r = rotateImage(imga, angle) 81 | if form is 'single': 82 | res = imga_r 83 | elif form is 'pair': 84 | imgb_r = rotateImage(imgb, angle) 85 | res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接 86 | else: 87 | imgb_r = rotateImage(imgb, angle) 88 | imgc_r = rotateImage(imgc, angle) 89 | res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接 90 | 91 | cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res) 92 | print('save %s'% outpath) 93 | 94 | 95 | 96 | if __name__ == "__main__": 97 | # main(args) 98 | pool = Pool(100) 99 | rl = pool.map(process, paths) 100 | pool.close() 101 | pool.join() -------------------------------------------------------------------------------- /source/image_scale_agument_parallel_flat.py: -------------------------------------------------------------------------------- 1 | import oss2 2 | import argparse 3 | import cv2 4 | import glob 5 | import os 6 | import tqdm 7 | import numpy as np 8 | import tqdm 9 | import urllib 10 | import random 11 | from multiprocessing import Pool 12 | 13 | 14 | parser = argparse.ArgumentParser(description="process remove bg result") 15 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.") 16 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.") 17 | args = parser.parse_args() 18 | 19 | 20 | args.save_dir = os.path.join(args.data_dir, 'total_scale') 21 | form = 'single' 22 | 23 | if not os.path.exists(args.save_dir): 24 | os.makedirs(args.save_dir,exist_ok=True) 25 | 26 | def all_file(file_dir): 27 | L=[] 28 | for root, dirs, files in os.walk(file_dir): 29 | for file in files: 30 | extend = os.path.splitext(file)[1] 31 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg': 32 | L.append(os.path.join(root, file)) 33 | return L 34 | 35 | def scaleImage(image, degree): 36 | 37 | h, w, _ = image.shape 38 | canvas = np.ones((h, w, 3), dtype="uint8")*255 39 | nw, nh = (int(w*degree), int(h*degree)) 40 | image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) # w, h 41 | 42 | if degree<1: 43 | canvas[int((h-nh)/2):int((h-nh)/2)+nh, int((w-nw)/2):int((w-nw)/2)+nw,:] = image 44 | elif degree>1: 45 | canvas = image[int((nh-h)/2):int((nh-h)/2)+h, int((nw-w)/2):int((nw-w)/2)+w, :] 46 | else: 47 | canvas = image.copy() 48 | 49 | return canvas 50 | 51 | def scaleImage2(image, degree, angle=0): 52 | row,col,_ = image.shape 53 | center=tuple(np.array([row,col])/2) 54 | rot_mat = cv2.getRotationMatrix2D(center,angle,degree) 55 | new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT) 56 | return new_image 57 | 58 | 59 | paths = all_file(args.data_dir) 60 | 61 | 62 | def process(path): 63 | 64 | outpath = args.save_dir+path[len(args.data_dir):] 65 | sub_dir = os.path.dirname(outpath) 66 | if not os.path.exists(sub_dir): 67 | os.makedirs(sub_dir, exist_ok=True) 68 | 69 | 70 | img0 = cv2.imread(path, -1) 71 | h, w, c = img0.shape 72 | img = img0[:, :, :3].copy() 73 | if c==4: 74 | alpha = img0[:, :, 3] 75 | mask = alpha[:, :, np.newaxis].copy() / 255. 76 | img = (img * mask + (1 - mask) * 255) 77 | 78 | imgb = None 79 | imgc = None 80 | if form is 'single': 81 | imga = img 82 | elif form is 'pair': 83 | imga = img[:, :int(w / 2), :] 84 | imgb = img[:, int(w / 2):, :] 85 | elif form is 'tuple': 86 | imga = img[:, :int(w / 3), :] 87 | imgb = img[:, int(w / 3): int(w * 2 / 3), :] 88 | imgc = img[:, int(w * 2 / 3):, :] 89 | 90 | if random.random()>0.9: 91 | angles = [random.uniform(1, 1.1)] 92 | else: 93 | angles = [random.uniform(0.8, 1)] 94 | 95 | for angle in angles: 96 | 97 | imga_r = scaleImage(imga, angle) 98 | if form is 'single': 99 | res = imga_r 100 | elif form is 'pair': 101 | imgb_r = scaleImage(imgb, angle) 102 | res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接 103 | else: 104 | imgb_r = scaleImage(imgb, angle) 105 | imgc_r = scaleImage(imgc, angle) 106 | res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接 107 | 108 | cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res) 109 | print('save %s'% outpath) 110 | 111 | 112 | if __name__ == "__main__": 113 | # main(args) 114 | pool = Pool(100) 115 | rl = pool.map(process, paths) 116 | pool.close() 117 | pool.join() -------------------------------------------------------------------------------- /source/mtcnn_pytorch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/.DS_Store -------------------------------------------------------------------------------- /source/mtcnn_pytorch/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Dan Antoshchenko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /source/mtcnn_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # MTCNN 2 | 3 | `pytorch` implementation of **inference stage** of face detection algorithm described in 4 | [Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks](https://arxiv.org/abs/1604.02878). 5 | 6 | ## Example 7 | ![example of a face detection](images/example.png) 8 | 9 | ## How to use it 10 | Just download the repository and then do this 11 | ```python 12 | from src import detect_faces 13 | from PIL import Image 14 | 15 | image = Image.open('image.jpg') 16 | bounding_boxes, landmarks = detect_faces(image) 17 | ``` 18 | For examples see `test_on_images.ipynb`. 19 | 20 | ## Requirements 21 | * pytorch 0.2 22 | * Pillow, numpy 23 | 24 | ## Credit 25 | This implementation is heavily inspired by: 26 | * [pangyupo/mxnet_mtcnn_face_detection](https://github.com/pangyupo/mxnet_mtcnn_face_detection) 27 | -------------------------------------------------------------------------------- /source/mtcnn_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/__init__.py -------------------------------------------------------------------------------- /source/mtcnn_pytorch/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/src/__init__.py -------------------------------------------------------------------------------- /source/mtcnn_pytorch/src/align_trans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Apr 24 15:43:29 2017 3 | @author: zhaoy 4 | """ 5 | import cv2 6 | import numpy as np 7 | 8 | from .matlab_cp2tform import get_similarity_transform_for_cv2 9 | 10 | # reference facial points, a list of coordinates (x,y) 11 | dx = 1 12 | dy = 1 13 | REFERENCE_FACIAL_POINTS = [ 14 | [30.29459953 + dx, 51.69630051 + dy], # left eye 15 | [65.53179932 + dx, 51.50139999 + dy], # right eye 16 | [48.02519989 + dx, 71.73660278 + dy], # nose 17 | [33.54930115 + dx, 92.3655014 + dy], # left mouth 18 | [62.72990036 + dx, 92.20410156 + dy] # right mouth 19 | ] 20 | 21 | DEFAULT_CROP_SIZE = (96, 112) 22 | 23 | global FACIAL_POINTS 24 | 25 | 26 | class FaceWarpException(Exception): 27 | 28 | def __str__(self): 29 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 30 | 31 | 32 | def get_reference_facial_points(output_size=None, 33 | inner_padding_factor=0.0, 34 | outer_padding=(0, 0), 35 | default_square=False): 36 | 37 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) 38 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE) 39 | 40 | # 0) make the inner region a square 41 | if default_square: 42 | size_diff = max(tmp_crop_size) - tmp_crop_size 43 | tmp_5pts += size_diff / 2 44 | tmp_crop_size += size_diff 45 | 46 | h_crop = tmp_crop_size[0] 47 | w_crop = tmp_crop_size[1] 48 | if (output_size): 49 | if (output_size[0] == h_crop and output_size[1] == w_crop): 50 | return tmp_5pts 51 | 52 | if (inner_padding_factor == 0 and outer_padding == (0, 0)): 53 | if output_size is None: 54 | return tmp_5pts 55 | else: 56 | raise FaceWarpException( 57 | 'No paddings to do, output_size must be None or {}'.format( 58 | tmp_crop_size)) 59 | 60 | # check output size 61 | if not (0 <= inner_padding_factor <= 1.0): 62 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') 63 | 64 | factor = inner_padding_factor > 0 or outer_padding[0] > 0 65 | factor = factor or outer_padding[1] > 0 66 | if (factor and output_size is None): 67 | output_size = tmp_crop_size * \ 68 | (1 + inner_padding_factor * 2).astype(np.int32) 69 | output_size += np.array(outer_padding) 70 | 71 | cond1 = outer_padding[0] < output_size[0] 72 | cond2 = outer_padding[1] < output_size[1] 73 | if not (cond1 and cond2): 74 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]' 75 | 'and outer_padding[1] < output_size[1])') 76 | 77 | # 1) pad the inner region according inner_padding_factor 78 | if inner_padding_factor > 0: 79 | size_diff = tmp_crop_size * inner_padding_factor * 2 80 | tmp_5pts += size_diff / 2 81 | tmp_crop_size += np.round(size_diff).astype(np.int32) 82 | 83 | # 2) resize the padded inner region 84 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 85 | 86 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[ 87 | 1] * tmp_crop_size[0]: 88 | raise FaceWarpException( 89 | 'Must have (output_size - outer_padding)' 90 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)') 91 | 92 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] 93 | tmp_5pts = tmp_5pts * scale_factor 94 | 95 | # 3) add outer_padding to make output_size 96 | reference_5point = tmp_5pts + np.array(outer_padding) 97 | 98 | return reference_5point 99 | 100 | 101 | def get_affine_transform_matrix(src_pts, dst_pts): 102 | 103 | tfm = np.float32([[1, 0, 0], [0, 1, 0]]) 104 | n_pts = src_pts.shape[0] 105 | ones = np.ones((n_pts, 1), src_pts.dtype) 106 | src_pts_ = np.hstack([src_pts, ones]) 107 | dst_pts_ = np.hstack([dst_pts, ones]) 108 | 109 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) 110 | 111 | if rank == 3: 112 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], 113 | [A[0, 1], A[1, 1], A[2, 1]]]) 114 | elif rank == 2: 115 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) 116 | 117 | return tfm 118 | 119 | 120 | def warp_and_crop_face(src_img, 121 | facial_pts, 122 | ratio=0.84, 123 | reference_pts=None, 124 | crop_size=(96, 112), 125 | align_type='similarity' 126 | '', 127 | return_trans_inv=False): 128 | 129 | if reference_pts is None: 130 | if crop_size[0] == 96 and crop_size[1] == 112: 131 | reference_pts = REFERENCE_FACIAL_POINTS 132 | else: 133 | default_square = False 134 | inner_padding_factor = 0 135 | outer_padding = (0, 0) 136 | output_size = crop_size 137 | 138 | reference_pts = get_reference_facial_points( 139 | output_size, inner_padding_factor, outer_padding, 140 | default_square) 141 | 142 | ref_pts = np.float32(reference_pts) 143 | 144 | factor = ratio 145 | ref_pts = (ref_pts - 112 / 2) * factor + 112 / 2 146 | ref_pts *= crop_size[0] / 112. 147 | 148 | ref_pts_shp = ref_pts.shape 149 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: 150 | raise FaceWarpException( 151 | 'reference_pts.shape must be (K,2) or (2,K) and K>2') 152 | 153 | if ref_pts_shp[0] == 2: 154 | ref_pts = ref_pts.T 155 | 156 | src_pts = np.float32(facial_pts) 157 | src_pts_shp = src_pts.shape 158 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: 159 | raise FaceWarpException( 160 | 'facial_pts.shape must be (K,2) or (2,K) and K>2') 161 | 162 | if src_pts_shp[0] == 2: 163 | src_pts = src_pts.T 164 | 165 | if src_pts.shape != ref_pts.shape: 166 | raise FaceWarpException( 167 | 'facial_pts and reference_pts must have the same shape') 168 | 169 | if align_type == 'cv2_affine': 170 | tfm = cv2.getAffineTransform(src_pts, ref_pts) 171 | tfm_inv = cv2.getAffineTransform(ref_pts, src_pts) 172 | 173 | elif align_type == 'affine': 174 | tfm = get_affine_transform_matrix(src_pts, ref_pts) 175 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts) 176 | else: 177 | tfm, tfm_inv = get_similarity_transform_for_cv2(src_pts, ref_pts) 178 | 179 | face_img = cv2.warpAffine( 180 | src_img, 181 | tfm, (crop_size[0], crop_size[1]), 182 | borderValue=(255, 255, 255)) 183 | 184 | if return_trans_inv: 185 | return face_img, tfm_inv 186 | else: 187 | return face_img 188 | -------------------------------------------------------------------------------- /source/mtcnn_pytorch/src/matlab_cp2tform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Jul 11 06:54:28 2017 3 | 4 | @author: zhaoyafei 5 | """ 6 | 7 | import numpy as np 8 | from numpy.linalg import inv, lstsq 9 | from numpy.linalg import matrix_rank as rank 10 | from numpy.linalg import norm 11 | 12 | 13 | class MatlabCp2tormException(Exception): 14 | 15 | def __str__(self): 16 | return 'In File {}:{}'.format(__file__, super.__str__(self)) 17 | 18 | 19 | def tformfwd(trans, uv): 20 | """ 21 | Function: 22 | ---------- 23 | apply affine transform 'trans' to uv 24 | 25 | Parameters: 26 | ---------- 27 | @trans: 3x3 np.array 28 | transform matrix 29 | @uv: Kx2 np.array 30 | each row is a pair of coordinates (x, y) 31 | 32 | Returns: 33 | ---------- 34 | @xy: Kx2 np.array 35 | each row is a pair of transformed coordinates (x, y) 36 | """ 37 | uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) 38 | xy = np.dot(uv, trans) 39 | xy = xy[:, 0:-1] 40 | return xy 41 | 42 | 43 | def tforminv(trans, uv): 44 | """ 45 | Function: 46 | ---------- 47 | apply the inverse of affine transform 'trans' to uv 48 | 49 | Parameters: 50 | ---------- 51 | @trans: 3x3 np.array 52 | transform matrix 53 | @uv: Kx2 np.array 54 | each row is a pair of coordinates (x, y) 55 | 56 | Returns: 57 | ---------- 58 | @xy: Kx2 np.array 59 | each row is a pair of inverse-transformed coordinates (x, y) 60 | """ 61 | Tinv = inv(trans) 62 | xy = tformfwd(Tinv, uv) 63 | return xy 64 | 65 | 66 | def findNonreflectiveSimilarity(uv, xy, options=None): 67 | 68 | options = {'K': 2} 69 | 70 | K = options['K'] 71 | M = xy.shape[0] 72 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 73 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 74 | # print('--->x, y:\n', x, y 75 | 76 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) 77 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) 78 | X = np.vstack((tmp1, tmp2)) 79 | # print('--->X.shape: ', X.shape 80 | # print('X:\n', X 81 | 82 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector 83 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector 84 | U = np.vstack((u, v)) 85 | # print('--->U.shape: ', U.shape 86 | # print('U:\n', U 87 | 88 | # We know that X * r = U 89 | if rank(X) >= 2 * K: 90 | r, _, _, _ = lstsq(X, U) 91 | r = np.squeeze(r) 92 | else: 93 | raise Exception('cp2tform:twoUniquePointsReq') 94 | 95 | # print('--->r:\n', r 96 | 97 | sc = r[0] 98 | ss = r[1] 99 | tx = r[2] 100 | ty = r[3] 101 | 102 | Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) 103 | 104 | # print('--->Tinv:\n', Tinv 105 | 106 | T = inv(Tinv) 107 | # print('--->T:\n', T 108 | 109 | T[:, 2] = np.array([0, 0, 1]) 110 | 111 | return T, Tinv 112 | 113 | 114 | def findSimilarity(uv, xy, options=None): 115 | 116 | options = {'K': 2} 117 | 118 | # uv = np.array(uv) 119 | # xy = np.array(xy) 120 | 121 | # Solve for trans1 122 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) 123 | 124 | # Solve for trans2 125 | 126 | # manually reflect the xy data across the Y-axis 127 | xyR = xy 128 | xyR[:, 0] = -1 * xyR[:, 0] 129 | 130 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) 131 | 132 | # manually reflect the tform to undo the reflection done on xyR 133 | TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) 134 | 135 | trans2 = np.dot(trans2r, TreflectY) 136 | 137 | # Figure out if trans1 or trans2 is better 138 | xy1 = tformfwd(trans1, uv) 139 | norm1 = norm(xy1 - xy) 140 | 141 | xy2 = tformfwd(trans2, uv) 142 | norm2 = norm(xy2 - xy) 143 | 144 | if norm1 <= norm2: 145 | return trans1, trans1_inv 146 | else: 147 | trans2_inv = inv(trans2) 148 | return trans2, trans2_inv 149 | 150 | 151 | def get_similarity_transform(src_pts, dst_pts, reflective=True): 152 | """ 153 | Function: 154 | ---------- 155 | Find Similarity Transform Matrix 'trans': 156 | u = src_pts[:, 0] 157 | v = src_pts[:, 1] 158 | x = dst_pts[:, 0] 159 | y = dst_pts[:, 1] 160 | [x, y, 1] = [u, v, 1] * trans 161 | 162 | Parameters: 163 | ---------- 164 | @src_pts: Kx2 np.array 165 | source points, each row is a pair of coordinates (x, y) 166 | @dst_pts: Kx2 np.array 167 | destination points, each row is a pair of transformed 168 | coordinates (x, y) 169 | @reflective: True or False 170 | if True: 171 | use reflective similarity transform 172 | else: 173 | use non-reflective similarity transform 174 | 175 | Returns: 176 | ---------- 177 | @trans: 3x3 np.array 178 | transform matrix from uv to xy 179 | trans_inv: 3x3 np.array 180 | inverse of trans, transform matrix from xy to uv 181 | """ 182 | 183 | if reflective: 184 | trans, trans_inv = findSimilarity(src_pts, dst_pts) 185 | else: 186 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) 187 | 188 | return trans, trans_inv 189 | 190 | 191 | def cvt_tform_mat_for_cv2(trans): 192 | """ 193 | Function: 194 | ---------- 195 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be 196 | directly used by cv2.warpAffine(): 197 | u = src_pts[:, 0] 198 | v = src_pts[:, 1] 199 | x = dst_pts[:, 0] 200 | y = dst_pts[:, 1] 201 | [x, y].T = cv_trans * [u, v, 1].T 202 | 203 | Parameters: 204 | ---------- 205 | @trans: 3x3 np.array 206 | transform matrix from uv to xy 207 | 208 | Returns: 209 | ---------- 210 | @cv2_trans: 2x3 np.array 211 | transform matrix from src_pts to dst_pts, could be directly used 212 | for cv2.warpAffine() 213 | """ 214 | cv2_trans = trans[:, 0:2].T 215 | 216 | return cv2_trans 217 | 218 | 219 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): 220 | """ 221 | Function: 222 | ---------- 223 | Find Similarity Transform Matrix 'cv2_trans' which could be 224 | directly used by cv2.warpAffine(): 225 | u = src_pts[:, 0] 226 | v = src_pts[:, 1] 227 | x = dst_pts[:, 0] 228 | y = dst_pts[:, 1] 229 | [x, y].T = cv_trans * [u, v, 1].T 230 | 231 | Parameters: 232 | ---------- 233 | @src_pts: Kx2 np.array 234 | source points, each row is a pair of coordinates (x, y) 235 | @dst_pts: Kx2 np.array 236 | destination points, each row is a pair of transformed 237 | coordinates (x, y) 238 | reflective: True or False 239 | if True: 240 | use reflective similarity transform 241 | else: 242 | use non-reflective similarity transform 243 | 244 | Returns: 245 | ---------- 246 | @cv2_trans: 2x3 np.array 247 | transform matrix from src_pts to dst_pts, could be directly used 248 | for cv2.warpAffine() 249 | """ 250 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) 251 | cv2_trans = cvt_tform_mat_for_cv2(trans) 252 | cv2_trans_inv = cvt_tform_mat_for_cv2(trans_inv) 253 | 254 | return cv2_trans, cv2_trans_inv 255 | 256 | 257 | if __name__ == '__main__': 258 | """ 259 | u = [0, 6, -2] 260 | v = [0, 3, 5] 261 | x = [-1, 0, 4] 262 | y = [-1, -10, 4] 263 | 264 | # In Matlab, run: 265 | # 266 | # uv = [u'; v']; 267 | # xy = [x'; y']; 268 | # tform_sim=cp2tform(uv,xy,'similarity'); 269 | # 270 | # trans = tform_sim.tdata.T 271 | # ans = 272 | # -0.0764 -1.6190 0 273 | # 1.6190 -0.0764 0 274 | # -3.2156 0.0290 1.0000 275 | # trans_inv = tform_sim.tdata.Tinv 276 | # ans = 277 | # 278 | # -0.0291 0.6163 0 279 | # -0.6163 -0.0291 0 280 | # -0.0756 1.9826 1.0000 281 | # xy_m=tformfwd(tform_sim, u,v) 282 | # 283 | # xy_m = 284 | # 285 | # -3.2156 0.0290 286 | # 1.1833 -9.9143 287 | # 5.0323 2.8853 288 | # uv_m=tforminv(tform_sim, x,y) 289 | # 290 | # uv_m = 291 | # 292 | # 0.5698 1.3953 293 | # 6.0872 2.2733 294 | # -2.6570 4.3314 295 | """ 296 | u = [0, 6, -2] 297 | v = [0, 3, 5] 298 | x = [-1, 0, 4] 299 | y = [-1, -10, 4] 300 | 301 | uv = np.array((u, v)).T 302 | xy = np.array((x, y)).T 303 | 304 | print('\n--->uv:') 305 | print(uv) 306 | print('\n--->xy:') 307 | print(xy) 308 | 309 | trans, trans_inv = get_similarity_transform(uv, xy) 310 | 311 | print('\n--->trans matrix:') 312 | print(trans) 313 | 314 | print('\n--->trans_inv matrix:') 315 | print(trans_inv) 316 | 317 | print('\n---> apply transform to uv') 318 | print('\nxy_m = uv_augmented * trans') 319 | uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) 320 | xy_m = np.dot(uv_aug, trans) 321 | print(xy_m) 322 | 323 | print('\nxy_m = tformfwd(trans, uv)') 324 | xy_m = tformfwd(trans, uv) 325 | print(xy_m) 326 | 327 | print('\n---> apply inverse transform to xy') 328 | print('\nuv_m = xy_augmented * trans_inv') 329 | xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) 330 | uv_m = np.dot(xy_aug, trans_inv) 331 | print(uv_m) 332 | 333 | print('\nuv_m = tformfwd(trans_inv, xy)') 334 | uv_m = tformfwd(trans_inv, xy) 335 | print(uv_m) 336 | 337 | uv_m = tforminv(trans, xy) 338 | print('\nuv_m = tforminv(trans, xy)') 339 | print(uv_m) 340 | -------------------------------------------------------------------------------- /source/stylegan2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/.DS_Store -------------------------------------------------------------------------------- /source/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/__init__.py -------------------------------------------------------------------------------- /source/stylegan2/config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/config/.DS_Store -------------------------------------------------------------------------------- /source/stylegan2/config/conf_server_test_blend_shell.json: -------------------------------------------------------------------------------- 1 | { 2 | "parameters": { 3 | "output": "face_generation/experiment_stylegan", 4 | "ffhq_ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt", 5 | "size": 256, 6 | "sample": 1, 7 | "pics": 5000, 8 | "truncation": 0.7, 9 | "form": "single" 10 | } 11 | } 12 | 13 | -------------------------------------------------------------------------------- /source/stylegan2/config/conf_server_train_condition_shell.json: -------------------------------------------------------------------------------- 1 | { 2 | "parameters": { 3 | "output": "face_generation/experiment_stylegan", 4 | "ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt", 5 | "ckpt_ffhq": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt", 6 | "size": 256, 7 | "batch": 8, 8 | "n_sample": 4, 9 | "iter": 1500, 10 | "sample_every": 100, 11 | "save_every": 100 12 | } 13 | } 14 | 15 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/criteria/__init__.py -------------------------------------------------------------------------------- /source/stylegan2/criteria/helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 4 | 5 | """ 6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 7 | """ 8 | 9 | 10 | class Flatten(Module): 11 | def forward(self, input): 12 | return input.view(input.size(0), -1) 13 | 14 | 15 | def l2_norm(input, axis=1): 16 | norm = torch.norm(input, 2, axis, True) 17 | output = torch.div(input, norm) 18 | return output 19 | 20 | 21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 22 | """ A named tuple describing a ResNet block. """ 23 | 24 | 25 | def get_block(in_channel, depth, num_units, stride=2): 26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 27 | 28 | 29 | def get_blocks(num_layers): 30 | if num_layers == 50: 31 | blocks = [ 32 | get_block(in_channel=64, depth=64, num_units=3), 33 | get_block(in_channel=64, depth=128, num_units=4), 34 | get_block(in_channel=128, depth=256, num_units=14), 35 | get_block(in_channel=256, depth=512, num_units=3) 36 | ] 37 | elif num_layers == 100: 38 | blocks = [ 39 | get_block(in_channel=64, depth=64, num_units=3), 40 | get_block(in_channel=64, depth=128, num_units=13), 41 | get_block(in_channel=128, depth=256, num_units=30), 42 | get_block(in_channel=256, depth=512, num_units=3) 43 | ] 44 | elif num_layers == 152: 45 | blocks = [ 46 | get_block(in_channel=64, depth=64, num_units=3), 47 | get_block(in_channel=64, depth=128, num_units=8), 48 | get_block(in_channel=128, depth=256, num_units=36), 49 | get_block(in_channel=256, depth=512, num_units=3) 50 | ] 51 | else: 52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 53 | return blocks 54 | 55 | 56 | class SEModule(Module): 57 | def __init__(self, channels, reduction): 58 | super(SEModule, self).__init__() 59 | self.avg_pool = AdaptiveAvgPool2d(1) 60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 61 | self.relu = ReLU(inplace=True) 62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 63 | self.sigmoid = Sigmoid() 64 | 65 | def forward(self, x): 66 | module_input = x 67 | x = self.avg_pool(x) 68 | x = self.fc1(x) 69 | x = self.relu(x) 70 | x = self.fc2(x) 71 | x = self.sigmoid(x) 72 | return module_input * x 73 | 74 | 75 | class bottleneck_IR(Module): 76 | def __init__(self, in_channel, depth, stride): 77 | super(bottleneck_IR, self).__init__() 78 | if in_channel == depth: 79 | self.shortcut_layer = MaxPool2d(1, stride) 80 | else: 81 | self.shortcut_layer = Sequential( 82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 83 | BatchNorm2d(depth) 84 | ) 85 | self.res_layer = Sequential( 86 | BatchNorm2d(in_channel), 87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 89 | ) 90 | 91 | def forward(self, x): 92 | shortcut = self.shortcut_layer(x) 93 | res = self.res_layer(x) 94 | return res + shortcut 95 | 96 | 97 | class bottleneck_IR_SE(Module): 98 | def __init__(self, in_channel, depth, stride): 99 | super(bottleneck_IR_SE, self).__init__() 100 | if in_channel == depth: 101 | self.shortcut_layer = MaxPool2d(1, stride) 102 | else: 103 | self.shortcut_layer = Sequential( 104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 105 | BatchNorm2d(depth) 106 | ) 107 | self.res_layer = Sequential( 108 | BatchNorm2d(in_channel), 109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 110 | PReLU(depth), 111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 112 | BatchNorm2d(depth), 113 | SEModule(depth, 16) 114 | ) 115 | 116 | def forward(self, x): 117 | shortcut = self.shortcut_layer(x) 118 | res = self.res_layer(x) 119 | return res + shortcut 120 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/id_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .model_irse import Backbone 4 | 5 | 6 | class IDLoss(nn.Module): 7 | def __init__(self): 8 | super(IDLoss, self).__init__() 9 | print('Loading ResNet ArcFace') 10 | model_paths = '/data/vdb/qingyao/cartoon/mycode/pretrained_models/model_ir_se50.pth' 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_paths)) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def extract_feats(self, x): 17 | x = x[:, :, 35:223, 32:220] # Crop interesting region 18 | x = self.face_pool(x) 19 | x_feats = self.facenet(x) 20 | return x_feats 21 | 22 | def forward(self, y_hat, x): 23 | n_samples = x.shape[0] 24 | x_feats = self.extract_feats(x) 25 | y_hat_feats = self.extract_feats(y_hat) 26 | loss = 0 27 | sim_improvement = 0 28 | id_logs = [] 29 | count = 0 30 | for i in range(n_samples): 31 | diff_input = y_hat_feats[i].dot(x_feats[i]) 32 | id_logs.append({ 33 | 'diff_input': float(diff_input) 34 | }) 35 | # loss += 1 - diff_target 36 | # modify 37 | loss += 1 - diff_input 38 | count += 1 39 | 40 | return loss / count 41 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /source/stylegan2/criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /source/stylegan2/criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/moco_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from configs.paths_config import model_paths 5 | 6 | 7 | class MocoLoss(nn.Module): 8 | 9 | def __init__(self): 10 | super(MocoLoss, self).__init__() 11 | print("Loading MOCO model from path: {}".format(model_paths["moco"])) 12 | self.model = self.__load_model() 13 | self.model.cuda() 14 | self.model.eval() 15 | 16 | @staticmethod 17 | def __load_model(): 18 | import torchvision.models as models 19 | model = models.__dict__["resnet50"]() 20 | # freeze all layers but the last fc 21 | for name, param in model.named_parameters(): 22 | if name not in ['fc.weight', 'fc.bias']: 23 | param.requires_grad = False 24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu") 25 | state_dict = checkpoint['state_dict'] 26 | # rename moco pre-trained keys 27 | for k in list(state_dict.keys()): 28 | # retain only encoder_q up to before the embedding layer 29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'): 30 | # remove prefix 31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k] 32 | # delete renamed or unused k 33 | del state_dict[k] 34 | msg = model.load_state_dict(state_dict, strict=False) 35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} 36 | # remove output layer 37 | model = nn.Sequential(*list(model.children())[:-1]).cuda() 38 | return model 39 | 40 | def extract_feats(self, x): 41 | x = F.interpolate(x, size=224) 42 | x_feats = self.model(x) 43 | x_feats = nn.functional.normalize(x_feats, dim=1) 44 | x_feats = x_feats.squeeze() 45 | return x_feats 46 | 47 | def forward(self, y_hat, y, x): 48 | n_samples = x.shape[0] 49 | x_feats = self.extract_feats(x) 50 | y_feats = self.extract_feats(y) 51 | y_hat_feats = self.extract_feats(y_hat) 52 | y_feats = y_feats.detach() 53 | loss = 0 54 | sim_improvement = 0 55 | sim_logs = [] 56 | count = 0 57 | for i in range(n_samples): 58 | diff_target = y_hat_feats[i].dot(y_feats[i]) 59 | diff_input = y_hat_feats[i].dot(x_feats[i]) 60 | diff_views = y_feats[i].dot(x_feats[i]) 61 | sim_logs.append({'diff_target': float(diff_target), 62 | 'diff_input': float(diff_input), 63 | 'diff_views': float(diff_views)}) 64 | loss += 1 - diff_target 65 | sim_diff = float(diff_target) - float(diff_views) 66 | sim_improvement += sim_diff 67 | count += 1 68 | 69 | return loss / count, sim_improvement / count, sim_logs 70 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/model_irse.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 2 | from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 3 | 4 | """ 5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 6 | """ 7 | 8 | 9 | class Backbone(Module): 10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 11 | super(Backbone, self).__init__() 12 | assert input_size in [112, 224], "input_size should be 112 or 224" 13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 15 | blocks = get_blocks(num_layers) 16 | if mode == 'ir': 17 | unit_module = bottleneck_IR 18 | elif mode == 'ir_se': 19 | unit_module = bottleneck_IR_SE 20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 21 | BatchNorm2d(64), 22 | PReLU(64)) 23 | if input_size == 112: 24 | self.output_layer = Sequential(BatchNorm2d(512), 25 | Dropout(drop_ratio), 26 | Flatten(), 27 | Linear(512 * 7 * 7, 512), 28 | BatchNorm1d(512, affine=affine)) 29 | else: 30 | self.output_layer = Sequential(BatchNorm2d(512), 31 | Dropout(drop_ratio), 32 | Flatten(), 33 | Linear(512 * 14 * 14, 512), 34 | BatchNorm1d(512, affine=affine)) 35 | 36 | modules = [] 37 | for block in blocks: 38 | for bottleneck in block: 39 | modules.append(unit_module(bottleneck.in_channel, 40 | bottleneck.depth, 41 | bottleneck.stride)) 42 | self.body = Sequential(*modules) 43 | 44 | def forward(self, x): 45 | x = self.input_layer(x) 46 | x = self.body(x) 47 | x = self.output_layer(x) 48 | return l2_norm(x) 49 | 50 | 51 | def IR_50(input_size): 52 | """Constructs a ir-50 model.""" 53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 54 | return model 55 | 56 | 57 | def IR_101(input_size): 58 | """Constructs a ir-101 model.""" 59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 60 | return model 61 | 62 | 63 | def IR_152(input_size): 64 | """Constructs a ir-152 model.""" 65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 66 | return model 67 | 68 | 69 | def IR_SE_50(input_size): 70 | """Constructs a ir_se-50 model.""" 71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 72 | return model 73 | 74 | 75 | def IR_SE_101(input_size): 76 | """Constructs a ir_se-101 model.""" 77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 78 | return model 79 | 80 | 81 | def IR_SE_152(input_size): 82 | """Constructs a ir_se-152 model.""" 83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 84 | return model 85 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import optim 6 | 7 | from PIL import Image 8 | import os 9 | 10 | 11 | # gram matrix and loss 12 | class GramMatrix(nn.Module): 13 | def forward(self, input): 14 | b, c, h, w = input.size() 15 | F = input.view(b, c, h * w) 16 | G = torch.bmm(F, F.transpose(1, 2)) 17 | G.div_(h * w) 18 | return G 19 | 20 | 21 | class GramMSELoss(nn.Module): 22 | def forward(self, input, target): 23 | out = nn.MSELoss()(GramMatrix()(input), target) 24 | return (out) 25 | 26 | 27 | # vgg definition that conveniently let's you grab the outputs from any layer 28 | class VGG(nn.Module): 29 | def __init__(self, pool='max'): 30 | super(VGG, self).__init__() 31 | # vgg modules 32 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) 33 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 34 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 35 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 36 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 37 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 38 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 39 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 40 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 41 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 42 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 44 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 46 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 47 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 48 | 49 | if pool == 'max': 50 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 51 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 52 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 53 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 54 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 55 | 56 | elif pool == 'avg': 57 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 58 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 59 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 60 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 61 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 62 | 63 | 64 | def forward(self, x): 65 | out = {} 66 | out['r11'] = F.relu(self.conv1_1(x)) 67 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 68 | out['p1'] = self.pool1(out['r12']) 69 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 70 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 71 | out['p2'] = self.pool2(out['r22']) 72 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 73 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 74 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 75 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 76 | out['p3'] = self.pool3(out['r34']) 77 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 78 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 79 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 80 | conv_4_4 = self.conv4_4(out['r43']) 81 | out['r44'] = F.relu(conv_4_4) 82 | out['p4'] = self.pool4(out['r44']) 83 | return conv_4_4 -------------------------------------------------------------------------------- /source/stylegan2/criteria/vgg_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .vgg import VGG 4 | import os 5 | from configs.paths_config import model_paths 6 | 7 | class VggLoss(nn.Module): 8 | def __init__(self): 9 | super(VggLoss, self).__init__() 10 | print("Loading VGG19 model from path: {}".format(model_paths["vgg"])) 11 | 12 | self.vgg_model = VGG() 13 | self.vgg_model.load_state_dict(torch.load(model_paths['vgg'])) 14 | self.vgg_model.cuda() 15 | self.vgg_model.eval() 16 | 17 | self.l1loss = torch.nn.L1Loss() 18 | 19 | 20 | 21 | 22 | 23 | def forward(self, input_photo, output): 24 | vgg_photo = self.vgg_model(input_photo) 25 | vgg_output = self.vgg_model(output) 26 | n, c, h, w = vgg_photo.shape 27 | # h, w, c = vgg_photo.get_shape().as_list()[1:] 28 | loss = self.l1loss(vgg_photo, vgg_output) 29 | 30 | return loss 31 | -------------------------------------------------------------------------------- /source/stylegan2/criteria/w_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class WNormLoss(nn.Module): 6 | 7 | def __init__(self, start_from_latent_avg=True): 8 | super(WNormLoss, self).__init__() 9 | self.start_from_latent_avg = start_from_latent_avg 10 | 11 | def forward(self, latent, latent_avg=None): 12 | if self.start_from_latent_avg: 13 | latent = latent - latent_avg 14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0] 15 | -------------------------------------------------------------------------------- /source/stylegan2/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MultiResolutionDataset(Dataset): 9 | def __init__(self, path, transform, resolution=256): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | 40 | return img 41 | -------------------------------------------------------------------------------- /source/stylegan2/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /source/stylegan2/generate_blendmodel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7" 4 | import torch 5 | from torchvision import utils 6 | from model import Generator 7 | from tqdm import tqdm 8 | import json 9 | import glob 10 | 11 | from PIL import Image 12 | 13 | def make_image(tensor): 14 | return ( 15 | tensor.detach() 16 | .clamp_(min=-1, max=1) 17 | .add(1) 18 | .div_(2) 19 | .mul(255) 20 | .type(torch.uint8) 21 | .permute(0, 2, 3, 1) 22 | .to("cpu") 23 | .numpy() 24 | ) 25 | 26 | def generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq): 27 | 28 | outdir = args.save_dir 29 | 30 | # print(outdir) 31 | # outdir = os.path.join(args.output, args.name, 'eval','toons_paired_0512') 32 | if not os.path.exists(outdir): 33 | os.makedirs(outdir) 34 | 35 | with torch.no_grad(): 36 | g_ema.eval() 37 | for i in tqdm(range(args.pics)): 38 | sample_z = torch.randn(args.sample, args.latent, device=device) 39 | 40 | res, _ = g_ema( 41 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 42 | ) 43 | if args.form == "pair": 44 | sample_face, _ = g_ema_ffhq( 45 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 46 | ) 47 | res = torch.cat([sample_face, res], 3) 48 | 49 | outpath = os.path.join(outdir, str(i).zfill(6)+'.png') 50 | utils.save_image( 51 | res, 52 | outpath, 53 | # f"sample/{str(i).zfill(6)}.png", 54 | nrow=1, 55 | normalize=True, 56 | range=(-1, 1), 57 | ) 58 | # print('save %s'% outpath) 59 | 60 | 61 | 62 | 63 | if __name__ == "__main__": 64 | device = "cuda" 65 | 66 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 67 | parser.add_argument('--config', type=str, default='config/conf_server_test_blend_shell.json') 68 | parser.add_argument('--name', type=str, default='') 69 | parser.add_argument('--save_dir', type=str, default='') 70 | 71 | parser.add_argument('--form', type=str, default='single') 72 | parser.add_argument( 73 | "--size", type=int, default=256, help="output image size of the generator" 74 | ) 75 | parser.add_argument( 76 | "--sample", 77 | type=int, 78 | default=1, 79 | help="number of samples to be generated for each image", 80 | ) 81 | parser.add_argument( 82 | "--pics", type=int, default=20, help="number of images to be generated" 83 | ) 84 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 85 | parser.add_argument( 86 | "--truncation_mean", 87 | type=int, 88 | default=4096, 89 | help="number of vectors to calculate mean for the truncation", 90 | ) 91 | parser.add_argument( 92 | "--ckpt", 93 | type=str, 94 | default="stylegan2-ffhq-config-f.pt", 95 | help="path to the model checkpoint", 96 | ) 97 | parser.add_argument( 98 | "--channel_multiplier", 99 | type=int, 100 | default=2, 101 | help="channel multiplier of the generator. config-f = 2, else = 1", 102 | ) 103 | 104 | args = parser.parse_args() 105 | # from config updata paras 106 | opt = vars(args) 107 | with open(args.config) as f: 108 | config = json.load(f)['parameters'] 109 | for key, value in config.items(): 110 | opt[key] = value 111 | 112 | # args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_001000_4.pt' 113 | args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_' 114 | args.ckpt = glob.glob(args.ckpt+'*')[0] 115 | 116 | args.latent = 512 117 | args.n_mlp = 8 118 | 119 | g_ema = Generator( 120 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 121 | ).to(device) 122 | checkpoint = torch.load(args.ckpt) 123 | 124 | # g_ema.load_state_dict(checkpoint["g_ema"]) 125 | g_ema.load_state_dict(checkpoint["g_ema"], strict=False) 126 | 127 | ## add G_ffhq 128 | g_ema_ffhq = Generator( 129 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 130 | ).to(device) 131 | checkpoint_ffhq = torch.load(args.ffhq_ckpt) 132 | g_ema_ffhq.load_state_dict(checkpoint_ffhq["g_ema"], strict=False) 133 | 134 | 135 | if args.truncation < 1: 136 | with torch.no_grad(): 137 | mean_latent = g_ema.mean_latent(args.truncation_mean) 138 | else: 139 | mean_latent = None 140 | 141 | model_name = os.path.basename(args.ckpt) 142 | print('save generated samples to %s'% os.path.join(args.output, args.name, 'eval_blend', model_name)) 143 | generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq) 144 | # generate_style_mix(args, g_ema, device, mean_latent, model_name, g_ema_ffhq) 145 | 146 | # latent_path = 'test2.pt' 147 | # generate_from_latent(args, g_ema, device, mean_latent, latent_path) 148 | -------------------------------------------------------------------------------- /source/stylegan2/noise.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/noise.pt -------------------------------------------------------------------------------- /source/stylegan2/non_leaking.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import autograd 5 | from torch.nn import functional as F 6 | import numpy as np 7 | 8 | from distributed import reduce_sum 9 | from op import upfirdn2d 10 | 11 | 12 | class AdaptiveAugment: 13 | def __init__(self, ada_aug_target, ada_aug_len, update_every, device): 14 | self.ada_aug_target = ada_aug_target 15 | self.ada_aug_len = ada_aug_len 16 | self.update_every = update_every 17 | 18 | self.ada_update = 0 19 | self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device) 20 | self.r_t_stat = 0 21 | self.ada_aug_p = 0 22 | 23 | @torch.no_grad() 24 | def tune(self, real_pred): 25 | self.ada_aug_buf += torch.tensor( 26 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]), 27 | device=real_pred.device, 28 | ) 29 | self.ada_update += 1 30 | 31 | if self.ada_update % self.update_every == 0: 32 | self.ada_aug_buf = reduce_sum(self.ada_aug_buf) 33 | pred_signs, n_pred = self.ada_aug_buf.tolist() 34 | 35 | self.r_t_stat = pred_signs / n_pred 36 | 37 | if self.r_t_stat > self.ada_aug_target: 38 | sign = 1 39 | 40 | else: 41 | sign = -1 42 | 43 | self.ada_aug_p += sign * n_pred / self.ada_aug_len 44 | self.ada_aug_p = min(1, max(0, self.ada_aug_p)) 45 | self.ada_aug_buf.mul_(0) 46 | self.ada_update = 0 47 | 48 | return self.ada_aug_p 49 | 50 | 51 | SYM6 = ( 52 | 0.015404109327027373, 53 | 0.0034907120842174702, 54 | -0.11799011114819057, 55 | -0.048311742585633, 56 | 0.4910559419267466, 57 | 0.787641141030194, 58 | 0.3379294217276218, 59 | -0.07263752278646252, 60 | -0.021060292512300564, 61 | 0.04472490177066578, 62 | 0.0017677118642428036, 63 | -0.007800708325034148, 64 | ) 65 | 66 | 67 | def translate_mat(t_x, t_y, device="cpu"): 68 | batch = t_x.shape[0] 69 | 70 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1) 71 | translate = torch.stack((t_x, t_y), 1) 72 | mat[:, :2, 2] = translate 73 | 74 | return mat 75 | 76 | 77 | def rotate_mat(theta, device="cpu"): 78 | batch = theta.shape[0] 79 | 80 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1) 81 | sin_t = torch.sin(theta) 82 | cos_t = torch.cos(theta) 83 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2) 84 | mat[:, :2, :2] = rot 85 | 86 | return mat 87 | 88 | 89 | def scale_mat(s_x, s_y, device="cpu"): 90 | batch = s_x.shape[0] 91 | 92 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1) 93 | mat[:, 0, 0] = s_x 94 | mat[:, 1, 1] = s_y 95 | 96 | return mat 97 | 98 | 99 | def translate3d_mat(t_x, t_y, t_z): 100 | batch = t_x.shape[0] 101 | 102 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 103 | translate = torch.stack((t_x, t_y, t_z), 1) 104 | mat[:, :3, 3] = translate 105 | 106 | return mat 107 | 108 | 109 | def rotate3d_mat(axis, theta): 110 | batch = theta.shape[0] 111 | 112 | u_x, u_y, u_z = axis 113 | 114 | eye = torch.eye(3).unsqueeze(0) 115 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0) 116 | outer = torch.tensor(axis) 117 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0) 118 | 119 | sin_t = torch.sin(theta).view(-1, 1, 1) 120 | cos_t = torch.cos(theta).view(-1, 1, 1) 121 | 122 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer 123 | 124 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 125 | eye_4[:, :3, :3] = rot 126 | 127 | return eye_4 128 | 129 | 130 | def scale3d_mat(s_x, s_y, s_z): 131 | batch = s_x.shape[0] 132 | 133 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 134 | mat[:, 0, 0] = s_x 135 | mat[:, 1, 1] = s_y 136 | mat[:, 2, 2] = s_z 137 | 138 | return mat 139 | 140 | 141 | def luma_flip_mat(axis, i): 142 | batch = i.shape[0] 143 | 144 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 145 | axis = torch.tensor(axis + (0,)) 146 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1) 147 | 148 | return eye - flip 149 | 150 | 151 | def saturation_mat(axis, i): 152 | batch = i.shape[0] 153 | 154 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1) 155 | axis = torch.tensor(axis + (0,)) 156 | axis = torch.ger(axis, axis) 157 | saturate = axis + (eye - axis) * i.view(-1, 1, 1) 158 | 159 | return saturate 160 | 161 | 162 | def lognormal_sample(size, mean=0, std=1, device="cpu"): 163 | return torch.empty(size, device=device).log_normal_(mean=mean, std=std) 164 | 165 | 166 | def category_sample(size, categories, device="cpu"): 167 | category = torch.tensor(categories, device=device) 168 | sample = torch.randint(high=len(categories), size=(size,), device=device) 169 | 170 | return category[sample] 171 | 172 | 173 | def uniform_sample(size, low, high, device="cpu"): 174 | return torch.empty(size, device=device).uniform_(low, high) 175 | 176 | 177 | def normal_sample(size, mean=0, std=1, device="cpu"): 178 | return torch.empty(size, device=device).normal_(mean, std) 179 | 180 | 181 | def bernoulli_sample(size, p, device="cpu"): 182 | return torch.empty(size, device=device).bernoulli_(p) 183 | 184 | 185 | def random_mat_apply(p, transform, prev, eye, device="cpu"): 186 | size = transform.shape[0] 187 | select = bernoulli_sample(size, p, device=device).view(size, 1, 1) 188 | select_transform = select * transform + (1 - select) * eye 189 | 190 | return select_transform @ prev 191 | 192 | 193 | def sample_affine(p, size, height, width, device="cpu"): 194 | G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1) 195 | eye = G 196 | 197 | # flip 198 | param = category_sample(size, (0, 1)) 199 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device) 200 | G = random_mat_apply(p, Gc, G, eye, device=device) 201 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n') 202 | 203 | # 90 rotate 204 | param = category_sample(size, (0, 3)) 205 | Gc = rotate_mat(-math.pi / 2 * param, device=device) 206 | G = random_mat_apply(p, Gc, G, eye, device=device) 207 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n') 208 | 209 | # integer translate 210 | param = uniform_sample(size, -0.125, 0.125) 211 | param_height = torch.round(param * height) / height 212 | param_width = torch.round(param * width) / width 213 | Gc = translate_mat(param_width, param_height, device=device) 214 | G = random_mat_apply(p, Gc, G, eye, device=device) 215 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n') 216 | 217 | # isotropic scale 218 | param = lognormal_sample(size, std=0.2 * math.log(2)) 219 | Gc = scale_mat(param, param, device=device) 220 | G = random_mat_apply(p, Gc, G, eye, device=device) 221 | # print('isotropic scale', G, scale_mat(param, param), sep='\n') 222 | 223 | p_rot = 1 - math.sqrt(1 - p) 224 | 225 | # pre-rotate 226 | param = uniform_sample(size, -math.pi, math.pi) 227 | Gc = rotate_mat(-param, device=device) 228 | G = random_mat_apply(p_rot, Gc, G, eye, device=device) 229 | # print('pre-rotate', G, rotate_mat(-param), sep='\n') 230 | 231 | # anisotropic scale 232 | param = lognormal_sample(size, std=0.2 * math.log(2)) 233 | Gc = scale_mat(param, 1 / param, device=device) 234 | G = random_mat_apply(p, Gc, G, eye, device=device) 235 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n') 236 | 237 | # post-rotate 238 | param = uniform_sample(size, -math.pi, math.pi) 239 | Gc = rotate_mat(-param, device=device) 240 | G = random_mat_apply(p_rot, Gc, G, eye, device=device) 241 | # print('post-rotate', G, rotate_mat(-param), sep='\n') 242 | 243 | # fractional translate 244 | param = normal_sample(size, std=0.125) 245 | Gc = translate_mat(param, param, device=device) 246 | G = random_mat_apply(p, Gc, G, eye, device=device) 247 | # print('fractional translate', G, translate_mat(param, param), sep='\n') 248 | 249 | return G 250 | 251 | 252 | def sample_color(p, size): 253 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1) 254 | eye = C 255 | axis_val = 1 / math.sqrt(3) 256 | axis = (axis_val, axis_val, axis_val) 257 | 258 | # brightness 259 | param = normal_sample(size, std=0.2) 260 | Cc = translate3d_mat(param, param, param) 261 | C = random_mat_apply(p, Cc, C, eye) 262 | 263 | # contrast 264 | param = lognormal_sample(size, std=0.5 * math.log(2)) 265 | Cc = scale3d_mat(param, param, param) 266 | C = random_mat_apply(p, Cc, C, eye) 267 | 268 | # luma flip 269 | param = category_sample(size, (0, 1)) 270 | Cc = luma_flip_mat(axis, param) 271 | C = random_mat_apply(p, Cc, C, eye) 272 | 273 | # hue rotation 274 | param = uniform_sample(size, -math.pi, math.pi) 275 | Cc = rotate3d_mat(axis, param) 276 | C = random_mat_apply(p, Cc, C, eye) 277 | 278 | # saturation 279 | param = lognormal_sample(size, std=1 * math.log(2)) 280 | Cc = saturation_mat(axis, param) 281 | C = random_mat_apply(p, Cc, C, eye) 282 | 283 | return C 284 | 285 | 286 | def make_grid(shape, x0, x1, y0, y1, device): 287 | n, c, h, w = shape 288 | grid = torch.empty(n, h, w, 3, device=device) 289 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device) 290 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1) 291 | grid[:, :, :, 2] = 1 292 | 293 | return grid 294 | 295 | 296 | def affine_grid(grid, mat): 297 | n, h, w, _ = grid.shape 298 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2) 299 | 300 | 301 | def get_padding(G, height, width, kernel_size): 302 | device = G.device 303 | 304 | cx = (width - 1) / 2 305 | cy = (height - 1) / 2 306 | cp = torch.tensor( 307 | [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device 308 | ) 309 | cp = G @ cp.T 310 | 311 | pad_k = kernel_size // 4 312 | 313 | pad = cp[:, :2, :].permute(1, 0, 2).flatten(1) 314 | pad = torch.cat((-pad, pad)).max(1).values 315 | pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device) 316 | pad = pad.max(torch.tensor([0, 0] * 2, device=device)) 317 | pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device)) 318 | 319 | pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32) 320 | 321 | return pad_x1, pad_x2, pad_y1, pad_y2 322 | 323 | 324 | def try_sample_affine_and_pad(img, p, kernel_size, G=None): 325 | batch, _, height, width = img.shape 326 | 327 | G_try = G 328 | 329 | if G is None: 330 | G_try = torch.inverse(sample_affine(p, batch, height, width)) 331 | 332 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size) 333 | 334 | img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect") 335 | 336 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2) 337 | 338 | 339 | class GridSampleForward(autograd.Function): 340 | @staticmethod 341 | def forward(ctx, input, grid): 342 | out = F.grid_sample( 343 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False 344 | ) 345 | ctx.save_for_backward(input, grid) 346 | 347 | return out 348 | 349 | @staticmethod 350 | def backward(ctx, grad_output): 351 | input, grid = ctx.saved_tensors 352 | grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid) 353 | 354 | return grad_input, grad_grid 355 | 356 | 357 | class GridSampleBackward(autograd.Function): 358 | @staticmethod 359 | def forward(ctx, grad_output, input, grid): 360 | op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward") 361 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 362 | ctx.save_for_backward(grid) 363 | 364 | return grad_input, grad_grid 365 | 366 | @staticmethod 367 | def backward(ctx, grad_grad_input, grad_grad_grid): 368 | grid, = ctx.saved_tensors 369 | grad_grad_output = None 370 | 371 | if ctx.needs_input_grad[0]: 372 | grad_grad_output = GridSampleForward.apply(grad_grad_input, grid) 373 | 374 | return grad_grad_output, None, None 375 | 376 | 377 | grid_sample = GridSampleForward.apply 378 | 379 | 380 | def scale_mat_single(s_x, s_y): 381 | return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32) 382 | 383 | 384 | def translate_mat_single(t_x, t_y): 385 | return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32) 386 | 387 | 388 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6): 389 | kernel = antialiasing_kernel 390 | len_k = len(kernel) 391 | 392 | kernel = torch.as_tensor(kernel).to(img) 393 | # kernel = torch.ger(kernel, kernel).to(img) 394 | kernel_flip = torch.flip(kernel, (0,)) 395 | 396 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad( 397 | img, p, len_k, G 398 | ) 399 | 400 | G_inv = ( 401 | translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2) 402 | @ G 403 | ) 404 | up_pad = ( 405 | (len_k + 2 - 1) // 2, 406 | (len_k - 2) // 2, 407 | (len_k + 2 - 1) // 2, 408 | (len_k - 2) // 2, 409 | ) 410 | img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0)) 411 | img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:])) 412 | G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2) 413 | G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5) 414 | batch_size, channel, height, width = img.shape 415 | pad_k = len_k // 4 416 | shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2) 417 | G_inv = ( 418 | scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2]) 419 | @ G_inv 420 | @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2])) 421 | ) 422 | grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False) 423 | img_affine = grid_sample(img_2x, grid) 424 | d_p = -pad_k * 2 425 | down_pad = ( 426 | d_p + (len_k - 2 + 1) // 2, 427 | d_p + (len_k - 2) // 2, 428 | d_p + (len_k - 2 + 1) // 2, 429 | d_p + (len_k - 2) // 2, 430 | ) 431 | img_down = upfirdn2d( 432 | img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0) 433 | ) 434 | img_down = upfirdn2d( 435 | img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:]) 436 | ) 437 | 438 | return img_down, G 439 | 440 | 441 | def apply_color(img, mat): 442 | batch = img.shape[0] 443 | img = img.permute(0, 2, 3, 1) 444 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3) 445 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3) 446 | img = img @ mat_mul + mat_add 447 | img = img.permute(0, 3, 1, 2) 448 | 449 | return img 450 | 451 | 452 | def random_apply_color(img, p, C=None): 453 | if C is None: 454 | C = sample_color(p, img.shape[0]) 455 | 456 | img = apply_color(img, C.to(img)) 457 | 458 | return img, C 459 | 460 | 461 | def augment(img, p, transform_matrix=(None, None)): 462 | img, G = random_apply_affine(img, p, transform_matrix[0]) 463 | img, C = random_apply_color(img, p, transform_matrix[1]) 464 | 465 | return img, (G, C) 466 | -------------------------------------------------------------------------------- /source/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /source/stylegan2/op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /source/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | 61 | ctx.bias = bias is not None 62 | 63 | if bias is None: 64 | bias = empty 65 | 66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 67 | ctx.save_for_backward(out) 68 | ctx.negative_slope = negative_slope 69 | ctx.scale = scale 70 | 71 | return out 72 | 73 | @staticmethod 74 | def backward(ctx, grad_output): 75 | out, = ctx.saved_tensors 76 | 77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 79 | ) 80 | 81 | if not ctx.bias: 82 | grad_bias = None 83 | 84 | return grad_input, grad_bias, None, None 85 | 86 | 87 | class FusedLeakyReLU(nn.Module): 88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 89 | super().__init__() 90 | 91 | if bias: 92 | self.bias = nn.Parameter(torch.zeros(channel)) 93 | 94 | else: 95 | self.bias = None 96 | 97 | self.negative_slope = negative_slope 98 | self.scale = scale 99 | 100 | def forward(self, input): 101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 102 | 103 | 104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 105 | if input.device.type == "cpu": 106 | if bias is not None: 107 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 108 | return ( 109 | F.leaky_relu( 110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 111 | ) 112 | * scale 113 | ) 114 | 115 | else: 116 | return F.leaky_relu(input, negative_slope=0.2) * scale 117 | 118 | else: 119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 120 | -------------------------------------------------------------------------------- /source/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /source/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /source/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /source/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /source/stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /source/stylegan2/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | import os 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format="jpeg", quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple( 24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 25 | ): 26 | imgs = [] 27 | 28 | for size in sizes: 29 | imgs.append(resize_and_convert(img, size, resample, quality)) 30 | 31 | return imgs 32 | 33 | 34 | def resize_worker(img_file, sizes, resample): 35 | i, file = img_file 36 | img = Image.open(file) 37 | img = img.convert("RGB") 38 | out = resize_multiple(img, sizes=sizes, resample=resample) 39 | 40 | return i, out 41 | 42 | 43 | def prepare( 44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS 45 | ): 46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 47 | 48 | files = sorted(dataset.imgs, key=lambda x: x[0]) 49 | files = [(i, file) for i, (file, label) in enumerate(files)] 50 | total = 0 51 | 52 | with multiprocessing.Pool(n_worker) as pool: 53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 54 | for size, img in zip(sizes, imgs): 55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 56 | 57 | with env.begin(write=True) as txn: 58 | txn.put(key, img) 59 | 60 | total += 1 61 | 62 | with env.begin(write=True) as txn: 63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Preprocess images for model training") 68 | parser.add_argument("path", type=str, help="path to the image dataset") 69 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset") 70 | parser.add_argument( 71 | "--size", 72 | type=str, 73 | default="128,256,512,1024", 74 | help="resolutions of images for the dataset", 75 | ) 76 | parser.add_argument( 77 | "--n_worker", 78 | type=int, 79 | default=16, 80 | help="number of workers for preparing dataset", 81 | ) 82 | parser.add_argument( 83 | "--resample", 84 | type=str, 85 | default="lanczos", 86 | help="resampling methods for resizing images", 87 | ) 88 | 89 | args = parser.parse_args() 90 | if not os.path.exists(str(args.out)): 91 | os.makedirs(str(args.out)) 92 | 93 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 94 | resample = resample_map[args.resample] 95 | 96 | sizes = [int(s.strip()) for s in args.size.split(",")] 97 | 98 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 99 | 100 | imgset = datasets.ImageFolder(args.path) 101 | 102 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 103 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 104 | -------------------------------------------------------------------------------- /source/stylegan2/style_blend.py: -------------------------------------------------------------------------------- 1 | # model blending technique 2 | import os 3 | import cv2 as cv 4 | import torch 5 | from model import Generator 6 | import math 7 | import argparse 8 | 9 | def extract_conv_names(model): 10 | model = list(model.keys()) 11 | conv_name = [] 12 | resolutions = [4*2**x for x in range(9)] 13 | level_names = [["Conv0_up", "Const"], ["Conv1", "ToRGB"]] 14 | 15 | 16 | def blend_models(model_1, model_2, resolution, level, blend_width=None): 17 | resolutions = [4 * 2 ** i for i in range(7)] 18 | mid = resolutions.index(resolution) 19 | 20 | device = "cuda" 21 | 22 | size = 256 23 | latent = 512 24 | n_mlp = 8 25 | channel_multiplier =2 26 | G_1 = Generator( 27 | size, latent, n_mlp, channel_multiplier=channel_multiplier 28 | ).to(device) 29 | ckpt_ffhq = torch.load(model_1, map_location=lambda storage, loc: storage) 30 | G_1.load_state_dict(ckpt_ffhq["g"], strict=False) 31 | 32 | 33 | G_2 = Generator( 34 | size, latent, n_mlp, channel_multiplier=channel_multiplier 35 | ).to(device) 36 | ckpt_toon = torch.load(model_2) 37 | G_2.load_state_dict(ckpt_toon["g_ema"]) 38 | 39 | 40 | 41 | # G_1 = stylegan2.models.load(model_1) 42 | # G_2 = stylegan2.models.load(model_2) 43 | model_1_state_dict = G_1.state_dict() 44 | model_2_state_dict = G_2.state_dict() 45 | assert(model_1_state_dict.keys() == model_2_state_dict.keys()) 46 | G_out = G_1.clone() 47 | 48 | layers = [] 49 | ys = [] 50 | for k, v in model_1_state_dict.items(): 51 | if k.startswith('convs.'): 52 | pos = int(k[len('convs.')]) 53 | x = pos - mid 54 | if blend_width: 55 | exponent = -x / blend_width 56 | y = 1 / (1 + math.exp(exponent)) 57 | else: 58 | y = 1 if x > 0 else 0 59 | 60 | layers.append(k) 61 | ys.append(y) 62 | elif k.startswith('to_rgbs.'): 63 | pos = int(k[len('to_rgbs.')]) 64 | x = pos - mid 65 | if blend_width: 66 | exponent = -x / blend_width 67 | y = 1 / (1 + math.exp(exponent)) 68 | else: 69 | y = 1 if x > 0 else 0 70 | layers.append(k) 71 | ys.append(y) 72 | out_state = G_out.state_dict() 73 | for y, layer in zip(ys, layers): 74 | out_state[layer] = y * model_2_state_dict[layer] + \ 75 | (1 - y) * model_1_state_dict[layer] 76 | print('blend layer %s'%str(y)) 77 | G_out.load_state_dict(out_state) 78 | return G_out 79 | 80 | 81 | def blend_models_2(model_1, model_2, resolution, level, blend_width=None): 82 | # resolution = f"{resolution}x{resolution}" 83 | resolutions = [4 * 2 ** i for i in range(7)] 84 | mid = [resolutions.index(r) for r in resolution] 85 | 86 | G_1 = stylegan2.models.load(model_1) 87 | G_2 = stylegan2.models.load(model_2) 88 | model_1_state_dict = G_1.state_dict() 89 | model_2_state_dict = G_2.state_dict() 90 | assert(model_1_state_dict.keys() == model_2_state_dict.keys()) 91 | G_out = G_1.clone() 92 | 93 | layers = [] 94 | ys = [] 95 | for k, v in model_1_state_dict.items(): 96 | if k.startswith('G_synthesis.conv_blocks.'): 97 | pos = int(k[len('G_synthesis.conv_blocks.')]) 98 | y = 0 if pos in mid else 1 99 | layers.append(k) 100 | ys.append(y) 101 | elif k.startswith('G_synthesis.to_data_layers.'): 102 | pos = int(k[len('G_synthesis.to_data_layers.')]) 103 | y = 0 if pos in mid else 1 104 | layers.append(k) 105 | ys.append(y) 106 | # print(ys, layers) 107 | out_state = G_out.state_dict() 108 | for y, layer in zip(ys, layers): 109 | out_state[layer] = y * model_2_state_dict[layer] + \ 110 | (1 - y) * model_1_state_dict[layer] 111 | G_out.load_state_dict(out_state) 112 | return G_out 113 | 114 | 115 | def main(name): 116 | 117 | resolution = 4 118 | 119 | model_name = '001000.pt' 120 | 121 | G_out = blend_models("pretrained_models/stylegan2-ffhq-config-f-256-550000.pt", 122 | "face_generation/experiment_stylegan/"+name+"/models/"+model_name, 123 | resolution, 124 | None) 125 | # G_out.save('G_blend.pth') 126 | outdir = os.path.join('face_generation/experiment_stylegan',name,'models_blend') 127 | if not os.path.exists(outdir): 128 | os.makedirs(outdir) 129 | 130 | outpath = os.path.join(outdir, 'G_blend_'+str(model_name[:-3])+'_'+ str(resolution)+'.pt') 131 | torch.save( 132 | { 133 | "g_ema": G_out.state_dict(), 134 | }, 135 | # 'G_blend_570000_16.pth', 136 | outpath 137 | ) 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser(description="style blender") 142 | parser.add_argument('--name', type=str, default='') 143 | args = parser.parse_args() 144 | print('model name:%s'%args.name) 145 | main(args.name) -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def resize_size(image, size=720): 8 | h, w, c = np.shape(image) 9 | if min(h, w) > size: 10 | if h > w: 11 | h, w = int(size * h / w), size 12 | else: 13 | h, w = size, int(size * w / h) 14 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) 15 | return image 16 | 17 | 18 | def padTo16x(image): 19 | h, w, c = np.shape(image) 20 | if h % 16 == 0 and w % 16 == 0: 21 | return image, h, w 22 | nh, nw = (h // 16 + 1) * 16, (w // 16 + 1) * 16 23 | img_new = np.ones((nh, nw, 3), np.uint8) * 255 24 | img_new[:h, :w, :] = image 25 | 26 | return img_new, h, w 27 | 28 | 29 | def get_f5p(landmarks, np_img): 30 | eye_left = find_pupil(landmarks[36:41], np_img) 31 | eye_right = find_pupil(landmarks[42:47], np_img) 32 | if eye_left is None or eye_right is None: 33 | print('cannot find 5 points with find_puil, used mean instead.!') 34 | eye_left = landmarks[36:41].mean(axis=0) 35 | eye_right = landmarks[42:47].mean(axis=0) 36 | nose = landmarks[30] 37 | mouth_left = landmarks[48] 38 | mouth_right = landmarks[54] 39 | f5p = [[eye_left[0], eye_left[1]], [eye_right[0], eye_right[1]], 40 | [nose[0], nose[1]], [mouth_left[0], mouth_left[1]], 41 | [mouth_right[0], mouth_right[1]]] 42 | return f5p 43 | 44 | 45 | def find_pupil(landmarks, np_img): 46 | h, w, _ = np_img.shape 47 | xmax = int(landmarks[:, 0].max()) 48 | xmin = int(landmarks[:, 0].min()) 49 | ymax = int(landmarks[:, 1].max()) 50 | ymin = int(landmarks[:, 1].min()) 51 | 52 | if ymin >= ymax or xmin >= xmax or ymin < 0 or xmin < 0 or ymax > h or xmax > w: 53 | return None 54 | eye_img_bgr = np_img[ymin:ymax, xmin:xmax, :] 55 | eye_img = cv2.cvtColor(eye_img_bgr, cv2.COLOR_BGR2GRAY) 56 | eye_img = cv2.equalizeHist(eye_img) 57 | n_marks = landmarks - np.array([xmin, ymin]).reshape([1, 2]) 58 | eye_mask = cv2.fillConvexPoly( 59 | np.zeros_like(eye_img), n_marks.astype(np.int32), 1) 60 | ret, thresh = cv2.threshold(eye_img, 100, 255, 61 | cv2.THRESH_BINARY | cv2.THRESH_OTSU) 62 | thresh = (1 - thresh / 255.) * eye_mask 63 | cnt = 0 64 | xm = [] 65 | ym = [] 66 | for i in range(thresh.shape[0]): 67 | for j in range(thresh.shape[1]): 68 | if thresh[i, j] > 0.5: 69 | xm.append(j) 70 | ym.append(i) 71 | cnt += 1 72 | if cnt != 0: 73 | xm.sort() 74 | ym.sort() 75 | xm = xm[cnt // 2] 76 | ym = ym[cnt // 2] 77 | else: 78 | xm = thresh.shape[1] / 2 79 | ym = thresh.shape[0] / 2 80 | 81 | return xm + xmin, ym + ymin 82 | 83 | 84 | def all_file(file_dir): 85 | L = [] 86 | for root, dirs, files in os.walk(file_dir): 87 | for file in files: 88 | extend = os.path.splitext(file)[1] 89 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg': 90 | L.append(os.path.join(root, file)) 91 | return L 92 | 93 | def initialize_mask(box_width): 94 | h, w = [box_width, box_width] 95 | mask = np.zeros((h, w), np.uint8) 96 | 97 | center = (int(w / 2), int(h / 2)) 98 | axes = (int(w * 0.4), int(h * 0.49)) 99 | mask = cv2.ellipse(img=mask, center=center, axes=axes, angle=0, startAngle=0, endAngle=360, color=(1), 100 | thickness=-1) 101 | mask = cv2.distanceTransform(mask, cv2.DIST_L2, 3) 102 | 103 | maxn = max(w, h) * 0.15 104 | mask[(mask < 255) & (mask > 0)] = mask[(mask < 255) & (mask > 0)] / maxn 105 | mask = np.clip(mask, 0, 1) 106 | 107 | return mask.astype(float) 108 | -------------------------------------------------------------------------------- /train_localtoon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from modelscope.trainers.cv import CartoonTranslationTrainer 4 | 5 | 6 | def main(args): 7 | 8 | data_photo = os.path.join(args.data_dir, 'face_photo') 9 | data_cartoon = os.path.join(args.data_dir, 'face_cartoon') 10 | 11 | style = args.style 12 | if style == "anime": 13 | style = "" 14 | else: 15 | style = '-' + style 16 | model_id = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models' 17 | 18 | max_steps = 300000 19 | trainer = CartoonTranslationTrainer( 20 | model=model_id, 21 | work_dir=args.work_dir, 22 | photo=data_photo, 23 | cartoon=data_cartoon, 24 | max_steps=max_steps) 25 | trainer.train() 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser(description="process remove bg result") 29 | parser.add_argument("--data_dir", type=str, default='', help="Path to training images.") 30 | parser.add_argument("--work_dir", type=str, default='', help="Path to save results.") 31 | parser.add_argument("--style", type=str, default='anime', help="resume training from similar style.") 32 | 33 | args = parser.parse_args() 34 | 35 | main(args) 36 | 37 | --------------------------------------------------------------------------------