├── .DS_Store
├── .idea
├── .gitignore
├── DCT-Net.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── webServers.xml
├── LICENSE
├── README.md
├── assets
├── .DS_Store
├── demo.gif
├── sim1.png
├── sim2.png
├── sim3.png
├── sim4.png
├── sim5.png
├── sim_3d.png
├── sim_anime.png
├── sim_artstyle.png
├── sim_design.png
├── sim_handdrawn.png
├── sim_illu.png
├── sim_sketch.png
├── styles.png
└── video.gif
├── celeb.txt
├── download.py
├── export.py
├── extract_align_faces.py
├── generate_data.py
├── input.mp4
├── input.png
├── multi-style
├── download.py
├── run.py
└── run_sdk.py
├── notebooks
├── .DS_Store
├── fastTrain.ipynb
└── inference.ipynb
├── prepare_data.sh
├── run.py
├── run_sdk.py
├── run_vid.py
├── source
├── .DS_Store
├── __init__.py
├── cartoonize.py
├── facelib
│ ├── LICENSE
│ ├── LK
│ │ ├── __init__.py
│ │ └── lk.py
│ ├── __init__.py
│ ├── config.py
│ ├── face_detector.py
│ ├── face_landmark.py
│ └── facer.py
├── image_flip_agument_parallel.py
├── image_rotation_agument_parallel_flat.py
├── image_scale_agument_parallel_flat.py
├── mtcnn_pytorch
│ ├── .DS_Store
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ └── src
│ │ ├── __init__.py
│ │ ├── align_trans.py
│ │ └── matlab_cp2tform.py
├── stylegan2
│ ├── .DS_Store
│ ├── __init__.py
│ ├── config
│ │ ├── .DS_Store
│ │ ├── conf_server_test_blend_shell.json
│ │ └── conf_server_train_condition_shell.json
│ ├── criteria
│ │ ├── __init__.py
│ │ ├── helpers.py
│ │ ├── id_loss.py
│ │ ├── lpips
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ ├── networks.py
│ │ │ └── utils.py
│ │ ├── moco_loss.py
│ │ ├── model_irse.py
│ │ ├── vgg.py
│ │ ├── vgg_loss.py
│ │ └── w_norm.py
│ ├── dataset.py
│ ├── distributed.py
│ ├── generate_blendmodel.py
│ ├── model.py
│ ├── noise.pt
│ ├── non_leaking.py
│ ├── op
│ │ ├── __init__.py
│ │ ├── conv2d_gradfix.py
│ │ ├── fused_act.py
│ │ ├── fused_bias_act.cpp
│ │ ├── fused_bias_act_kernel.cu
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── prepare_data.py
│ ├── style_blend.py
│ └── train_condition.py
└── utils.py
└── train_localtoon.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/DCT-Net.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DCT-Net: Domain-Calibrated Translation for Portrait Stylization
2 |
3 | ### [Project page](https://menyifang.github.io/projects/DCTNet/DCTNet.html) | [Video](https://www.youtube.com/watch?v=Y8BrfOjXYQM) | [Paper](https://arxiv.org/abs/2207.02426)
4 |
5 | Official implementation of DCT-Net for Full-body Portrait Stylization.
6 |
7 |
8 | > [**DCT-Net: Domain-Calibrated Translation for Portrait Stylization**](arxiv_url_coming_soon),
9 | > [Yifang Men](https://menyifang.github.io/)1, Yuan Yao1, Miaomiao Cui1, [Zhouhui Lian](https://www.icst.pku.edu.cn/zlian/)2, Xuansong Xie1,
10 | > _1[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Beijing, China_
11 | > _2[Wangxuan Institute of Computer Technology, Peking University](https://www.icst.pku.edu.cn/), China_
12 | > In: SIGGRAPH 2022 (**TOG**)
13 | > *[arXiv preprint](https://arxiv.org/abs/2207.02426)*
14 |
15 |
16 | [](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net)
17 |
18 |
19 | ## Demo
20 | 
21 |
22 |
23 | ## News
24 |
25 | (2023-03-14) The training guidance has been released, train DCT-Net with your own style data.
26 |
27 | (2023-02-20) Two new style pre-trained models (design, illustration) trained with combined DCT-Net and Stable-Diffusion are provided. The training guidance will be released soon.
28 |
29 | (2022-10-09) The multi-style pre-trained models (3d, handdrawn, sketch, artstyle) and usage are available now.
30 |
31 | (2022-08-08) The pertained model and infer code of 'anime' style is available now. More styles coming soon.
32 |
33 | (2022-08-08) cartoon function can be directly call from pythonSDK.
34 |
35 | (2022-07-07) The paper is available now at arxiv(https://arxiv.org/abs/2207.02426).
36 |
37 |
38 | ## Web Demo
39 | - Integrated into [Colab notebook](https://colab.research.google.com/github/menyifang/DCT-Net/blob/main/notebooks/inference.ipynb). Try out the colab demo.
40 |
41 | - Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [](https://huggingface.co/spaces/SIGGRAPH2022/DCT-Net)
42 |
43 | - [Chinese version] Integrated into [ModelScope](https://modelscope.cn/#/models). Try out the Web Demo [](https://modelscope.cn/#/models/damo/cv_unet_person-image-cartoon_compound-models/summary)
45 |
46 | ## Requirements
47 | * python 3
48 | * tensorflow (>=1.14, training only support tf1.x)
49 | * easydict
50 | * numpy
51 | * both CPU/GPU are supported
52 |
53 |
54 | ## Quick Start
55 |
56 |
57 |
58 | ```bash
59 | git clone https://github.com/menyifang/DCT-Net.git
60 | cd DCT-Net
61 |
62 | ```
63 |
64 | ### Installation
65 | ```bash
66 | conda create -n dctnet python=3.7
67 | conda activate dctnet
68 | pip install --upgrade tensorflow-gpu==1.15 # GPU support, use tensorflow for CPU only
69 | pip install "modelscope[cv]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
70 | pip install "modelscope[multi-modal]==1.3.2" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
71 | ```
72 |
73 | ### Downloads
74 |
75 | | [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary)| [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary)| [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary)|
76 | |:--:|:--:|:--:|:--:|:--:|
77 | | [anime](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models/summary) | [3d](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary) | [handdrawn](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-handdrawn_compound-models/summary) | [sketch](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sketch_compound-models/summary) | [artstyle](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-artstyle_compound-models/summary) |
78 |
79 | | [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [
](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary) |
80 | |:--:|:--:|
81 | | [design](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-design_compound-models/summary) | [illustration](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-sd-illustration_compound-models/summary)
82 |
83 | Pre-trained models in different styles can be downloaded by
84 | ```bash
85 | python download.py
86 | ```
87 |
88 | ### Inference
89 |
90 | - from python SDK
91 | ```bash
92 | python run_sdk.py
93 | ```
94 |
95 | - from source code
96 | ```bash
97 | python run.py
98 | ```
99 |
100 | ### Video cartoonization
101 |
102 | 
103 |
104 | video can be directly processed as image sequences, style choice [option: anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration]
105 |
106 | ```bash
107 | python run_vid.py --style anime
108 | ```
109 |
110 |
111 | ## Training
112 |
113 |
114 | ### Data preparation
115 | ```
116 | face_photo: face dataset such as [FFHQ](https://github.com/NVlabs/ffhq-dataset) or other collected real faces.
117 | face_cartoon: 100-300 cartoon face images in a specific style, which can be self-collected or synthsized with generative models.
118 | ```
119 | Due to the copyrighe issues, we can not provide collected cartoon exemplar for training. You can produce cartoon exemplars with the style-finetuned Stable-Diffusion (SD) models, which can be downloaded from modelscope or huggingface hubs.
120 |
121 | The effects of some style-finetune SD models are as follows:
122 |
123 | | [
](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [
](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor) | [
](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary)| [
](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary)| [
](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary)|
124 | |:--:|:--:|:--:|:--:|:--:|
125 | | [design](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_design/summary) | [watercolor](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_watercolor/summary) | [illustration](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_illustration/summary) | [clipart](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_clipart/summary) | [flat](https://modelscope.cn/models/damo/cv_cartoon_stable_diffusion_flat/summary) |
126 |
127 | - Generate stylized data, style choice [option: clipart, design, illustration, watercolor, flat]
128 | ```bash
129 | python generate_data.py --style clipart
130 | ```
131 |
132 | - preprocess
133 |
134 | extract aligned faces from raw style images:
135 | ```bash
136 | python extract_align_faces.py --src_dir 'data/raw_style_data'
137 | ```
138 |
139 | - train content calibration network
140 |
141 | install environment required by [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch)
142 | ```bash
143 | cd source/stylegan2
144 | python prepare_data.py '../../data/face_cartoon' --size 256 --out '../../data/stylegan2/traindata'
145 | python train_condition.py --name 'ffhq_style_s256' --path '../../data/stylegan2/traindata' --config config/conf_server_train_condition_shell.json
146 | ```
147 |
148 | after training, generated content calibrated samples via:
149 | ```bash
150 | python style_blend.py --name 'ffhq_style_s256'
151 | python generate_blendmodel.py --name 'ffhq_style_s256' --save_dir '../../data/face_cartoon/syn_style_faces'
152 | ```
153 |
154 | - geometry calibration
155 |
156 | run geometry calibration for both photo and cartoon:
157 | ```bash
158 | cd source
159 | python image_flip_agument_parallel.py --data_dir '../data/face_cartoon'
160 | python image_scale_agument_parallel_flat.py --data_dir '../data/face_cartoon'
161 | python image_rotation_agument_parallel_flat.py --data_dir '../data/face_cartoon'
162 | ```
163 |
164 | - train texture translator
165 |
166 | The dataset structure is recommended as:
167 | ```
168 | +—data
169 | | +—face_photo
170 | | +—face_cartoon
171 | ```
172 | resume training from pretrained model in similar style,
173 |
174 | style can be chosen from 'anime, 3d, handdrawn, sketch, artstyle, sd-design, sd-illustration'
175 |
176 | ```bash
177 | python train_localtoon.py --data_dir PATH_TO_YOU_DATA --work_dir PATH_SAVE --style anime
178 | ```
179 |
180 |
181 |
182 |
183 | ## Acknowledgments
184 |
185 | Face detector and aligner are adapted from [Peppa_Pig_Face_Engine](https://github.com/610265158/Peppa_Pig_Face_Engine
186 | ) and [InsightFace](https://github.com/TreB1eN/InsightFace_Pytorch).
187 |
188 |
189 |
190 | ## Citation
191 |
192 | If you find this code useful for your research, please use the following BibTeX entry.
193 |
194 | ```bibtex
195 | @inproceedings{men2022dct,
196 | title={DCT-Net: Domain-Calibrated Translation for Portrait Stylization},
197 | author={Men, Yifang and Yao, Yuan and Cui, Miaomiao and Lian, Zhouhui and Xie, Xuansong},
198 | journal={ACM Transactions on Graphics (TOG)},
199 | volume={41},
200 | number={4},
201 | pages={1--9},
202 | year={2022},
203 | publisher={ACM New York, NY, USA}
204 | }
205 | ```
206 |
207 |
208 |
209 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/demo.gif
--------------------------------------------------------------------------------
/assets/sim1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim1.png
--------------------------------------------------------------------------------
/assets/sim2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim2.png
--------------------------------------------------------------------------------
/assets/sim3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim3.png
--------------------------------------------------------------------------------
/assets/sim4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim4.png
--------------------------------------------------------------------------------
/assets/sim5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim5.png
--------------------------------------------------------------------------------
/assets/sim_3d.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_3d.png
--------------------------------------------------------------------------------
/assets/sim_anime.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_anime.png
--------------------------------------------------------------------------------
/assets/sim_artstyle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_artstyle.png
--------------------------------------------------------------------------------
/assets/sim_design.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_design.png
--------------------------------------------------------------------------------
/assets/sim_handdrawn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_handdrawn.png
--------------------------------------------------------------------------------
/assets/sim_illu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_illu.png
--------------------------------------------------------------------------------
/assets/sim_sketch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/sim_sketch.png
--------------------------------------------------------------------------------
/assets/styles.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/styles.png
--------------------------------------------------------------------------------
/assets/video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/assets/video.gif
--------------------------------------------------------------------------------
/celeb.txt:
--------------------------------------------------------------------------------
1 | Beyoncé
2 | Taylor Swift
3 | Rihanna
4 | Ariana Grande
5 | Angelina Jolie
6 | Jennifer Lawrence
7 | Emma Watson
8 | Katy Perry
9 | Scarlett Johansson
10 | Lady Gaga
11 | Gal Gadot
12 | Selena Gomez
13 | Sandra Bullock
14 | Ellen DeGeneres
15 | Mila Kunis
16 | Jennifer Aniston
17 | Margot Robbie
18 | Blake Lively
19 | Miley Cyrus
20 | Charlize Theron
21 | Emma Stone
22 | Sofia Vergara
23 | Halle Berry
24 | Zendaya
25 | Dwayne Johnson
26 | Lady Amelia Windsor
27 | Brie Larson
28 | Adele
29 | Janelle Monae
30 | Shakira
31 | Priyanka Chopra
32 | Betty White
33 | Nina Dobrev
34 | Meghan Markle
35 | Lupita Nyong'o
36 | Emilia Clarke
37 | Kate Middleton
38 | Zooey Deschanel
39 | Sienna Miller
40 | Christina Aguilera
41 | Kate Hudson
42 | Gina Rodriguez
43 | Cardi B
44 | Yara Shahidi
45 | Michelle Obama
46 | Kourtney Kardashian
47 | Portia de Rossi
48 | Kerry Washington
49 | Jada Pinkett Smith
50 | Lucy Liu
51 | Victoria Beckham
52 | Gwyneth Paltrow
53 | Kim Kardashian
54 | Ellen Page
55 | Kerry Washington
56 | Maya Rudolph
57 | Alicia Keys
58 | Oprah Winfrey
59 | Tracee Ellis Ross
60 | Jennifer Lopez
61 | Rachel McAdams
62 | Pink
63 | Cameron Diaz
64 | Lily Collins
65 | Anne Hathaway
66 | Tyra Banks
67 | Ashley Tisdale
68 | Amanda Seyfried
69 | Jessica Alba
70 | Demi Lovato
71 | Keira Knightley
72 | Bella Hadid
73 | Kendall Jenner
74 | Emma Roberts
75 | Vanessa Hudgens
76 | Sofia Richie
77 | Hailey Bieber
78 | Gisele Bündchen
79 | Taylor Hill
80 | Kiki Layne
81 | Cate Blanchett
82 | Kate Winslet
83 | Gal Gadot
84 | Salma Hayek
85 | Julia Roberts
86 | Mariah Carey
87 | Scarlett Johansson
88 | Rosie Huntington-Whitely
89 | Marjane Satrapi
90 | Halle Berry
91 | Mariah Carey
92 | Selena Gomez
93 | Emma Watson
94 | Jennifer Aniston
95 | Rihanna
96 | Blake Lively
97 | Ariana Grande
98 | Angelina Jolie
99 | Lady Gaga
100 | Taylor Swift
101 |
102 | Robert Downey Jr.
103 | Tom Cruise
104 | George Clooney
105 | Brad Pitt
106 | Dwayne Johnson
107 | Leonardo DiCaprio
108 | Will Smith
109 | Johnny Depp
110 | Chris Evans
111 | Ryan Reynolds
112 | Tom Hanks
113 | Matt Damon
114 | Denzel Washington
115 | Hugh Jackman
116 | Chris Hemsworth
117 | Chris Pratt
118 | Idris Elba
119 | Daniel Craig
120 | Samuel L. Jackson
121 | Jeremy Renner
122 | Chris Pine
123 | Robert Pattinson
124 | Sebastian Stan
125 | Benedict Cumberbatch
126 | Paul Rudd
127 | Mark Wahlberg
128 | Zac Efron
129 | Jason Statham
130 | Michael Fassbender
131 | Joel Kinnaman
132 | Keanu Reeves
133 | Scarlett Johansson
134 | Vin Diesel
135 | Angelina Jolie
136 | Emma Watson
137 | Jennifer Lawrence
138 | Gal Gadot
139 | Margot Robbie
140 | Brie Larson
141 | Sofia Vergara
142 | Mila Kunis
143 | Emily Blunt
144 | Sandra Bullock
145 | Kate Winslet
146 | Nicole Kidman
147 | Charlize Theron
148 | Anne Hathaway
149 | Cate Blanchett
150 | Emma Stone
151 | Lupita Nyong'o
152 | Jennifer Aniston
153 | Halle Berry
154 | Rihanna
155 | Lady Gaga
156 | Beyoncé
157 | Taylor Swift
158 | Miley Cyrus
159 | Ariana Grande
160 | Meghan Markle
161 | Kate Middleton
162 | Angelina Jolie
163 | Jennifer Lopez
164 | Shakira
165 | Katy Perry
166 | Lady Gaga
167 | Britney Spears
168 | Adele
169 | Mariah Carey
170 | Madonna
171 | Janet Jackson
172 | Whitney Houston
173 | Tina Turner
174 | Celine Dion
175 | Barbra Streisand
176 | Cher
177 | Gloria Estefan
178 | Diana Ross
179 | Julie Andrews
180 | Liza Minnelli
181 | Bette Midler
182 | Elton John
183 | Freddie Mercury
184 | Paul McCartney
185 | Elvis Presley
186 | Michael Jackson
187 | Prince
188 | Madonna
189 | Mariah Carey
190 | Janet Jackson
191 | Whitney Houston
192 | Tina Turner
193 | Celine Dion
194 | Barbra Streisand
195 | Cher
196 | Gloria Estefan
197 | Diana Ross
198 | Julie Andrews
199 | Liza Minnelli
200 | Bette Midler
201 | Elton John
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | from modelscope.hub.snapshot_download import snapshot_download
2 | # pre-trained models in different style
3 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.')
4 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-3d_compound-models', cache_dir='.')
5 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-handdrawn_compound-models', cache_dir='.')
6 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sketch_compound-models', cache_dir='.')
7 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-artstyle_compound-models', cache_dir='.')
8 |
9 | # pre-trained models trained with DCT-Net + Stable-Diffusion
10 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sd-design_compound-models', revision='v1.0.0', cache_dir='.')
11 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sd-illustration_compound-models', revision='v1.0.0', cache_dir='.')
12 |
13 |
14 |
--------------------------------------------------------------------------------
/export.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | import cv2
5 |
6 | from modelscope.exporters.cv import CartoonTranslationExporter
7 | from modelscope.msdatasets import MsDataset
8 | from modelscope.outputs import OutputKeys
9 | from modelscope.pipelines import pipeline
10 | from modelscope.pipelines.base import Pipeline
11 | from modelscope.trainers.cv import CartoonTranslationTrainer
12 | from modelscope.utils.constant import Tasks
13 | from modelscope.utils.test_utils import test_level
14 |
15 |
16 | class TestImagePortraitStylizationTrainer(unittest.TestCase):
17 |
18 | def setUp(self) -> None:
19 | self.task = Tasks.image_portrait_stylization
20 | self.test_image = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_cartoon.png'
21 |
22 | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
23 | def test_run_with_model_name(self):
24 | model_id = 'damo/cv_unet_person-image-cartoon_compound-models'
25 |
26 | data_dir = MsDataset.load(
27 | 'dctnet_train_clipart_mini_ms',
28 | namespace='menyifang',
29 | split='train').config_kwargs['split_config']['train']
30 |
31 | data_photo = os.path.join(data_dir, 'face_photo')
32 | data_cartoon = os.path.join(data_dir, 'face_cartoon')
33 | work_dir = 'exp_localtoon'
34 | max_steps = 10
35 | trainer = CartoonTranslationTrainer(
36 | model=model_id,
37 | work_dir=work_dir,
38 | photo=data_photo,
39 | cartoon=data_cartoon,
40 | max_steps=max_steps)
41 | trainer.train()
42 |
43 | # export pb file
44 | ckpt_path = os.path.join(work_dir, 'saved_models', 'model-' + str(0))
45 | pb_path = os.path.join(trainer.model_dir, 'cartoon_h.pb')
46 | exporter = CartoonTranslationExporter()
47 | exporter.export_frozen_graph_def(
48 | ckpt_path=ckpt_path, frozen_graph_path=pb_path)
49 |
50 | # infer with pb file
51 | self.pipeline_person_image_cartoon(trainer.model_dir)
52 |
53 | def pipeline_person_image_cartoon(self, model_dir):
54 | pipeline_cartoon = pipeline(task=self.task, model=model_dir)
55 | result = pipeline_cartoon(input=self.test_image)
56 | if result is not None:
57 | cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
58 | print(f'Output written to {os.path.abspath("result.png")}')
59 |
60 |
61 | if __name__ == '__main__':
62 | unittest.main()
--------------------------------------------------------------------------------
/extract_align_faces.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | import argparse
5 | from source.facelib.facer import FaceAna
6 | import source.utils as utils
7 | from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points
8 | from modelscope.hub.snapshot_download import snapshot_download
9 |
10 | class FaceProcesser:
11 | def __init__(self, dataroot, crop_size = 256, max_face = 1):
12 | self.max_face = max_face
13 | self.crop_size = crop_size
14 | self.facer = FaceAna(dataroot)
15 |
16 | def filter_face(self, lm, crop_size):
17 | a = max(lm[:, 0])-min(lm[:, 0])
18 | b = max(lm[:, 1])-min(lm[:, 1])
19 | # print("a:%d, b:%d"%(a,b))
20 | if max(a, b)0:
48 | continue
49 |
50 | if self.filter_face(landmark, self.crop_size)==0:
51 | print("filtered!")
52 | continue
53 |
54 | f5p = utils.get_f5p(landmark, img_bgr)
55 | # face alignment
56 | warped_face, _ = warp_and_crop_face(
57 | img_bgr,
58 | f5p,
59 | ratio=0.75,
60 | reference_pts=get_reference_facial_points(default_square=True),
61 | crop_size=(self.crop_size, self.crop_size),
62 | return_trans_inv=True)
63 |
64 | warped_faces.append(warped_face)
65 | i = i+1
66 |
67 |
68 | return warped_faces
69 |
70 |
71 |
72 |
73 | if __name__ == "__main__":
74 |
75 |
76 | parser = argparse.ArgumentParser(description="process remove bg result")
77 | parser.add_argument("--src_dir", type=str, default='', help="Path to src images.")
78 | parser.add_argument("--save_dir", type=str, default='', help="Path to save images.")
79 | parser.add_argument("--crop_size", type=int, default=256)
80 | parser.add_argument("--max_face", type=int, default=1)
81 | parser.add_argument("--overwrite", type=int, default=1)
82 | args = parser.parse_args()
83 | args.save_dir = os.path.dirname(args.src_dir) + '/face_cartoon/raw_style_faces'
84 |
85 | crop_size = args.crop_size
86 | max_face = args.max_face
87 | overwrite = args.overwrite
88 |
89 | # model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.')
90 | # print('model assets saved to %s'%model_dir)
91 | model_dir = 'damo/cv_unet_person-image-cartoon_compound-models'
92 |
93 | processer = FaceProcesser(dataroot=model_dir,crop_size=crop_size, max_face =max_face)
94 |
95 | src_dir = args.src_dir
96 | save_dir = args.save_dir
97 |
98 | # print('Step: start to extract aligned faces ... ...')
99 |
100 | print('src_dir:%s'% src_dir)
101 | print('save_dir:%s'% save_dir)
102 |
103 | if not os.path.exists(save_dir):
104 | os.makedirs(save_dir)
105 |
106 | paths = utils.all_file(src_dir)
107 | print('to process %d images'% len(paths))
108 |
109 | for path in sorted(paths):
110 | dirname = path[len(src_dir)+1:].split('/')[0]
111 |
112 | outpath = save_dir + path[len(src_dir):]
113 | if not overwrite:
114 | if os.path.exists(outpath):
115 | continue
116 |
117 | sub_dir = os.path.dirname(outpath)
118 | # print(sub_dir)
119 | if not os.path.exists(sub_dir):
120 | os.makedirs(sub_dir, exist_ok=True)
121 |
122 | imgb = None
123 | imgc = None
124 | img = cv2.imread(path, -1)
125 | if img is None:
126 | continue
127 |
128 | if len(img.shape)==2:
129 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
130 |
131 | # print(img.shape)
132 | h,w,c = img.shape
133 | if h<256 or w<256:
134 | continue
135 | imgs = []
136 |
137 | # if need resize, resize here
138 | img_h, img_w, _ = img.shape
139 | warped_faces = processer.process(img)
140 | if warped_faces is None:
141 | continue
142 | # ### only for anime faces, single, not detect face
143 | # warped_face = imga
144 |
145 | i=0
146 | for res in warped_faces:
147 | # filter small faces
148 | h, w, c = res.shape
149 | if h < 256 or w < 256:
150 | continue
151 | outpath = os.path.join(os.path.dirname(outpath), os.path.basename(outpath)[:-4] + '_' + str(i) + '.png')
152 |
153 | cv2.imwrite(outpath, res)
154 | print('save %s' % outpath)
155 | i = i+1
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/generate_data.py:
--------------------------------------------------------------------------------
1 | from modelscope.pipelines import pipeline
2 | from modelscope.utils.constant import Tasks
3 | import torch
4 | import os, cv2
5 | import argparse
6 |
7 | def load_cele_txt(celeb_file='celeb.txt'):
8 | celeb = open(celeb_file, 'r')
9 | lines = celeb.readlines()
10 | name_list = []
11 | for line in lines:
12 | name = line.strip('\n')
13 | if name != '':
14 | name_list.append(name)
15 | return name_list
16 |
17 |
18 | def main(args):
19 | style = args.style
20 | repeat_num = 5
21 |
22 | model_id = 'damo/cv_cartoon_stable_diffusion_' + style
23 | pipe = pipeline(Tasks.text_to_image_synthesis, model=model_id,
24 | model_revision='v1.0.0', torch_dtype=torch.float16)
25 | from diffusers.schedulers import EulerAncestralDiscreteScheduler
26 | pipe.pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.pipeline.scheduler.config)
27 | print('model init finished!')
28 |
29 |
30 | save_dir = 'res_style_%s/syn_celeb' % (style)
31 | if not os.path.exists(save_dir):
32 | os.makedirs(save_dir)
33 |
34 | name_list = load_cele_txt('celeb.txt')
35 | person_num = len(name_list)
36 | for i in range(person_num):
37 | name = name_list[i]
38 | print('process %s' % name)
39 |
40 | if style == "clipart":
41 | prompt = 'archer style, a portrait painting of %s' % (name)
42 | else:
43 | prompt = 'sks style, a painting of a %s, no text' % (name)
44 |
45 | images = pipe({'text': prompt, 'num_images_per_prompt': repeat_num})['output_imgs']
46 | idx = 0
47 | for image in images:
48 | outpath = os.path.join(save_dir, '%s_%d.png' % (name, idx))
49 | cv2.imwrite(outpath, image)
50 | idx += 1
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument('--style', type=str, default='clipart')
56 |
57 | args = parser.parse_args()
58 | main(args)
59 |
60 |
--------------------------------------------------------------------------------
/input.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/input.mp4
--------------------------------------------------------------------------------
/input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/input.png
--------------------------------------------------------------------------------
/multi-style/download.py:
--------------------------------------------------------------------------------
1 | from modelscope.hub.snapshot_download import snapshot_download
2 | import argparse
3 |
4 |
5 |
6 | def process(args):
7 | style = args.style
8 | print('download %s model'%style)
9 | if style == "anime":
10 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon_compound-models', cache_dir='.')
11 |
12 | elif style == "3d":
13 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-3d_compound-models', cache_dir='.')
14 |
15 | elif style == "handdrawn":
16 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-handdrawn_compound-models', cache_dir='.')
17 |
18 | elif style == "sketch":
19 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-sketch_compound-models', cache_dir='.')
20 |
21 | elif style == "artstyle":
22 | model_dir = snapshot_download('damo/cv_unet_person-image-cartoon-artstyle_compound-models', cache_dir='.')
23 |
24 | else:
25 | print('no such style %s'% style)
26 |
27 |
28 |
29 |
30 | if __name__ == '__main__':
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument('--style', type=str, default='anime')
33 | args = parser.parse_args()
34 |
35 | process(args)
36 |
--------------------------------------------------------------------------------
/multi-style/run.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('.')
3 |
4 | import cv2
5 | from source.cartoonize import Cartoonizer
6 | import os
7 | import argparse
8 |
9 |
10 | def process(args):
11 |
12 | style = args.style
13 | if style == "anime":
14 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon_compound-models')
15 |
16 | elif style == "3d":
17 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-3d_compound-models')
18 |
19 | elif style == "handdrawn":
20 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-handdrawn_compound-models')
21 |
22 | elif style == "sketch":
23 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-sketch_compound-models')
24 |
25 | elif style == "artstyle":
26 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon-artstyle_compound-models')
27 |
28 | else:
29 | print('no such style %s' % style)
30 | return 0
31 |
32 | img = cv2.imread('input.png')[..., ::-1]
33 | result = algo.cartoonize(img)
34 | cv2.imwrite('result1_%s.png'%style, result)
35 |
36 | print('finished!')
37 |
38 |
39 |
40 |
41 | if __name__ == '__main__':
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--style', type=str, default='anime')
44 | args = parser.parse_args()
45 |
46 | process(args)
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/multi-style/run_sdk.py:
--------------------------------------------------------------------------------
1 | import cv2, argparse
2 | from modelscope.outputs import OutputKeys
3 | from modelscope.pipelines import pipeline
4 | from modelscope.utils.constant import Tasks
5 |
6 | def process(args):
7 | style = args.style
8 | print('choose style %s'%style)
9 | if style == "anime":
10 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
11 | model='damo/cv_unet_person-image-cartoon_compound-models')
12 | elif style == "3d":
13 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
14 | model='damo/cv_unet_person-image-cartoon-3d_compound-models')
15 | elif style == "handdrawn":
16 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
17 | model='damo/cv_unet_person-image-cartoon-handdrawn_compound-models')
18 | elif style == "sketch":
19 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
20 | model='damo/cv_unet_person-image-cartoon-sketch_compound-models')
21 | elif style == "artstyle":
22 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
23 | model='damo/cv_unet_person-image-cartoon-artstyle_compound-models')
24 | else:
25 | print('no such style %s'% style)
26 | return 0
27 |
28 |
29 | result = img_cartoon('input.png')
30 |
31 | cv2.imwrite('result_%s.png'%style, result[OutputKeys.OUTPUT_IMG])
32 | print('finished!')
33 |
34 |
35 |
36 |
37 | if __name__ == '__main__':
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--style', type=str, default='anime')
40 | args = parser.parse_args()
41 |
42 | process(args)
43 |
--------------------------------------------------------------------------------
/notebooks/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/notebooks/.DS_Store
--------------------------------------------------------------------------------
/prepare_data.sh:
--------------------------------------------------------------------------------
1 |
2 | data_root='data'
3 | align_dir='raw_style_data_faces'
4 |
5 | echo "STEP: start to prepare data for stylegan ..."
6 | cd $data_root
7 | if [ ! -d stylegan ]; then
8 | mkdir stylegan
9 | fi
10 | cd stylegan
11 | stylegan_data_dir=$(pwd)
12 | if [ ! -d "$(date +"%Y%m%d")" ]; then
13 | mkdir "$(date +"%Y%m%d")"
14 | fi
15 | cd "$(date +"%Y%m%d")"
16 | cp $align_dir . -r
17 | if [ -d $(echo $align_dir) ]; then
18 | cp $(echo $align_dir) . -r
19 | fi
20 | src_dir_sg=$(pwd)
21 |
22 | cd $data_root/../source
23 | outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")"
24 | echo $outdir_sg
25 | echo $src_dir_sg
26 | if [ ! -d "$outdir_sg" ]; then
27 | python prepare_data.py --size 256 --out $outdir_sg $src_dir_sg
28 | fi
29 | echo "prepare data for stylegan finished!"
30 |
31 | ### train model
32 | #cd $data_root
33 | #cd stylegan
34 | #stylegan_data_dir=$(pwd)
35 | #outdir_sg="$(echo $stylegan_data_dir)/traindata_$(echo $stylename)_256_$(date +"%m%d")"
36 | #echo "STEP:start to train the style learner ..."
37 | #echo $outdir_sg
38 | #exp_name="ffhq_$(echo $stylename)_s256_id01_$(date +"%m%d")"
39 | #cd /data/vdb/qingyao/cartoon/mycode/stylegan2-pytorch
40 | #model_path=face_generation/experiment_stylegan/$(echo $exp_name)/models/001000.pt
41 | #if [ ! -f "$model_path" ]; then
42 | # CUDA_VISIBLE_DEVICES=6 python train_condition.py --name $exp_name --path $outdir_sg --config config/conf_server_train_condition_shell.json
43 | #fi
44 | #### [training...]
45 | #echo "train the style learner finished!"
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 |
2 | import cv2
3 | from source.cartoonize import Cartoonizer
4 | import os
5 |
6 | def process():
7 |
8 | algo = Cartoonizer(dataroot='damo/cv_unet_person-image-cartoon_compound-models')
9 | img = cv2.imread('input.png')[...,::-1]
10 |
11 | result = algo.cartoonize(img)
12 |
13 | cv2.imwrite('res.png', result)
14 | print('finished!')
15 |
16 |
17 |
18 |
19 | if __name__ == '__main__':
20 | process()
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/run_sdk.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from modelscope.outputs import OutputKeys
3 | from modelscope.pipelines import pipeline
4 | from modelscope.utils.constant import Tasks
5 |
6 | ##### DCT-Net
7 | ## anime style
8 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
9 | model='damo/cv_unet_person-image-cartoon_compound-models')
10 | result = img_cartoon('input.png')
11 | cv2.imwrite('result_anime.png', result[OutputKeys.OUTPUT_IMG])
12 |
13 | ## 3d style
14 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
15 | model='damo/cv_unet_person-image-cartoon-3d_compound-models')
16 | result = img_cartoon('input.png')
17 | cv2.imwrite('result_3d.png', result[OutputKeys.OUTPUT_IMG])
18 |
19 | ## handdrawn style
20 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
21 | model='damo/cv_unet_person-image-cartoon-handdrawn_compound-models')
22 | result = img_cartoon('input.png')
23 | cv2.imwrite('result_handdrawn.png', result[OutputKeys.OUTPUT_IMG])
24 |
25 | ## sketch style
26 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
27 | model='damo/cv_unet_person-image-cartoon-sketch_compound-models')
28 | result = img_cartoon('input.png')
29 | cv2.imwrite('result_sketch.png', result[OutputKeys.OUTPUT_IMG])
30 |
31 | ## artstyle style
32 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
33 | model='damo/cv_unet_person-image-cartoon-artstyle_compound-models')
34 | result = img_cartoon('input.png')
35 | cv2.imwrite('result_artstyle.png', result[OutputKeys.OUTPUT_IMG])
36 |
37 | #### DCT-Net + SD
38 | ## design style
39 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
40 | model='damo/cv_unet_person-image-cartoon-sd-design_compound-models')
41 | result = img_cartoon('input.png')
42 | cv2.imwrite('result_sd_design.png', result[OutputKeys.OUTPUT_IMG])
43 |
44 | ## illustration style
45 | img_cartoon = pipeline(Tasks.image_portrait_stylization,
46 | model='damo/cv_unet_person-image-cartoon-sd-illustration_compound-models')
47 | result = img_cartoon('input.png')
48 | cv2.imwrite('result_sd_illustration.png', result[OutputKeys.OUTPUT_IMG])
49 |
50 |
51 | print('finished!')
52 |
--------------------------------------------------------------------------------
/run_vid.py:
--------------------------------------------------------------------------------
1 | import cv2, argparse
2 | import imageio
3 | from tqdm import tqdm
4 | import numpy as np
5 | from modelscope.outputs import OutputKeys
6 | from modelscope.pipelines import pipeline
7 | from modelscope.utils.constant import Tasks
8 |
9 | def process(args):
10 | style = args.style
11 | print('choose style %s'%style)
12 |
13 | reader = imageio.get_reader(args.video_path)
14 | fps = reader.get_meta_data()['fps']
15 | writer = imageio.get_writer(args.save_path, mode='I', fps=fps, codec='libx264')
16 |
17 | if style == "anime":
18 | style = ""
19 | else:
20 | style = '-' + style
21 |
22 | model_name = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models'
23 | img_cartoon = pipeline(Tasks.image_portrait_stylization, model=model_name)
24 |
25 | for _, img in tqdm(enumerate(reader)):
26 | result = img_cartoon(img[..., ::-1])
27 | res = result[OutputKeys.OUTPUT_IMG]
28 | writer.append_data(res[..., ::-1].astype(np.uint8))
29 | writer.close()
30 | print('finished!')
31 | print('result saved to %s'% args.save_path)
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 | parser.add_argument('--video_path', type=str, default='PATH/TO/YOUR/MP4')
37 | parser.add_argument('--save_path', type=str, default='res.mp4')
38 | parser.add_argument('--style', type=str, default='anime')
39 |
40 | args = parser.parse_args()
41 |
42 | process(args)
43 |
--------------------------------------------------------------------------------
/source/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/.DS_Store
--------------------------------------------------------------------------------
/source/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/__init__.py
--------------------------------------------------------------------------------
/source/cartoonize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import tensorflow as tf
4 | import numpy as np
5 | from source.facelib.facer import FaceAna
6 | import source.utils as utils
7 | from source.mtcnn_pytorch.src.align_trans import warp_and_crop_face, get_reference_facial_points
8 |
9 | if tf.__version__ >= '2.0':
10 | tf = tf.compat.v1
11 | tf.disable_eager_execution()
12 |
13 |
14 | class Cartoonizer():
15 | def __init__(self, dataroot):
16 |
17 | self.facer = FaceAna(dataroot)
18 | self.sess_head = self.load_sess(
19 | os.path.join(dataroot, 'cartoon_anime_h.pb'), 'model_head')
20 | self.sess_bg = self.load_sess(
21 | os.path.join(dataroot, 'cartoon_anime_bg.pb'), 'model_bg')
22 |
23 | self.box_width = 288
24 | global_mask = cv2.imread(os.path.join(dataroot, 'alpha.jpg'))
25 | global_mask = cv2.resize(
26 | global_mask, (self.box_width, self.box_width),
27 | interpolation=cv2.INTER_AREA)
28 | self.global_mask = cv2.cvtColor(
29 | global_mask, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
30 |
31 | def load_sess(self, model_path, name):
32 | config = tf.ConfigProto(allow_soft_placement=True)
33 | config.gpu_options.allow_growth = True
34 | sess = tf.Session(config=config)
35 | print(f'loading model from {model_path}')
36 | with tf.gfile.FastGFile(model_path, 'rb') as f:
37 | graph_def = tf.GraphDef()
38 | graph_def.ParseFromString(f.read())
39 | sess.graph.as_default()
40 | tf.import_graph_def(graph_def, name=name)
41 | sess.run(tf.global_variables_initializer())
42 | print(f'load model {model_path} done.')
43 | return sess
44 |
45 |
46 | def detect_face(self, img):
47 | src_h, src_w, _ = img.shape
48 | src_x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
49 | boxes, landmarks, _ = self.facer.run(src_x)
50 | if boxes.shape[0] == 0:
51 | return None
52 | else:
53 | return landmarks
54 |
55 |
56 | def cartoonize(self, img):
57 | # img: RGB input
58 | ori_h, ori_w, _ = img.shape
59 | img = utils.resize_size(img, size=720)
60 |
61 | img_brg = img[:, :, ::-1]
62 |
63 | # background process
64 | pad_bg, pad_h, pad_w = utils.padTo16x(img_brg)
65 |
66 | bg_res = self.sess_bg.run(
67 | self.sess_bg.graph.get_tensor_by_name(
68 | 'model_bg/output_image:0'),
69 | feed_dict={'model_bg/input_image:0': pad_bg})
70 | res = bg_res[:pad_h, :pad_w, :]
71 |
72 | landmarks = self.detect_face(img_brg)
73 | if landmarks is None:
74 | print('No face detected!')
75 | return res
76 |
77 | print('%d faces detected!'%len(landmarks))
78 | for landmark in landmarks:
79 | # get facial 5 points
80 | f5p = utils.get_f5p(landmark, img_brg)
81 |
82 | # face alignment
83 | head_img, trans_inv = warp_and_crop_face(
84 | img,
85 | f5p,
86 | ratio=0.75,
87 | reference_pts=get_reference_facial_points(default_square=True),
88 | crop_size=(self.box_width, self.box_width),
89 | return_trans_inv=True)
90 |
91 | # head process
92 | head_res = self.sess_head.run(
93 | self.sess_head.graph.get_tensor_by_name(
94 | 'model_head/output_image:0'),
95 | feed_dict={
96 | 'model_head/input_image:0': head_img[:, :, ::-1]
97 | })
98 |
99 | # merge head and background
100 | head_trans_inv = cv2.warpAffine(
101 | head_res,
102 | trans_inv, (np.size(img, 1), np.size(img, 0)),
103 | borderValue=(0, 0, 0))
104 |
105 | mask = self.global_mask
106 | mask_trans_inv = cv2.warpAffine(
107 | mask,
108 | trans_inv, (np.size(img, 1), np.size(img, 0)),
109 | borderValue=(0, 0, 0))
110 | mask_trans_inv = np.expand_dims(mask_trans_inv, 2)
111 |
112 | res = mask_trans_inv * head_trans_inv + (1 - mask_trans_inv) * res
113 |
114 | res = cv2.resize(res, (ori_w, ori_h), interpolation=cv2.INTER_AREA)
115 |
116 | return res
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/source/facelib/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Copyright (c) Peppa_Pig_Face_Engine
3 |
4 | https://github.com/610265158/Peppa_Pig_Face_Engine
5 |
--------------------------------------------------------------------------------
/source/facelib/LK/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/facelib/LK/__init__.py
--------------------------------------------------------------------------------
/source/facelib/LK/lk.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from modelscope.models.cv.cartoon.facelib.config import config as cfg
4 |
5 |
6 | class GroupTrack():
7 |
8 | def __init__(self):
9 | self.old_frame = None
10 | self.previous_landmarks_set = None
11 | self.with_landmark = True
12 | self.thres = cfg.TRACE.pixel_thres
13 | self.alpha = cfg.TRACE.smooth_landmark
14 | self.iou_thres = cfg.TRACE.iou_thres
15 |
16 | def calculate(self, img, current_landmarks_set):
17 | if self.previous_landmarks_set is None:
18 | self.previous_landmarks_set = current_landmarks_set
19 | result = current_landmarks_set
20 | else:
21 | previous_lm_num = self.previous_landmarks_set.shape[0]
22 | if previous_lm_num == 0:
23 | self.previous_landmarks_set = current_landmarks_set
24 | result = current_landmarks_set
25 | return result
26 | else:
27 | result = []
28 | for i in range(current_landmarks_set.shape[0]):
29 | not_in_flag = True
30 | for j in range(previous_lm_num):
31 | if self.iou(current_landmarks_set[i],
32 | self.previous_landmarks_set[j]
33 | ) > self.iou_thres:
34 | result.append(
35 | self.smooth(current_landmarks_set[i],
36 | self.previous_landmarks_set[j]))
37 | not_in_flag = False
38 | break
39 | if not_in_flag:
40 | result.append(current_landmarks_set[i])
41 |
42 | result = np.array(result)
43 | self.previous_landmarks_set = result
44 |
45 | return result
46 |
47 | def iou(self, p_set0, p_set1):
48 | rec1 = [
49 | np.min(p_set0[:, 0]),
50 | np.min(p_set0[:, 1]),
51 | np.max(p_set0[:, 0]),
52 | np.max(p_set0[:, 1])
53 | ]
54 | rec2 = [
55 | np.min(p_set1[:, 0]),
56 | np.min(p_set1[:, 1]),
57 | np.max(p_set1[:, 0]),
58 | np.max(p_set1[:, 1])
59 | ]
60 |
61 | # computing area of each rectangles
62 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
63 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
64 |
65 | # computing the sum_area
66 | sum_area = S_rec1 + S_rec2
67 |
68 | # find the each edge of intersect rectangle
69 | x1 = max(rec1[0], rec2[0])
70 | y1 = max(rec1[1], rec2[1])
71 | x2 = min(rec1[2], rec2[2])
72 | y2 = min(rec1[3], rec2[3])
73 |
74 | # judge if there is an intersect
75 | intersect = max(0, x2 - x1) * max(0, y2 - y1)
76 |
77 | iou = intersect / (sum_area - intersect)
78 | return iou
79 |
80 | def smooth(self, now_landmarks, previous_landmarks):
81 | result = []
82 | for i in range(now_landmarks.shape[0]):
83 | x = now_landmarks[i][0] - previous_landmarks[i][0]
84 | y = now_landmarks[i][1] - previous_landmarks[i][1]
85 | dis = np.sqrt(np.square(x) + np.square(y))
86 | if dis < self.thres:
87 | result.append(previous_landmarks[i])
88 | else:
89 | result.append(
90 | self.do_moving_average(now_landmarks[i],
91 | previous_landmarks[i]))
92 |
93 | return np.array(result)
94 |
95 | def do_moving_average(self, p_now, p_previous):
96 | p = self.alpha * p_now + (1 - self.alpha) * p_previous
97 | return p
98 |
--------------------------------------------------------------------------------
/source/facelib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/facelib/__init__.py
--------------------------------------------------------------------------------
/source/facelib/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from easydict import EasyDict as edict
5 |
6 | config = edict()
7 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
8 |
9 | config.DETECT = edict()
10 | config.DETECT.topk = 10
11 | config.DETECT.thres = 0.8
12 | config.DETECT.input_shape = (512, 512, 3)
13 | config.KEYPOINTS = edict()
14 | config.KEYPOINTS.p_num = 68
15 | config.KEYPOINTS.base_extend_range = [0.2, 0.3]
16 | config.KEYPOINTS.input_shape = (160, 160, 3)
17 | config.TRACE = edict()
18 | config.TRACE.pixel_thres = 1
19 | config.TRACE.smooth_box = 0.3
20 | config.TRACE.smooth_landmark = 0.95
21 | config.TRACE.iou_thres = 0.5
22 | config.DATA = edict()
23 | config.DATA.pixel_means = np.array([123., 116., 103.]) # RGB
24 |
--------------------------------------------------------------------------------
/source/facelib/face_detector.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cv2
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from .config import config as cfg
8 |
9 | if tf.__version__ >= '2.0':
10 | tf = tf.compat.v1
11 |
12 |
13 | class FaceDetector:
14 |
15 | def __init__(self, dir):
16 |
17 | self.model_path = dir + '/detector.pb'
18 | self.thres = cfg.DETECT.thres
19 | self.input_shape = cfg.DETECT.input_shape
20 |
21 | self._graph = tf.Graph()
22 |
23 | with self._graph.as_default():
24 | self._graph, self._sess = self.init_model(self.model_path)
25 |
26 | self.input_image = tf.get_default_graph().get_tensor_by_name(
27 | 'tower_0/images:0')
28 | self.training = tf.get_default_graph().get_tensor_by_name(
29 | 'training_flag:0')
30 | self.output_ops = [
31 | tf.get_default_graph().get_tensor_by_name('tower_0/boxes:0'),
32 | tf.get_default_graph().get_tensor_by_name('tower_0/scores:0'),
33 | tf.get_default_graph().get_tensor_by_name(
34 | 'tower_0/num_detections:0'),
35 | ]
36 |
37 | def __call__(self, image):
38 |
39 | image, scale_x, scale_y = self.preprocess(
40 | image,
41 | target_width=self.input_shape[1],
42 | target_height=self.input_shape[0])
43 |
44 | image = np.expand_dims(image, 0)
45 |
46 | boxes, scores, num_boxes = self._sess.run(
47 | self.output_ops,
48 | feed_dict={
49 | self.input_image: image,
50 | self.training: False
51 | })
52 |
53 | num_boxes = num_boxes[0]
54 | boxes = boxes[0][:num_boxes]
55 |
56 | scores = scores[0][:num_boxes]
57 |
58 | to_keep = scores > self.thres
59 | boxes = boxes[to_keep]
60 | scores = scores[to_keep]
61 |
62 | y1 = self.input_shape[0] / scale_y
63 | x1 = self.input_shape[1] / scale_x
64 | y2 = self.input_shape[0] / scale_y
65 | x2 = self.input_shape[1] / scale_x
66 | scaler = np.array([y1, x1, y2, x2], dtype='float32')
67 | boxes = boxes * scaler
68 |
69 | scores = np.expand_dims(scores, 0).reshape([-1, 1])
70 |
71 | for i in range(boxes.shape[0]):
72 | boxes[i] = np.array(
73 | [boxes[i][1], boxes[i][0], boxes[i][3], boxes[i][2]])
74 | return np.concatenate([boxes, scores], axis=1)
75 |
76 | def preprocess(self, image, target_height, target_width, label=None):
77 |
78 | h, w, c = image.shape
79 |
80 | bimage = np.zeros(
81 | shape=[target_height, target_width, c],
82 | dtype=image.dtype) + np.array(
83 | cfg.DATA.pixel_means, dtype=image.dtype)
84 | long_side = max(h, w)
85 |
86 | scale_x = scale_y = target_height / long_side
87 |
88 | image = cv2.resize(image, None, fx=scale_x, fy=scale_y)
89 |
90 | h_, w_, _ = image.shape
91 | bimage[:h_, :w_, :] = image
92 |
93 | return bimage, scale_x, scale_y
94 |
95 | def init_model(self, *args):
96 | pb_path = args[0]
97 |
98 | def init_pb(model_path):
99 | config = tf.ConfigProto()
100 | config.gpu_options.per_process_gpu_memory_fraction = 0.2
101 | compute_graph = tf.Graph()
102 | compute_graph.as_default()
103 | sess = tf.Session(config=config)
104 | with tf.gfile.GFile(model_path, 'rb') as fid:
105 | graph_def = tf.GraphDef()
106 | graph_def.ParseFromString(fid.read())
107 | tf.import_graph_def(graph_def, name='')
108 |
109 | return (compute_graph, sess)
110 |
111 | model = init_pb(pb_path)
112 |
113 | graph = model[0]
114 | sess = model[1]
115 |
116 | return graph, sess
117 |
--------------------------------------------------------------------------------
/source/facelib/face_landmark.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import tensorflow as tf
4 |
5 | from .config import config as cfg
6 |
7 | if tf.__version__ >= '2.0':
8 | tf = tf.compat.v1
9 |
10 |
11 | class FaceLandmark:
12 |
13 | def __init__(self, dir):
14 | self.model_path = dir + '/keypoints.pb'
15 | self.min_face = 60
16 | self.keypoint_num = cfg.KEYPOINTS.p_num * 2
17 |
18 | self._graph = tf.Graph()
19 |
20 | with self._graph.as_default():
21 |
22 | self._graph, self._sess = self.init_model(self.model_path)
23 | self.img_input = tf.get_default_graph().get_tensor_by_name(
24 | 'tower_0/images:0')
25 | self.embeddings = tf.get_default_graph().get_tensor_by_name(
26 | 'tower_0/prediction:0')
27 | self.training = tf.get_default_graph().get_tensor_by_name(
28 | 'training_flag:0')
29 |
30 | self.landmark = self.embeddings[:, :self.keypoint_num]
31 | self.headpose = self.embeddings[:, -7:-4] * 90.
32 | self.state = tf.nn.sigmoid(self.embeddings[:, -4:])
33 |
34 | def __call__(self, img, bboxes):
35 | landmark_result = []
36 | state_result = []
37 | for i, bbox in enumerate(bboxes):
38 | landmark, state = self._one_shot_run(img, bbox, i)
39 | if landmark is not None:
40 | landmark_result.append(landmark)
41 | state_result.append(state)
42 | return np.array(landmark_result), np.array(state_result)
43 |
44 | def simple_run(self, cropped_img):
45 | with self._graph.as_default():
46 |
47 | cropped_img = np.expand_dims(cropped_img, axis=0)
48 | landmark, p, states = self._sess.run(
49 | [self.landmark, self.headpose, self.state],
50 | feed_dict={
51 | self.img_input: cropped_img,
52 | self.training: False
53 | })
54 |
55 | return landmark, states
56 |
57 | def _one_shot_run(self, image, bbox, i):
58 |
59 | bbox_width = bbox[2] - bbox[0]
60 | bbox_height = bbox[3] - bbox[1]
61 | if (bbox_width <= self.min_face and bbox_height <= self.min_face):
62 | return None, None
63 | add = int(max(bbox_width, bbox_height))
64 | bimg = cv2.copyMakeBorder(
65 | image,
66 | add,
67 | add,
68 | add,
69 | add,
70 | borderType=cv2.BORDER_CONSTANT,
71 | value=cfg.DATA.pixel_means)
72 | bbox += add
73 |
74 | one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width
75 | center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2]
76 |
77 | bbox[0] = center[0] - one_edge // 2
78 | bbox[1] = center[1] - one_edge // 2
79 | bbox[2] = center[0] + one_edge // 2
80 | bbox[3] = center[1] + one_edge // 2
81 |
82 | bbox = bbox.astype(np.int)
83 | crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :]
84 | h, w, _ = crop_image.shape
85 | crop_image = cv2.resize(
86 | crop_image,
87 | (cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0]))
88 | crop_image = crop_image.astype(np.float32)
89 |
90 | keypoints, state = self.simple_run(crop_image)
91 |
92 | res = keypoints[0][:self.keypoint_num].reshape((-1, 2))
93 | res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1]
94 | res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0]
95 |
96 | landmark = []
97 | for _index in range(res.shape[0]):
98 | x_y = res[_index]
99 | landmark.append([
100 | int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add),
101 | int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add)
102 | ])
103 |
104 | landmark = np.array(landmark, np.float32)
105 |
106 | return landmark, state
107 |
108 | def init_model(self, *args):
109 |
110 | if len(args) == 1:
111 | use_pb = True
112 | pb_path = args[0]
113 | else:
114 | use_pb = False
115 | meta_path = args[0]
116 | restore_model_path = args[1]
117 |
118 | def ini_ckpt():
119 | graph = tf.Graph()
120 | graph.as_default()
121 | configProto = tf.ConfigProto()
122 | configProto.gpu_options.allow_growth = True
123 | sess = tf.Session(config=configProto)
124 | # load_model(model_path, sess)
125 | saver = tf.train.import_meta_graph(meta_path)
126 | saver.restore(sess, restore_model_path)
127 |
128 | print('Model restred!')
129 | return (graph, sess)
130 |
131 | def init_pb(model_path):
132 | config = tf.ConfigProto()
133 | config.gpu_options.per_process_gpu_memory_fraction = 0.2
134 | compute_graph = tf.Graph()
135 | compute_graph.as_default()
136 | sess = tf.Session(config=config)
137 | with tf.gfile.GFile(model_path, 'rb') as fid:
138 | graph_def = tf.GraphDef()
139 | graph_def.ParseFromString(fid.read())
140 | tf.import_graph_def(graph_def, name='')
141 |
142 | # saver = tf.train.Saver(tf.global_variables())
143 | # saver.save(sess, save_path='./tmp.ckpt')
144 | return (compute_graph, sess)
145 |
146 | if use_pb:
147 | model = init_pb(pb_path)
148 | else:
149 | model = ini_ckpt()
150 |
151 | graph = model[0]
152 | sess = model[1]
153 |
154 | return graph, sess
155 |
--------------------------------------------------------------------------------
/source/facelib/facer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cv2
4 | import numpy as np
5 |
6 | from .config import config as cfg
7 | from .face_detector import FaceDetector
8 | from .face_landmark import FaceLandmark
9 | from .LK.lk import GroupTrack
10 |
11 |
12 | class FaceAna():
13 | '''
14 | by default the top3 facea sorted by area will be calculated for time reason
15 | '''
16 |
17 | def __init__(self, model_dir):
18 | self.face_detector = FaceDetector(model_dir)
19 | self.face_landmark = FaceLandmark(model_dir)
20 | self.trace = GroupTrack()
21 |
22 | self.track_box = None
23 | self.previous_image = None
24 | self.previous_box = None
25 |
26 | self.diff_thres = 5
27 | self.top_k = cfg.DETECT.topk
28 | self.iou_thres = cfg.TRACE.iou_thres
29 | self.alpha = cfg.TRACE.smooth_box
30 |
31 | def run(self, image):
32 |
33 | boxes = self.face_detector(image)
34 |
35 | if boxes.shape[0] > self.top_k:
36 | boxes = self.sort(boxes)
37 |
38 | boxes_return = np.array(boxes)
39 | landmarks, states = self.face_landmark(image, boxes)
40 |
41 | if 1:
42 | track = []
43 | for i in range(landmarks.shape[0]):
44 | track.append([
45 | np.min(landmarks[i][:, 0]),
46 | np.min(landmarks[i][:, 1]),
47 | np.max(landmarks[i][:, 0]),
48 | np.max(landmarks[i][:, 1])
49 | ])
50 | tmp_box = np.array(track)
51 |
52 | self.track_box = self.judge_boxs(boxes_return, tmp_box)
53 |
54 | self.track_box, landmarks = self.sort_res(self.track_box, landmarks)
55 | return self.track_box, landmarks, states
56 |
57 | def sort_res(self, bboxes, points):
58 | area = []
59 | for bbox in bboxes:
60 | bbox_width = bbox[2] - bbox[0]
61 | bbox_height = bbox[3] - bbox[1]
62 | area.append(bbox_height * bbox_width)
63 |
64 | area = np.array(area)
65 | picked = area.argsort()[::-1]
66 | sorted_bboxes = [bboxes[x] for x in picked]
67 | sorted_points = [points[x] for x in picked]
68 | return np.array(sorted_bboxes), np.array(sorted_points)
69 |
70 | def diff_frames(self, previous_frame, image):
71 | if previous_frame is None:
72 | return True
73 | else:
74 | _diff = cv2.absdiff(previous_frame, image)
75 | diff = np.sum(
76 | _diff) / previous_frame.shape[0] / previous_frame.shape[1] / 3.
77 | return diff > self.diff_thres
78 |
79 | def sort(self, bboxes):
80 | if self.top_k > 100:
81 | return bboxes
82 | area = []
83 | for bbox in bboxes:
84 |
85 | bbox_width = bbox[2] - bbox[0]
86 | bbox_height = bbox[3] - bbox[1]
87 | area.append(bbox_height * bbox_width)
88 |
89 | area = np.array(area)
90 |
91 | picked = area.argsort()[-self.top_k:][::-1]
92 | sorted_bboxes = [bboxes[x] for x in picked]
93 | return np.array(sorted_bboxes)
94 |
95 | def judge_boxs(self, previuous_bboxs, now_bboxs):
96 |
97 | def iou(rec1, rec2):
98 |
99 | # computing area of each rectangles
100 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
101 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
102 |
103 | # computing the sum_area
104 | sum_area = S_rec1 + S_rec2
105 |
106 | # find the each edge of intersect rectangle
107 | x1 = max(rec1[0], rec2[0])
108 | y1 = max(rec1[1], rec2[1])
109 | x2 = min(rec1[2], rec2[2])
110 | y2 = min(rec1[3], rec2[3])
111 |
112 | # judge if there is an intersect
113 | intersect = max(0, x2 - x1) * max(0, y2 - y1)
114 |
115 | return intersect / (sum_area - intersect)
116 |
117 | if previuous_bboxs is None:
118 | return now_bboxs
119 |
120 | result = []
121 |
122 | for i in range(now_bboxs.shape[0]):
123 | contain = False
124 | for j in range(previuous_bboxs.shape[0]):
125 | if iou(now_bboxs[i], previuous_bboxs[j]) > self.iou_thres:
126 | result.append(
127 | self.smooth(now_bboxs[i], previuous_bboxs[j]))
128 | contain = True
129 | break
130 | if not contain:
131 | result.append(now_bboxs[i])
132 |
133 | return np.array(result)
134 |
135 | def smooth(self, now_box, previous_box):
136 |
137 | return self.do_moving_average(now_box[:4], previous_box[:4])
138 |
139 | def do_moving_average(self, p_now, p_previous):
140 | p = self.alpha * p_now + (1 - self.alpha) * p_previous
141 | return p
142 |
143 | def reset(self):
144 | '''
145 | reset the previous info used foe tracking,
146 | :return:
147 | '''
148 | self.track_box = None
149 | self.previous_image = None
150 | self.previous_box = None
151 |
--------------------------------------------------------------------------------
/source/image_flip_agument_parallel.py:
--------------------------------------------------------------------------------
1 | import oss2
2 | import argparse
3 | import cv2
4 | import glob
5 | import os
6 | import tqdm
7 | import numpy as np
8 | # from .utils import get_rmbg_alpha, get_img_from_url,reasonable_resize,major_detection,crop_img
9 | import tqdm
10 | import urllib
11 | import random
12 | from multiprocessing import Pool
13 |
14 |
15 | parser = argparse.ArgumentParser(description="process remove bg result")
16 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
17 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
18 | args = parser.parse_args()
19 |
20 |
21 | args.save_dir = os.path.join(args.data_dir, 'total_flip')
22 | form = 'single'
23 |
24 |
25 | def flipImage(image):
26 | new_image = cv2.flip(image, 1)
27 | return new_image
28 |
29 | def all_file(file_dir):
30 | L=[]
31 | for root, dirs, files in os.walk(file_dir):
32 | for file in files:
33 | extend = os.path.splitext(file)[1]
34 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg' or extend == '.JPG':
35 | L.append(os.path.join(root, file))
36 | return L
37 |
38 |
39 | paths = all_file(args.data_dir)
40 |
41 |
42 | def process(path):
43 |
44 | print(path)
45 | outpath = args.save_dir+path[len(args.data_dir):]
46 | if os.path.exists(outpath):
47 | return
48 |
49 | sub_dir = os.path.dirname(outpath)
50 | # print(sub_dir)
51 | if not os.path.exists(sub_dir):
52 | os.makedirs(sub_dir,exist_ok=True)
53 |
54 | img = cv2.imread(path, -1)
55 | h, w, c = img.shape
56 | if form == "pair":
57 | imga = img[:, :int(w / 2), :]
58 | imgb = img[:, int(w / 2):, :]
59 | imga = flipImage(imga)
60 | imgb = flipImage(imgb)
61 | res = cv2.hconcat([imga, imgb]) # 水平拼接
62 |
63 | else:
64 | res = flipImage(img)
65 |
66 | cv2.imwrite(outpath, res)
67 | print('save %s' % outpath)
68 |
69 |
70 |
71 |
72 |
73 | if __name__ == "__main__":
74 | # main(args)
75 | pool = Pool(100)
76 | rl = pool.map(process, paths)
77 | pool.close()
78 | pool.join()
--------------------------------------------------------------------------------
/source/image_rotation_agument_parallel_flat.py:
--------------------------------------------------------------------------------
1 | import oss2
2 | import argparse
3 | import cv2
4 | import glob
5 | import os
6 | import tqdm
7 | import numpy as np
8 | import tqdm
9 | import urllib
10 | import random
11 | from multiprocessing import Pool
12 |
13 |
14 | parser = argparse.ArgumentParser(description="process remove bg result")
15 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
16 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
17 | args = parser.parse_args()
18 |
19 |
20 | args.save_dir = os.path.join(args.data_dir, 'total_rotate')
21 | form = 'single'
22 |
23 | if not os.path.exists(args.save_dir):
24 | os.makedirs(args.save_dir,exist_ok=True)
25 |
26 | def all_file(file_dir):
27 | L=[]
28 | for root, dirs, files in os.walk(file_dir):
29 | for file in files:
30 | extend = os.path.splitext(file)[1]
31 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
32 | L.append(os.path.join(root, file))
33 | return L
34 |
35 | def rotateImage(image, angle):
36 | row,col,_ = image.shape
37 | center=tuple(np.array([row,col])/2)
38 | rot_mat = cv2.getRotationMatrix2D(center,angle,1.0)
39 | new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT)
40 | return new_image
41 |
42 |
43 | paths = all_file(args.data_dir)
44 |
45 |
46 | def process(path):
47 |
48 | if 'total_scale' in path:
49 | return
50 |
51 | outpath = args.save_dir + path[len(args.data_dir):]
52 | sub_dir = os.path.dirname(outpath)
53 | if not os.path.exists(sub_dir):
54 | os.makedirs(sub_dir, exist_ok=True)
55 |
56 | img0 = cv2.imread(path, -1)
57 | h, w, c = img0.shape
58 | img = img0[:, :, :3].copy()
59 | if c == 4:
60 | alpha = img0[:, :, 3]
61 | mask = alpha[:, :, np.newaxis].copy() / 255.
62 | img = (img * mask + (1 - mask) * 255)
63 |
64 | imgb = None
65 | imgc = None
66 | if form is 'single':
67 | imga = img
68 | elif form is 'pair':
69 | imga = img[:, :int(w / 2), :]
70 | imgb = img[:, int(w / 2):, :]
71 | elif form is 'tuple':
72 | imga = img[:, :int(w / 3), :]
73 | imgb = img[:, int(w / 3): int(w * 2 / 3), :]
74 | imgc = img[:, int(w * 2 / 3):, :]
75 |
76 | angles = [ random.randint(-10, 0), random.randint(0, 10)]
77 |
78 | for angle in angles:
79 |
80 | imga_r = rotateImage(imga, angle)
81 | if form is 'single':
82 | res = imga_r
83 | elif form is 'pair':
84 | imgb_r = rotateImage(imgb, angle)
85 | res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接
86 | else:
87 | imgb_r = rotateImage(imgb, angle)
88 | imgc_r = rotateImage(imgc, angle)
89 | res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接
90 |
91 | cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res)
92 | print('save %s'% outpath)
93 |
94 |
95 |
96 | if __name__ == "__main__":
97 | # main(args)
98 | pool = Pool(100)
99 | rl = pool.map(process, paths)
100 | pool.close()
101 | pool.join()
--------------------------------------------------------------------------------
/source/image_scale_agument_parallel_flat.py:
--------------------------------------------------------------------------------
1 | import oss2
2 | import argparse
3 | import cv2
4 | import glob
5 | import os
6 | import tqdm
7 | import numpy as np
8 | import tqdm
9 | import urllib
10 | import random
11 | from multiprocessing import Pool
12 |
13 |
14 | parser = argparse.ArgumentParser(description="process remove bg result")
15 | parser.add_argument("--data_dir", type=str, default="", help="Path to images.")
16 | parser.add_argument("--save_dir", type=str, default="", help="Path to save images.")
17 | args = parser.parse_args()
18 |
19 |
20 | args.save_dir = os.path.join(args.data_dir, 'total_scale')
21 | form = 'single'
22 |
23 | if not os.path.exists(args.save_dir):
24 | os.makedirs(args.save_dir,exist_ok=True)
25 |
26 | def all_file(file_dir):
27 | L=[]
28 | for root, dirs, files in os.walk(file_dir):
29 | for file in files:
30 | extend = os.path.splitext(file)[1]
31 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
32 | L.append(os.path.join(root, file))
33 | return L
34 |
35 | def scaleImage(image, degree):
36 |
37 | h, w, _ = image.shape
38 | canvas = np.ones((h, w, 3), dtype="uint8")*255
39 | nw, nh = (int(w*degree), int(h*degree))
40 | image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA) # w, h
41 |
42 | if degree<1:
43 | canvas[int((h-nh)/2):int((h-nh)/2)+nh, int((w-nw)/2):int((w-nw)/2)+nw,:] = image
44 | elif degree>1:
45 | canvas = image[int((nh-h)/2):int((nh-h)/2)+h, int((nw-w)/2):int((nw-w)/2)+w, :]
46 | else:
47 | canvas = image.copy()
48 |
49 | return canvas
50 |
51 | def scaleImage2(image, degree, angle=0):
52 | row,col,_ = image.shape
53 | center=tuple(np.array([row,col])/2)
54 | rot_mat = cv2.getRotationMatrix2D(center,angle,degree)
55 | new_image = cv2.warpAffine(image, rot_mat, (col,row), borderMode=cv2.BORDER_REFLECT)
56 | return new_image
57 |
58 |
59 | paths = all_file(args.data_dir)
60 |
61 |
62 | def process(path):
63 |
64 | outpath = args.save_dir+path[len(args.data_dir):]
65 | sub_dir = os.path.dirname(outpath)
66 | if not os.path.exists(sub_dir):
67 | os.makedirs(sub_dir, exist_ok=True)
68 |
69 |
70 | img0 = cv2.imread(path, -1)
71 | h, w, c = img0.shape
72 | img = img0[:, :, :3].copy()
73 | if c==4:
74 | alpha = img0[:, :, 3]
75 | mask = alpha[:, :, np.newaxis].copy() / 255.
76 | img = (img * mask + (1 - mask) * 255)
77 |
78 | imgb = None
79 | imgc = None
80 | if form is 'single':
81 | imga = img
82 | elif form is 'pair':
83 | imga = img[:, :int(w / 2), :]
84 | imgb = img[:, int(w / 2):, :]
85 | elif form is 'tuple':
86 | imga = img[:, :int(w / 3), :]
87 | imgb = img[:, int(w / 3): int(w * 2 / 3), :]
88 | imgc = img[:, int(w * 2 / 3):, :]
89 |
90 | if random.random()>0.9:
91 | angles = [random.uniform(1, 1.1)]
92 | else:
93 | angles = [random.uniform(0.8, 1)]
94 |
95 | for angle in angles:
96 |
97 | imga_r = scaleImage(imga, angle)
98 | if form is 'single':
99 | res = imga_r
100 | elif form is 'pair':
101 | imgb_r = scaleImage(imgb, angle)
102 | res = cv2.hconcat([imga_r, imgb_r]) # 水平拼接
103 | else:
104 | imgb_r = scaleImage(imgb, angle)
105 | imgc_r = scaleImage(imgc, angle)
106 | res = cv2.hconcat([imga_r, imgb_r, imgc_r]) # 水平拼接
107 |
108 | cv2.imwrite(outpath[:-4]+'_'+str(angle)+'.png', res)
109 | print('save %s'% outpath)
110 |
111 |
112 | if __name__ == "__main__":
113 | # main(args)
114 | pool = Pool(100)
115 | rl = pool.map(process, paths)
116 | pool.close()
117 | pool.join()
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/.DS_Store
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Dan Antoshchenko
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/README.md:
--------------------------------------------------------------------------------
1 | # MTCNN
2 |
3 | `pytorch` implementation of **inference stage** of face detection algorithm described in
4 | [Joint Face Detection and Alignment using Multi-task Cascaded Convolutional Networks](https://arxiv.org/abs/1604.02878).
5 |
6 | ## Example
7 | 
8 |
9 | ## How to use it
10 | Just download the repository and then do this
11 | ```python
12 | from src import detect_faces
13 | from PIL import Image
14 |
15 | image = Image.open('image.jpg')
16 | bounding_boxes, landmarks = detect_faces(image)
17 | ```
18 | For examples see `test_on_images.ipynb`.
19 |
20 | ## Requirements
21 | * pytorch 0.2
22 | * Pillow, numpy
23 |
24 | ## Credit
25 | This implementation is heavily inspired by:
26 | * [pangyupo/mxnet_mtcnn_face_detection](https://github.com/pangyupo/mxnet_mtcnn_face_detection)
27 |
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/__init__.py
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/mtcnn_pytorch/src/__init__.py
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/src/align_trans.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Mon Apr 24 15:43:29 2017
3 | @author: zhaoy
4 | """
5 | import cv2
6 | import numpy as np
7 |
8 | from .matlab_cp2tform import get_similarity_transform_for_cv2
9 |
10 | # reference facial points, a list of coordinates (x,y)
11 | dx = 1
12 | dy = 1
13 | REFERENCE_FACIAL_POINTS = [
14 | [30.29459953 + dx, 51.69630051 + dy], # left eye
15 | [65.53179932 + dx, 51.50139999 + dy], # right eye
16 | [48.02519989 + dx, 71.73660278 + dy], # nose
17 | [33.54930115 + dx, 92.3655014 + dy], # left mouth
18 | [62.72990036 + dx, 92.20410156 + dy] # right mouth
19 | ]
20 |
21 | DEFAULT_CROP_SIZE = (96, 112)
22 |
23 | global FACIAL_POINTS
24 |
25 |
26 | class FaceWarpException(Exception):
27 |
28 | def __str__(self):
29 | return 'In File {}:{}'.format(__file__, super.__str__(self))
30 |
31 |
32 | def get_reference_facial_points(output_size=None,
33 | inner_padding_factor=0.0,
34 | outer_padding=(0, 0),
35 | default_square=False):
36 |
37 | tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
38 | tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
39 |
40 | # 0) make the inner region a square
41 | if default_square:
42 | size_diff = max(tmp_crop_size) - tmp_crop_size
43 | tmp_5pts += size_diff / 2
44 | tmp_crop_size += size_diff
45 |
46 | h_crop = tmp_crop_size[0]
47 | w_crop = tmp_crop_size[1]
48 | if (output_size):
49 | if (output_size[0] == h_crop and output_size[1] == w_crop):
50 | return tmp_5pts
51 |
52 | if (inner_padding_factor == 0 and outer_padding == (0, 0)):
53 | if output_size is None:
54 | return tmp_5pts
55 | else:
56 | raise FaceWarpException(
57 | 'No paddings to do, output_size must be None or {}'.format(
58 | tmp_crop_size))
59 |
60 | # check output size
61 | if not (0 <= inner_padding_factor <= 1.0):
62 | raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
63 |
64 | factor = inner_padding_factor > 0 or outer_padding[0] > 0
65 | factor = factor or outer_padding[1] > 0
66 | if (factor and output_size is None):
67 | output_size = tmp_crop_size * \
68 | (1 + inner_padding_factor * 2).astype(np.int32)
69 | output_size += np.array(outer_padding)
70 |
71 | cond1 = outer_padding[0] < output_size[0]
72 | cond2 = outer_padding[1] < output_size[1]
73 | if not (cond1 and cond2):
74 | raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
75 | 'and outer_padding[1] < output_size[1])')
76 |
77 | # 1) pad the inner region according inner_padding_factor
78 | if inner_padding_factor > 0:
79 | size_diff = tmp_crop_size * inner_padding_factor * 2
80 | tmp_5pts += size_diff / 2
81 | tmp_crop_size += np.round(size_diff).astype(np.int32)
82 |
83 | # 2) resize the padded inner region
84 | size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
85 |
86 | if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[
87 | 1] * tmp_crop_size[0]:
88 | raise FaceWarpException(
89 | 'Must have (output_size - outer_padding)'
90 | '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
91 |
92 | scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
93 | tmp_5pts = tmp_5pts * scale_factor
94 |
95 | # 3) add outer_padding to make output_size
96 | reference_5point = tmp_5pts + np.array(outer_padding)
97 |
98 | return reference_5point
99 |
100 |
101 | def get_affine_transform_matrix(src_pts, dst_pts):
102 |
103 | tfm = np.float32([[1, 0, 0], [0, 1, 0]])
104 | n_pts = src_pts.shape[0]
105 | ones = np.ones((n_pts, 1), src_pts.dtype)
106 | src_pts_ = np.hstack([src_pts, ones])
107 | dst_pts_ = np.hstack([dst_pts, ones])
108 |
109 | A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
110 |
111 | if rank == 3:
112 | tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]],
113 | [A[0, 1], A[1, 1], A[2, 1]]])
114 | elif rank == 2:
115 | tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
116 |
117 | return tfm
118 |
119 |
120 | def warp_and_crop_face(src_img,
121 | facial_pts,
122 | ratio=0.84,
123 | reference_pts=None,
124 | crop_size=(96, 112),
125 | align_type='similarity'
126 | '',
127 | return_trans_inv=False):
128 |
129 | if reference_pts is None:
130 | if crop_size[0] == 96 and crop_size[1] == 112:
131 | reference_pts = REFERENCE_FACIAL_POINTS
132 | else:
133 | default_square = False
134 | inner_padding_factor = 0
135 | outer_padding = (0, 0)
136 | output_size = crop_size
137 |
138 | reference_pts = get_reference_facial_points(
139 | output_size, inner_padding_factor, outer_padding,
140 | default_square)
141 |
142 | ref_pts = np.float32(reference_pts)
143 |
144 | factor = ratio
145 | ref_pts = (ref_pts - 112 / 2) * factor + 112 / 2
146 | ref_pts *= crop_size[0] / 112.
147 |
148 | ref_pts_shp = ref_pts.shape
149 | if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
150 | raise FaceWarpException(
151 | 'reference_pts.shape must be (K,2) or (2,K) and K>2')
152 |
153 | if ref_pts_shp[0] == 2:
154 | ref_pts = ref_pts.T
155 |
156 | src_pts = np.float32(facial_pts)
157 | src_pts_shp = src_pts.shape
158 | if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
159 | raise FaceWarpException(
160 | 'facial_pts.shape must be (K,2) or (2,K) and K>2')
161 |
162 | if src_pts_shp[0] == 2:
163 | src_pts = src_pts.T
164 |
165 | if src_pts.shape != ref_pts.shape:
166 | raise FaceWarpException(
167 | 'facial_pts and reference_pts must have the same shape')
168 |
169 | if align_type == 'cv2_affine':
170 | tfm = cv2.getAffineTransform(src_pts, ref_pts)
171 | tfm_inv = cv2.getAffineTransform(ref_pts, src_pts)
172 |
173 | elif align_type == 'affine':
174 | tfm = get_affine_transform_matrix(src_pts, ref_pts)
175 | tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
176 | else:
177 | tfm, tfm_inv = get_similarity_transform_for_cv2(src_pts, ref_pts)
178 |
179 | face_img = cv2.warpAffine(
180 | src_img,
181 | tfm, (crop_size[0], crop_size[1]),
182 | borderValue=(255, 255, 255))
183 |
184 | if return_trans_inv:
185 | return face_img, tfm_inv
186 | else:
187 | return face_img
188 |
--------------------------------------------------------------------------------
/source/mtcnn_pytorch/src/matlab_cp2tform.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Tue Jul 11 06:54:28 2017
3 |
4 | @author: zhaoyafei
5 | """
6 |
7 | import numpy as np
8 | from numpy.linalg import inv, lstsq
9 | from numpy.linalg import matrix_rank as rank
10 | from numpy.linalg import norm
11 |
12 |
13 | class MatlabCp2tormException(Exception):
14 |
15 | def __str__(self):
16 | return 'In File {}:{}'.format(__file__, super.__str__(self))
17 |
18 |
19 | def tformfwd(trans, uv):
20 | """
21 | Function:
22 | ----------
23 | apply affine transform 'trans' to uv
24 |
25 | Parameters:
26 | ----------
27 | @trans: 3x3 np.array
28 | transform matrix
29 | @uv: Kx2 np.array
30 | each row is a pair of coordinates (x, y)
31 |
32 | Returns:
33 | ----------
34 | @xy: Kx2 np.array
35 | each row is a pair of transformed coordinates (x, y)
36 | """
37 | uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
38 | xy = np.dot(uv, trans)
39 | xy = xy[:, 0:-1]
40 | return xy
41 |
42 |
43 | def tforminv(trans, uv):
44 | """
45 | Function:
46 | ----------
47 | apply the inverse of affine transform 'trans' to uv
48 |
49 | Parameters:
50 | ----------
51 | @trans: 3x3 np.array
52 | transform matrix
53 | @uv: Kx2 np.array
54 | each row is a pair of coordinates (x, y)
55 |
56 | Returns:
57 | ----------
58 | @xy: Kx2 np.array
59 | each row is a pair of inverse-transformed coordinates (x, y)
60 | """
61 | Tinv = inv(trans)
62 | xy = tformfwd(Tinv, uv)
63 | return xy
64 |
65 |
66 | def findNonreflectiveSimilarity(uv, xy, options=None):
67 |
68 | options = {'K': 2}
69 |
70 | K = options['K']
71 | M = xy.shape[0]
72 | x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
73 | y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
74 | # print('--->x, y:\n', x, y
75 |
76 | tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
77 | tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
78 | X = np.vstack((tmp1, tmp2))
79 | # print('--->X.shape: ', X.shape
80 | # print('X:\n', X
81 |
82 | u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
83 | v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
84 | U = np.vstack((u, v))
85 | # print('--->U.shape: ', U.shape
86 | # print('U:\n', U
87 |
88 | # We know that X * r = U
89 | if rank(X) >= 2 * K:
90 | r, _, _, _ = lstsq(X, U)
91 | r = np.squeeze(r)
92 | else:
93 | raise Exception('cp2tform:twoUniquePointsReq')
94 |
95 | # print('--->r:\n', r
96 |
97 | sc = r[0]
98 | ss = r[1]
99 | tx = r[2]
100 | ty = r[3]
101 |
102 | Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
103 |
104 | # print('--->Tinv:\n', Tinv
105 |
106 | T = inv(Tinv)
107 | # print('--->T:\n', T
108 |
109 | T[:, 2] = np.array([0, 0, 1])
110 |
111 | return T, Tinv
112 |
113 |
114 | def findSimilarity(uv, xy, options=None):
115 |
116 | options = {'K': 2}
117 |
118 | # uv = np.array(uv)
119 | # xy = np.array(xy)
120 |
121 | # Solve for trans1
122 | trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
123 |
124 | # Solve for trans2
125 |
126 | # manually reflect the xy data across the Y-axis
127 | xyR = xy
128 | xyR[:, 0] = -1 * xyR[:, 0]
129 |
130 | trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
131 |
132 | # manually reflect the tform to undo the reflection done on xyR
133 | TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
134 |
135 | trans2 = np.dot(trans2r, TreflectY)
136 |
137 | # Figure out if trans1 or trans2 is better
138 | xy1 = tformfwd(trans1, uv)
139 | norm1 = norm(xy1 - xy)
140 |
141 | xy2 = tformfwd(trans2, uv)
142 | norm2 = norm(xy2 - xy)
143 |
144 | if norm1 <= norm2:
145 | return trans1, trans1_inv
146 | else:
147 | trans2_inv = inv(trans2)
148 | return trans2, trans2_inv
149 |
150 |
151 | def get_similarity_transform(src_pts, dst_pts, reflective=True):
152 | """
153 | Function:
154 | ----------
155 | Find Similarity Transform Matrix 'trans':
156 | u = src_pts[:, 0]
157 | v = src_pts[:, 1]
158 | x = dst_pts[:, 0]
159 | y = dst_pts[:, 1]
160 | [x, y, 1] = [u, v, 1] * trans
161 |
162 | Parameters:
163 | ----------
164 | @src_pts: Kx2 np.array
165 | source points, each row is a pair of coordinates (x, y)
166 | @dst_pts: Kx2 np.array
167 | destination points, each row is a pair of transformed
168 | coordinates (x, y)
169 | @reflective: True or False
170 | if True:
171 | use reflective similarity transform
172 | else:
173 | use non-reflective similarity transform
174 |
175 | Returns:
176 | ----------
177 | @trans: 3x3 np.array
178 | transform matrix from uv to xy
179 | trans_inv: 3x3 np.array
180 | inverse of trans, transform matrix from xy to uv
181 | """
182 |
183 | if reflective:
184 | trans, trans_inv = findSimilarity(src_pts, dst_pts)
185 | else:
186 | trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
187 |
188 | return trans, trans_inv
189 |
190 |
191 | def cvt_tform_mat_for_cv2(trans):
192 | """
193 | Function:
194 | ----------
195 | Convert Transform Matrix 'trans' into 'cv2_trans' which could be
196 | directly used by cv2.warpAffine():
197 | u = src_pts[:, 0]
198 | v = src_pts[:, 1]
199 | x = dst_pts[:, 0]
200 | y = dst_pts[:, 1]
201 | [x, y].T = cv_trans * [u, v, 1].T
202 |
203 | Parameters:
204 | ----------
205 | @trans: 3x3 np.array
206 | transform matrix from uv to xy
207 |
208 | Returns:
209 | ----------
210 | @cv2_trans: 2x3 np.array
211 | transform matrix from src_pts to dst_pts, could be directly used
212 | for cv2.warpAffine()
213 | """
214 | cv2_trans = trans[:, 0:2].T
215 |
216 | return cv2_trans
217 |
218 |
219 | def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
220 | """
221 | Function:
222 | ----------
223 | Find Similarity Transform Matrix 'cv2_trans' which could be
224 | directly used by cv2.warpAffine():
225 | u = src_pts[:, 0]
226 | v = src_pts[:, 1]
227 | x = dst_pts[:, 0]
228 | y = dst_pts[:, 1]
229 | [x, y].T = cv_trans * [u, v, 1].T
230 |
231 | Parameters:
232 | ----------
233 | @src_pts: Kx2 np.array
234 | source points, each row is a pair of coordinates (x, y)
235 | @dst_pts: Kx2 np.array
236 | destination points, each row is a pair of transformed
237 | coordinates (x, y)
238 | reflective: True or False
239 | if True:
240 | use reflective similarity transform
241 | else:
242 | use non-reflective similarity transform
243 |
244 | Returns:
245 | ----------
246 | @cv2_trans: 2x3 np.array
247 | transform matrix from src_pts to dst_pts, could be directly used
248 | for cv2.warpAffine()
249 | """
250 | trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
251 | cv2_trans = cvt_tform_mat_for_cv2(trans)
252 | cv2_trans_inv = cvt_tform_mat_for_cv2(trans_inv)
253 |
254 | return cv2_trans, cv2_trans_inv
255 |
256 |
257 | if __name__ == '__main__':
258 | """
259 | u = [0, 6, -2]
260 | v = [0, 3, 5]
261 | x = [-1, 0, 4]
262 | y = [-1, -10, 4]
263 |
264 | # In Matlab, run:
265 | #
266 | # uv = [u'; v'];
267 | # xy = [x'; y'];
268 | # tform_sim=cp2tform(uv,xy,'similarity');
269 | #
270 | # trans = tform_sim.tdata.T
271 | # ans =
272 | # -0.0764 -1.6190 0
273 | # 1.6190 -0.0764 0
274 | # -3.2156 0.0290 1.0000
275 | # trans_inv = tform_sim.tdata.Tinv
276 | # ans =
277 | #
278 | # -0.0291 0.6163 0
279 | # -0.6163 -0.0291 0
280 | # -0.0756 1.9826 1.0000
281 | # xy_m=tformfwd(tform_sim, u,v)
282 | #
283 | # xy_m =
284 | #
285 | # -3.2156 0.0290
286 | # 1.1833 -9.9143
287 | # 5.0323 2.8853
288 | # uv_m=tforminv(tform_sim, x,y)
289 | #
290 | # uv_m =
291 | #
292 | # 0.5698 1.3953
293 | # 6.0872 2.2733
294 | # -2.6570 4.3314
295 | """
296 | u = [0, 6, -2]
297 | v = [0, 3, 5]
298 | x = [-1, 0, 4]
299 | y = [-1, -10, 4]
300 |
301 | uv = np.array((u, v)).T
302 | xy = np.array((x, y)).T
303 |
304 | print('\n--->uv:')
305 | print(uv)
306 | print('\n--->xy:')
307 | print(xy)
308 |
309 | trans, trans_inv = get_similarity_transform(uv, xy)
310 |
311 | print('\n--->trans matrix:')
312 | print(trans)
313 |
314 | print('\n--->trans_inv matrix:')
315 | print(trans_inv)
316 |
317 | print('\n---> apply transform to uv')
318 | print('\nxy_m = uv_augmented * trans')
319 | uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
320 | xy_m = np.dot(uv_aug, trans)
321 | print(xy_m)
322 |
323 | print('\nxy_m = tformfwd(trans, uv)')
324 | xy_m = tformfwd(trans, uv)
325 | print(xy_m)
326 |
327 | print('\n---> apply inverse transform to xy')
328 | print('\nuv_m = xy_augmented * trans_inv')
329 | xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
330 | uv_m = np.dot(xy_aug, trans_inv)
331 | print(uv_m)
332 |
333 | print('\nuv_m = tformfwd(trans_inv, xy)')
334 | uv_m = tformfwd(trans_inv, xy)
335 | print(uv_m)
336 |
337 | uv_m = tforminv(trans, xy)
338 | print('\nuv_m = tforminv(trans, xy)')
339 | print(uv_m)
340 |
--------------------------------------------------------------------------------
/source/stylegan2/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/.DS_Store
--------------------------------------------------------------------------------
/source/stylegan2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/__init__.py
--------------------------------------------------------------------------------
/source/stylegan2/config/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/config/.DS_Store
--------------------------------------------------------------------------------
/source/stylegan2/config/conf_server_test_blend_shell.json:
--------------------------------------------------------------------------------
1 | {
2 | "parameters": {
3 | "output": "face_generation/experiment_stylegan",
4 | "ffhq_ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
5 | "size": 256,
6 | "sample": 1,
7 | "pics": 5000,
8 | "truncation": 0.7,
9 | "form": "single"
10 | }
11 | }
12 |
13 |
--------------------------------------------------------------------------------
/source/stylegan2/config/conf_server_train_condition_shell.json:
--------------------------------------------------------------------------------
1 | {
2 | "parameters": {
3 | "output": "face_generation/experiment_stylegan",
4 | "ckpt": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
5 | "ckpt_ffhq": "pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
6 | "size": 256,
7 | "batch": 8,
8 | "n_sample": 4,
9 | "iter": 1500,
10 | "sample_every": 100,
11 | "save_every": 100
12 | }
13 | }
14 |
15 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/criteria/__init__.py
--------------------------------------------------------------------------------
/source/stylegan2/criteria/helpers.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
4 |
5 | """
6 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
7 | """
8 |
9 |
10 | class Flatten(Module):
11 | def forward(self, input):
12 | return input.view(input.size(0), -1)
13 |
14 |
15 | def l2_norm(input, axis=1):
16 | norm = torch.norm(input, 2, axis, True)
17 | output = torch.div(input, norm)
18 | return output
19 |
20 |
21 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
22 | """ A named tuple describing a ResNet block. """
23 |
24 |
25 | def get_block(in_channel, depth, num_units, stride=2):
26 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
27 |
28 |
29 | def get_blocks(num_layers):
30 | if num_layers == 50:
31 | blocks = [
32 | get_block(in_channel=64, depth=64, num_units=3),
33 | get_block(in_channel=64, depth=128, num_units=4),
34 | get_block(in_channel=128, depth=256, num_units=14),
35 | get_block(in_channel=256, depth=512, num_units=3)
36 | ]
37 | elif num_layers == 100:
38 | blocks = [
39 | get_block(in_channel=64, depth=64, num_units=3),
40 | get_block(in_channel=64, depth=128, num_units=13),
41 | get_block(in_channel=128, depth=256, num_units=30),
42 | get_block(in_channel=256, depth=512, num_units=3)
43 | ]
44 | elif num_layers == 152:
45 | blocks = [
46 | get_block(in_channel=64, depth=64, num_units=3),
47 | get_block(in_channel=64, depth=128, num_units=8),
48 | get_block(in_channel=128, depth=256, num_units=36),
49 | get_block(in_channel=256, depth=512, num_units=3)
50 | ]
51 | else:
52 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
53 | return blocks
54 |
55 |
56 | class SEModule(Module):
57 | def __init__(self, channels, reduction):
58 | super(SEModule, self).__init__()
59 | self.avg_pool = AdaptiveAvgPool2d(1)
60 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
61 | self.relu = ReLU(inplace=True)
62 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
63 | self.sigmoid = Sigmoid()
64 |
65 | def forward(self, x):
66 | module_input = x
67 | x = self.avg_pool(x)
68 | x = self.fc1(x)
69 | x = self.relu(x)
70 | x = self.fc2(x)
71 | x = self.sigmoid(x)
72 | return module_input * x
73 |
74 |
75 | class bottleneck_IR(Module):
76 | def __init__(self, in_channel, depth, stride):
77 | super(bottleneck_IR, self).__init__()
78 | if in_channel == depth:
79 | self.shortcut_layer = MaxPool2d(1, stride)
80 | else:
81 | self.shortcut_layer = Sequential(
82 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
83 | BatchNorm2d(depth)
84 | )
85 | self.res_layer = Sequential(
86 | BatchNorm2d(in_channel),
87 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
88 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
89 | )
90 |
91 | def forward(self, x):
92 | shortcut = self.shortcut_layer(x)
93 | res = self.res_layer(x)
94 | return res + shortcut
95 |
96 |
97 | class bottleneck_IR_SE(Module):
98 | def __init__(self, in_channel, depth, stride):
99 | super(bottleneck_IR_SE, self).__init__()
100 | if in_channel == depth:
101 | self.shortcut_layer = MaxPool2d(1, stride)
102 | else:
103 | self.shortcut_layer = Sequential(
104 | Conv2d(in_channel, depth, (1, 1), stride, bias=False),
105 | BatchNorm2d(depth)
106 | )
107 | self.res_layer = Sequential(
108 | BatchNorm2d(in_channel),
109 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
110 | PReLU(depth),
111 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
112 | BatchNorm2d(depth),
113 | SEModule(depth, 16)
114 | )
115 |
116 | def forward(self, x):
117 | shortcut = self.shortcut_layer(x)
118 | res = self.res_layer(x)
119 | return res + shortcut
120 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/id_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .model_irse import Backbone
4 |
5 |
6 | class IDLoss(nn.Module):
7 | def __init__(self):
8 | super(IDLoss, self).__init__()
9 | print('Loading ResNet ArcFace')
10 | model_paths = '/data/vdb/qingyao/cartoon/mycode/pretrained_models/model_ir_se50.pth'
11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
12 | self.facenet.load_state_dict(torch.load(model_paths))
13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
14 | self.facenet.eval()
15 |
16 | def extract_feats(self, x):
17 | x = x[:, :, 35:223, 32:220] # Crop interesting region
18 | x = self.face_pool(x)
19 | x_feats = self.facenet(x)
20 | return x_feats
21 |
22 | def forward(self, y_hat, x):
23 | n_samples = x.shape[0]
24 | x_feats = self.extract_feats(x)
25 | y_hat_feats = self.extract_feats(y_hat)
26 | loss = 0
27 | sim_improvement = 0
28 | id_logs = []
29 | count = 0
30 | for i in range(n_samples):
31 | diff_input = y_hat_feats[i].dot(x_feats[i])
32 | id_logs.append({
33 | 'diff_input': float(diff_input)
34 | })
35 | # loss += 1 - diff_target
36 | # modify
37 | loss += 1 - diff_input
38 | count += 1
39 |
40 | return loss / count
41 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/criteria/lpips/__init__.py
--------------------------------------------------------------------------------
/source/stylegan2/criteria/lpips/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from criteria.lpips.networks import get_network, LinLayers
5 | from criteria.lpips.utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 | Arguments:
12 | net_type (str): the network type to compare the features:
13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14 | version (str): the version of LPIPS. Default: 0.1.
15 | """
16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17 |
18 | assert version in ['0.1'], 'v0.1 is only supported now'
19 |
20 | super(LPIPS, self).__init__()
21 |
22 | # pretrained network
23 | self.net = get_network(net_type).to("cuda")
24 |
25 | # linear layers
26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27 | self.lin.load_state_dict(get_state_dict(net_type, version))
28 |
29 | def forward(self, x: torch.Tensor, y: torch.Tensor):
30 | feat_x, feat_y = self.net(x), self.net(y)
31 |
32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34 |
35 | return torch.sum(torch.cat(res, 0)) / x.shape[0]
36 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/lpips/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from criteria.lpips.utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(True).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
--------------------------------------------------------------------------------
/source/stylegan2/criteria/lpips/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/moco_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from configs.paths_config import model_paths
5 |
6 |
7 | class MocoLoss(nn.Module):
8 |
9 | def __init__(self):
10 | super(MocoLoss, self).__init__()
11 | print("Loading MOCO model from path: {}".format(model_paths["moco"]))
12 | self.model = self.__load_model()
13 | self.model.cuda()
14 | self.model.eval()
15 |
16 | @staticmethod
17 | def __load_model():
18 | import torchvision.models as models
19 | model = models.__dict__["resnet50"]()
20 | # freeze all layers but the last fc
21 | for name, param in model.named_parameters():
22 | if name not in ['fc.weight', 'fc.bias']:
23 | param.requires_grad = False
24 | checkpoint = torch.load(model_paths['moco'], map_location="cpu")
25 | state_dict = checkpoint['state_dict']
26 | # rename moco pre-trained keys
27 | for k in list(state_dict.keys()):
28 | # retain only encoder_q up to before the embedding layer
29 | if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
30 | # remove prefix
31 | state_dict[k[len("module.encoder_q."):]] = state_dict[k]
32 | # delete renamed or unused k
33 | del state_dict[k]
34 | msg = model.load_state_dict(state_dict, strict=False)
35 | assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
36 | # remove output layer
37 | model = nn.Sequential(*list(model.children())[:-1]).cuda()
38 | return model
39 |
40 | def extract_feats(self, x):
41 | x = F.interpolate(x, size=224)
42 | x_feats = self.model(x)
43 | x_feats = nn.functional.normalize(x_feats, dim=1)
44 | x_feats = x_feats.squeeze()
45 | return x_feats
46 |
47 | def forward(self, y_hat, y, x):
48 | n_samples = x.shape[0]
49 | x_feats = self.extract_feats(x)
50 | y_feats = self.extract_feats(y)
51 | y_hat_feats = self.extract_feats(y_hat)
52 | y_feats = y_feats.detach()
53 | loss = 0
54 | sim_improvement = 0
55 | sim_logs = []
56 | count = 0
57 | for i in range(n_samples):
58 | diff_target = y_hat_feats[i].dot(y_feats[i])
59 | diff_input = y_hat_feats[i].dot(x_feats[i])
60 | diff_views = y_feats[i].dot(x_feats[i])
61 | sim_logs.append({'diff_target': float(diff_target),
62 | 'diff_input': float(diff_input),
63 | 'diff_views': float(diff_views)})
64 | loss += 1 - diff_target
65 | sim_diff = float(diff_target) - float(diff_views)
66 | sim_improvement += sim_diff
67 | count += 1
68 |
69 | return loss / count, sim_improvement / count, sim_logs
70 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/model_irse.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
2 | from .helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
3 |
4 | """
5 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
6 | """
7 |
8 |
9 | class Backbone(Module):
10 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
11 | super(Backbone, self).__init__()
12 | assert input_size in [112, 224], "input_size should be 112 or 224"
13 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
14 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
15 | blocks = get_blocks(num_layers)
16 | if mode == 'ir':
17 | unit_module = bottleneck_IR
18 | elif mode == 'ir_se':
19 | unit_module = bottleneck_IR_SE
20 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
21 | BatchNorm2d(64),
22 | PReLU(64))
23 | if input_size == 112:
24 | self.output_layer = Sequential(BatchNorm2d(512),
25 | Dropout(drop_ratio),
26 | Flatten(),
27 | Linear(512 * 7 * 7, 512),
28 | BatchNorm1d(512, affine=affine))
29 | else:
30 | self.output_layer = Sequential(BatchNorm2d(512),
31 | Dropout(drop_ratio),
32 | Flatten(),
33 | Linear(512 * 14 * 14, 512),
34 | BatchNorm1d(512, affine=affine))
35 |
36 | modules = []
37 | for block in blocks:
38 | for bottleneck in block:
39 | modules.append(unit_module(bottleneck.in_channel,
40 | bottleneck.depth,
41 | bottleneck.stride))
42 | self.body = Sequential(*modules)
43 |
44 | def forward(self, x):
45 | x = self.input_layer(x)
46 | x = self.body(x)
47 | x = self.output_layer(x)
48 | return l2_norm(x)
49 |
50 |
51 | def IR_50(input_size):
52 | """Constructs a ir-50 model."""
53 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
54 | return model
55 |
56 |
57 | def IR_101(input_size):
58 | """Constructs a ir-101 model."""
59 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
60 | return model
61 |
62 |
63 | def IR_152(input_size):
64 | """Constructs a ir-152 model."""
65 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
66 | return model
67 |
68 |
69 | def IR_SE_50(input_size):
70 | """Constructs a ir_se-50 model."""
71 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
72 | return model
73 |
74 |
75 | def IR_SE_101(input_size):
76 | """Constructs a ir_se-101 model."""
77 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
78 | return model
79 |
80 |
81 | def IR_SE_152(input_size):
82 | """Constructs a ir_se-152 model."""
83 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
84 | return model
85 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import optim
6 |
7 | from PIL import Image
8 | import os
9 |
10 |
11 | # gram matrix and loss
12 | class GramMatrix(nn.Module):
13 | def forward(self, input):
14 | b, c, h, w = input.size()
15 | F = input.view(b, c, h * w)
16 | G = torch.bmm(F, F.transpose(1, 2))
17 | G.div_(h * w)
18 | return G
19 |
20 |
21 | class GramMSELoss(nn.Module):
22 | def forward(self, input, target):
23 | out = nn.MSELoss()(GramMatrix()(input), target)
24 | return (out)
25 |
26 |
27 | # vgg definition that conveniently let's you grab the outputs from any layer
28 | class VGG(nn.Module):
29 | def __init__(self, pool='max'):
30 | super(VGG, self).__init__()
31 | # vgg modules
32 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
33 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
34 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
35 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
36 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
37 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
38 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
39 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
40 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
41 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
42 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
43 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
44 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
45 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
46 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
47 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
48 |
49 | if pool == 'max':
50 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
51 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
52 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
53 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
54 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
55 |
56 | elif pool == 'avg':
57 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
58 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
59 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
60 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
61 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
62 |
63 |
64 | def forward(self, x):
65 | out = {}
66 | out['r11'] = F.relu(self.conv1_1(x))
67 | out['r12'] = F.relu(self.conv1_2(out['r11']))
68 | out['p1'] = self.pool1(out['r12'])
69 | out['r21'] = F.relu(self.conv2_1(out['p1']))
70 | out['r22'] = F.relu(self.conv2_2(out['r21']))
71 | out['p2'] = self.pool2(out['r22'])
72 | out['r31'] = F.relu(self.conv3_1(out['p2']))
73 | out['r32'] = F.relu(self.conv3_2(out['r31']))
74 | out['r33'] = F.relu(self.conv3_3(out['r32']))
75 | out['r34'] = F.relu(self.conv3_4(out['r33']))
76 | out['p3'] = self.pool3(out['r34'])
77 | out['r41'] = F.relu(self.conv4_1(out['p3']))
78 | out['r42'] = F.relu(self.conv4_2(out['r41']))
79 | out['r43'] = F.relu(self.conv4_3(out['r42']))
80 | conv_4_4 = self.conv4_4(out['r43'])
81 | out['r44'] = F.relu(conv_4_4)
82 | out['p4'] = self.pool4(out['r44'])
83 | return conv_4_4
--------------------------------------------------------------------------------
/source/stylegan2/criteria/vgg_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .vgg import VGG
4 | import os
5 | from configs.paths_config import model_paths
6 |
7 | class VggLoss(nn.Module):
8 | def __init__(self):
9 | super(VggLoss, self).__init__()
10 | print("Loading VGG19 model from path: {}".format(model_paths["vgg"]))
11 |
12 | self.vgg_model = VGG()
13 | self.vgg_model.load_state_dict(torch.load(model_paths['vgg']))
14 | self.vgg_model.cuda()
15 | self.vgg_model.eval()
16 |
17 | self.l1loss = torch.nn.L1Loss()
18 |
19 |
20 |
21 |
22 |
23 | def forward(self, input_photo, output):
24 | vgg_photo = self.vgg_model(input_photo)
25 | vgg_output = self.vgg_model(output)
26 | n, c, h, w = vgg_photo.shape
27 | # h, w, c = vgg_photo.get_shape().as_list()[1:]
28 | loss = self.l1loss(vgg_photo, vgg_output)
29 |
30 | return loss
31 |
--------------------------------------------------------------------------------
/source/stylegan2/criteria/w_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class WNormLoss(nn.Module):
6 |
7 | def __init__(self, start_from_latent_avg=True):
8 | super(WNormLoss, self).__init__()
9 | self.start_from_latent_avg = start_from_latent_avg
10 |
11 | def forward(self, latent, latent_avg=None):
12 | if self.start_from_latent_avg:
13 | latent = latent - latent_avg
14 | return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
15 |
--------------------------------------------------------------------------------
/source/stylegan2/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class MultiResolutionDataset(Dataset):
9 | def __init__(self, path, transform, resolution=256):
10 | self.env = lmdb.open(
11 | path,
12 | max_readers=32,
13 | readonly=True,
14 | lock=False,
15 | readahead=False,
16 | meminit=False,
17 | )
18 |
19 | if not self.env:
20 | raise IOError('Cannot open lmdb dataset', path)
21 |
22 | with self.env.begin(write=False) as txn:
23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24 |
25 | self.resolution = resolution
26 | self.transform = transform
27 |
28 | def __len__(self):
29 | return self.length
30 |
31 | def __getitem__(self, index):
32 | with self.env.begin(write=False) as txn:
33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34 | img_bytes = txn.get(key)
35 |
36 | buffer = BytesIO(img_bytes)
37 | img = Image.open(buffer)
38 | img = self.transform(img)
39 |
40 | return img
41 |
--------------------------------------------------------------------------------
/source/stylegan2/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def reduce_sum(tensor):
45 | if not dist.is_available():
46 | return tensor
47 |
48 | if not dist.is_initialized():
49 | return tensor
50 |
51 | tensor = tensor.clone()
52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53 |
54 | return tensor
55 |
56 |
57 | def gather_grad(params):
58 | world_size = get_world_size()
59 |
60 | if world_size == 1:
61 | return
62 |
63 | for param in params:
64 | if param.grad is not None:
65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66 | param.grad.data.div_(world_size)
67 |
68 |
69 | def all_gather(data):
70 | world_size = get_world_size()
71 |
72 | if world_size == 1:
73 | return [data]
74 |
75 | buffer = pickle.dumps(data)
76 | storage = torch.ByteStorage.from_buffer(buffer)
77 | tensor = torch.ByteTensor(storage).to('cuda')
78 |
79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81 | dist.all_gather(size_list, local_size)
82 | size_list = [int(size.item()) for size in size_list]
83 | max_size = max(size_list)
84 |
85 | tensor_list = []
86 | for _ in size_list:
87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88 |
89 | if local_size != max_size:
90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91 | tensor = torch.cat((tensor, padding), 0)
92 |
93 | dist.all_gather(tensor_list, tensor)
94 |
95 | data_list = []
96 |
97 | for size, tensor in zip(size_list, tensor_list):
98 | buffer = tensor.cpu().numpy().tobytes()[:size]
99 | data_list.append(pickle.loads(buffer))
100 |
101 | return data_list
102 |
103 |
104 | def reduce_loss_dict(loss_dict):
105 | world_size = get_world_size()
106 |
107 | if world_size < 2:
108 | return loss_dict
109 |
110 | with torch.no_grad():
111 | keys = []
112 | losses = []
113 |
114 | for k in sorted(loss_dict.keys()):
115 | keys.append(k)
116 | losses.append(loss_dict[k])
117 |
118 | losses = torch.stack(losses, 0)
119 | dist.reduce(losses, dst=0)
120 |
121 | if dist.get_rank() == 0:
122 | losses /= world_size
123 |
124 | reduced_losses = {k: v for k, v in zip(keys, losses)}
125 |
126 | return reduced_losses
127 |
--------------------------------------------------------------------------------
/source/stylegan2/generate_blendmodel.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "7"
4 | import torch
5 | from torchvision import utils
6 | from model import Generator
7 | from tqdm import tqdm
8 | import json
9 | import glob
10 |
11 | from PIL import Image
12 |
13 | def make_image(tensor):
14 | return (
15 | tensor.detach()
16 | .clamp_(min=-1, max=1)
17 | .add(1)
18 | .div_(2)
19 | .mul(255)
20 | .type(torch.uint8)
21 | .permute(0, 2, 3, 1)
22 | .to("cpu")
23 | .numpy()
24 | )
25 |
26 | def generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq):
27 |
28 | outdir = args.save_dir
29 |
30 | # print(outdir)
31 | # outdir = os.path.join(args.output, args.name, 'eval','toons_paired_0512')
32 | if not os.path.exists(outdir):
33 | os.makedirs(outdir)
34 |
35 | with torch.no_grad():
36 | g_ema.eval()
37 | for i in tqdm(range(args.pics)):
38 | sample_z = torch.randn(args.sample, args.latent, device=device)
39 |
40 | res, _ = g_ema(
41 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
42 | )
43 | if args.form == "pair":
44 | sample_face, _ = g_ema_ffhq(
45 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
46 | )
47 | res = torch.cat([sample_face, res], 3)
48 |
49 | outpath = os.path.join(outdir, str(i).zfill(6)+'.png')
50 | utils.save_image(
51 | res,
52 | outpath,
53 | # f"sample/{str(i).zfill(6)}.png",
54 | nrow=1,
55 | normalize=True,
56 | range=(-1, 1),
57 | )
58 | # print('save %s'% outpath)
59 |
60 |
61 |
62 |
63 | if __name__ == "__main__":
64 | device = "cuda"
65 |
66 | parser = argparse.ArgumentParser(description="Generate samples from the generator")
67 | parser.add_argument('--config', type=str, default='config/conf_server_test_blend_shell.json')
68 | parser.add_argument('--name', type=str, default='')
69 | parser.add_argument('--save_dir', type=str, default='')
70 |
71 | parser.add_argument('--form', type=str, default='single')
72 | parser.add_argument(
73 | "--size", type=int, default=256, help="output image size of the generator"
74 | )
75 | parser.add_argument(
76 | "--sample",
77 | type=int,
78 | default=1,
79 | help="number of samples to be generated for each image",
80 | )
81 | parser.add_argument(
82 | "--pics", type=int, default=20, help="number of images to be generated"
83 | )
84 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
85 | parser.add_argument(
86 | "--truncation_mean",
87 | type=int,
88 | default=4096,
89 | help="number of vectors to calculate mean for the truncation",
90 | )
91 | parser.add_argument(
92 | "--ckpt",
93 | type=str,
94 | default="stylegan2-ffhq-config-f.pt",
95 | help="path to the model checkpoint",
96 | )
97 | parser.add_argument(
98 | "--channel_multiplier",
99 | type=int,
100 | default=2,
101 | help="channel multiplier of the generator. config-f = 2, else = 1",
102 | )
103 |
104 | args = parser.parse_args()
105 | # from config updata paras
106 | opt = vars(args)
107 | with open(args.config) as f:
108 | config = json.load(f)['parameters']
109 | for key, value in config.items():
110 | opt[key] = value
111 |
112 | # args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_001000_4.pt'
113 | args.ckpt = 'face_generation/experiment_stylegan/'+args.name+'/models_blend/G_blend_'
114 | args.ckpt = glob.glob(args.ckpt+'*')[0]
115 |
116 | args.latent = 512
117 | args.n_mlp = 8
118 |
119 | g_ema = Generator(
120 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
121 | ).to(device)
122 | checkpoint = torch.load(args.ckpt)
123 |
124 | # g_ema.load_state_dict(checkpoint["g_ema"])
125 | g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
126 |
127 | ## add G_ffhq
128 | g_ema_ffhq = Generator(
129 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
130 | ).to(device)
131 | checkpoint_ffhq = torch.load(args.ffhq_ckpt)
132 | g_ema_ffhq.load_state_dict(checkpoint_ffhq["g_ema"], strict=False)
133 |
134 |
135 | if args.truncation < 1:
136 | with torch.no_grad():
137 | mean_latent = g_ema.mean_latent(args.truncation_mean)
138 | else:
139 | mean_latent = None
140 |
141 | model_name = os.path.basename(args.ckpt)
142 | print('save generated samples to %s'% os.path.join(args.output, args.name, 'eval_blend', model_name))
143 | generate(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)
144 | # generate_style_mix(args, g_ema, device, mean_latent, model_name, g_ema_ffhq)
145 |
146 | # latent_path = 'test2.pt'
147 | # generate_from_latent(args, g_ema, device, mean_latent, latent_path)
148 |
--------------------------------------------------------------------------------
/source/stylegan2/noise.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/menyifang/DCT-Net/147e4f2bb7156d77ecf301a13b63d997ea2ecf10/source/stylegan2/noise.pt
--------------------------------------------------------------------------------
/source/stylegan2/non_leaking.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import autograd
5 | from torch.nn import functional as F
6 | import numpy as np
7 |
8 | from distributed import reduce_sum
9 | from op import upfirdn2d
10 |
11 |
12 | class AdaptiveAugment:
13 | def __init__(self, ada_aug_target, ada_aug_len, update_every, device):
14 | self.ada_aug_target = ada_aug_target
15 | self.ada_aug_len = ada_aug_len
16 | self.update_every = update_every
17 |
18 | self.ada_update = 0
19 | self.ada_aug_buf = torch.tensor([0.0, 0.0], device=device)
20 | self.r_t_stat = 0
21 | self.ada_aug_p = 0
22 |
23 | @torch.no_grad()
24 | def tune(self, real_pred):
25 | self.ada_aug_buf += torch.tensor(
26 | (torch.sign(real_pred).sum().item(), real_pred.shape[0]),
27 | device=real_pred.device,
28 | )
29 | self.ada_update += 1
30 |
31 | if self.ada_update % self.update_every == 0:
32 | self.ada_aug_buf = reduce_sum(self.ada_aug_buf)
33 | pred_signs, n_pred = self.ada_aug_buf.tolist()
34 |
35 | self.r_t_stat = pred_signs / n_pred
36 |
37 | if self.r_t_stat > self.ada_aug_target:
38 | sign = 1
39 |
40 | else:
41 | sign = -1
42 |
43 | self.ada_aug_p += sign * n_pred / self.ada_aug_len
44 | self.ada_aug_p = min(1, max(0, self.ada_aug_p))
45 | self.ada_aug_buf.mul_(0)
46 | self.ada_update = 0
47 |
48 | return self.ada_aug_p
49 |
50 |
51 | SYM6 = (
52 | 0.015404109327027373,
53 | 0.0034907120842174702,
54 | -0.11799011114819057,
55 | -0.048311742585633,
56 | 0.4910559419267466,
57 | 0.787641141030194,
58 | 0.3379294217276218,
59 | -0.07263752278646252,
60 | -0.021060292512300564,
61 | 0.04472490177066578,
62 | 0.0017677118642428036,
63 | -0.007800708325034148,
64 | )
65 |
66 |
67 | def translate_mat(t_x, t_y, device="cpu"):
68 | batch = t_x.shape[0]
69 |
70 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
71 | translate = torch.stack((t_x, t_y), 1)
72 | mat[:, :2, 2] = translate
73 |
74 | return mat
75 |
76 |
77 | def rotate_mat(theta, device="cpu"):
78 | batch = theta.shape[0]
79 |
80 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
81 | sin_t = torch.sin(theta)
82 | cos_t = torch.cos(theta)
83 | rot = torch.stack((cos_t, -sin_t, sin_t, cos_t), 1).view(batch, 2, 2)
84 | mat[:, :2, :2] = rot
85 |
86 | return mat
87 |
88 |
89 | def scale_mat(s_x, s_y, device="cpu"):
90 | batch = s_x.shape[0]
91 |
92 | mat = torch.eye(3, device=device).unsqueeze(0).repeat(batch, 1, 1)
93 | mat[:, 0, 0] = s_x
94 | mat[:, 1, 1] = s_y
95 |
96 | return mat
97 |
98 |
99 | def translate3d_mat(t_x, t_y, t_z):
100 | batch = t_x.shape[0]
101 |
102 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
103 | translate = torch.stack((t_x, t_y, t_z), 1)
104 | mat[:, :3, 3] = translate
105 |
106 | return mat
107 |
108 |
109 | def rotate3d_mat(axis, theta):
110 | batch = theta.shape[0]
111 |
112 | u_x, u_y, u_z = axis
113 |
114 | eye = torch.eye(3).unsqueeze(0)
115 | cross = torch.tensor([(0, -u_z, u_y), (u_z, 0, -u_x), (-u_y, u_x, 0)]).unsqueeze(0)
116 | outer = torch.tensor(axis)
117 | outer = (outer.unsqueeze(1) * outer).unsqueeze(0)
118 |
119 | sin_t = torch.sin(theta).view(-1, 1, 1)
120 | cos_t = torch.cos(theta).view(-1, 1, 1)
121 |
122 | rot = cos_t * eye + sin_t * cross + (1 - cos_t) * outer
123 |
124 | eye_4 = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
125 | eye_4[:, :3, :3] = rot
126 |
127 | return eye_4
128 |
129 |
130 | def scale3d_mat(s_x, s_y, s_z):
131 | batch = s_x.shape[0]
132 |
133 | mat = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
134 | mat[:, 0, 0] = s_x
135 | mat[:, 1, 1] = s_y
136 | mat[:, 2, 2] = s_z
137 |
138 | return mat
139 |
140 |
141 | def luma_flip_mat(axis, i):
142 | batch = i.shape[0]
143 |
144 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
145 | axis = torch.tensor(axis + (0,))
146 | flip = 2 * torch.ger(axis, axis) * i.view(-1, 1, 1)
147 |
148 | return eye - flip
149 |
150 |
151 | def saturation_mat(axis, i):
152 | batch = i.shape[0]
153 |
154 | eye = torch.eye(4).unsqueeze(0).repeat(batch, 1, 1)
155 | axis = torch.tensor(axis + (0,))
156 | axis = torch.ger(axis, axis)
157 | saturate = axis + (eye - axis) * i.view(-1, 1, 1)
158 |
159 | return saturate
160 |
161 |
162 | def lognormal_sample(size, mean=0, std=1, device="cpu"):
163 | return torch.empty(size, device=device).log_normal_(mean=mean, std=std)
164 |
165 |
166 | def category_sample(size, categories, device="cpu"):
167 | category = torch.tensor(categories, device=device)
168 | sample = torch.randint(high=len(categories), size=(size,), device=device)
169 |
170 | return category[sample]
171 |
172 |
173 | def uniform_sample(size, low, high, device="cpu"):
174 | return torch.empty(size, device=device).uniform_(low, high)
175 |
176 |
177 | def normal_sample(size, mean=0, std=1, device="cpu"):
178 | return torch.empty(size, device=device).normal_(mean, std)
179 |
180 |
181 | def bernoulli_sample(size, p, device="cpu"):
182 | return torch.empty(size, device=device).bernoulli_(p)
183 |
184 |
185 | def random_mat_apply(p, transform, prev, eye, device="cpu"):
186 | size = transform.shape[0]
187 | select = bernoulli_sample(size, p, device=device).view(size, 1, 1)
188 | select_transform = select * transform + (1 - select) * eye
189 |
190 | return select_transform @ prev
191 |
192 |
193 | def sample_affine(p, size, height, width, device="cpu"):
194 | G = torch.eye(3, device=device).unsqueeze(0).repeat(size, 1, 1)
195 | eye = G
196 |
197 | # flip
198 | param = category_sample(size, (0, 1))
199 | Gc = scale_mat(1 - 2.0 * param, torch.ones(size), device=device)
200 | G = random_mat_apply(p, Gc, G, eye, device=device)
201 | # print('flip', G, scale_mat(1 - 2.0 * param, torch.ones(size)), sep='\n')
202 |
203 | # 90 rotate
204 | param = category_sample(size, (0, 3))
205 | Gc = rotate_mat(-math.pi / 2 * param, device=device)
206 | G = random_mat_apply(p, Gc, G, eye, device=device)
207 | # print('90 rotate', G, rotate_mat(-math.pi / 2 * param), sep='\n')
208 |
209 | # integer translate
210 | param = uniform_sample(size, -0.125, 0.125)
211 | param_height = torch.round(param * height) / height
212 | param_width = torch.round(param * width) / width
213 | Gc = translate_mat(param_width, param_height, device=device)
214 | G = random_mat_apply(p, Gc, G, eye, device=device)
215 | # print('integer translate', G, translate_mat(param_width, param_height), sep='\n')
216 |
217 | # isotropic scale
218 | param = lognormal_sample(size, std=0.2 * math.log(2))
219 | Gc = scale_mat(param, param, device=device)
220 | G = random_mat_apply(p, Gc, G, eye, device=device)
221 | # print('isotropic scale', G, scale_mat(param, param), sep='\n')
222 |
223 | p_rot = 1 - math.sqrt(1 - p)
224 |
225 | # pre-rotate
226 | param = uniform_sample(size, -math.pi, math.pi)
227 | Gc = rotate_mat(-param, device=device)
228 | G = random_mat_apply(p_rot, Gc, G, eye, device=device)
229 | # print('pre-rotate', G, rotate_mat(-param), sep='\n')
230 |
231 | # anisotropic scale
232 | param = lognormal_sample(size, std=0.2 * math.log(2))
233 | Gc = scale_mat(param, 1 / param, device=device)
234 | G = random_mat_apply(p, Gc, G, eye, device=device)
235 | # print('anisotropic scale', G, scale_mat(param, 1 / param), sep='\n')
236 |
237 | # post-rotate
238 | param = uniform_sample(size, -math.pi, math.pi)
239 | Gc = rotate_mat(-param, device=device)
240 | G = random_mat_apply(p_rot, Gc, G, eye, device=device)
241 | # print('post-rotate', G, rotate_mat(-param), sep='\n')
242 |
243 | # fractional translate
244 | param = normal_sample(size, std=0.125)
245 | Gc = translate_mat(param, param, device=device)
246 | G = random_mat_apply(p, Gc, G, eye, device=device)
247 | # print('fractional translate', G, translate_mat(param, param), sep='\n')
248 |
249 | return G
250 |
251 |
252 | def sample_color(p, size):
253 | C = torch.eye(4).unsqueeze(0).repeat(size, 1, 1)
254 | eye = C
255 | axis_val = 1 / math.sqrt(3)
256 | axis = (axis_val, axis_val, axis_val)
257 |
258 | # brightness
259 | param = normal_sample(size, std=0.2)
260 | Cc = translate3d_mat(param, param, param)
261 | C = random_mat_apply(p, Cc, C, eye)
262 |
263 | # contrast
264 | param = lognormal_sample(size, std=0.5 * math.log(2))
265 | Cc = scale3d_mat(param, param, param)
266 | C = random_mat_apply(p, Cc, C, eye)
267 |
268 | # luma flip
269 | param = category_sample(size, (0, 1))
270 | Cc = luma_flip_mat(axis, param)
271 | C = random_mat_apply(p, Cc, C, eye)
272 |
273 | # hue rotation
274 | param = uniform_sample(size, -math.pi, math.pi)
275 | Cc = rotate3d_mat(axis, param)
276 | C = random_mat_apply(p, Cc, C, eye)
277 |
278 | # saturation
279 | param = lognormal_sample(size, std=1 * math.log(2))
280 | Cc = saturation_mat(axis, param)
281 | C = random_mat_apply(p, Cc, C, eye)
282 |
283 | return C
284 |
285 |
286 | def make_grid(shape, x0, x1, y0, y1, device):
287 | n, c, h, w = shape
288 | grid = torch.empty(n, h, w, 3, device=device)
289 | grid[:, :, :, 0] = torch.linspace(x0, x1, w, device=device)
290 | grid[:, :, :, 1] = torch.linspace(y0, y1, h, device=device).unsqueeze(-1)
291 | grid[:, :, :, 2] = 1
292 |
293 | return grid
294 |
295 |
296 | def affine_grid(grid, mat):
297 | n, h, w, _ = grid.shape
298 | return (grid.view(n, h * w, 3) @ mat.transpose(1, 2)).view(n, h, w, 2)
299 |
300 |
301 | def get_padding(G, height, width, kernel_size):
302 | device = G.device
303 |
304 | cx = (width - 1) / 2
305 | cy = (height - 1) / 2
306 | cp = torch.tensor(
307 | [(-cx, -cy, 1), (cx, -cy, 1), (cx, cy, 1), (-cx, cy, 1)], device=device
308 | )
309 | cp = G @ cp.T
310 |
311 | pad_k = kernel_size // 4
312 |
313 | pad = cp[:, :2, :].permute(1, 0, 2).flatten(1)
314 | pad = torch.cat((-pad, pad)).max(1).values
315 | pad = pad + torch.tensor([pad_k * 2 - cx, pad_k * 2 - cy] * 2, device=device)
316 | pad = pad.max(torch.tensor([0, 0] * 2, device=device))
317 | pad = pad.min(torch.tensor([width - 1, height - 1] * 2, device=device))
318 |
319 | pad_x1, pad_y1, pad_x2, pad_y2 = pad.ceil().to(torch.int32)
320 |
321 | return pad_x1, pad_x2, pad_y1, pad_y2
322 |
323 |
324 | def try_sample_affine_and_pad(img, p, kernel_size, G=None):
325 | batch, _, height, width = img.shape
326 |
327 | G_try = G
328 |
329 | if G is None:
330 | G_try = torch.inverse(sample_affine(p, batch, height, width))
331 |
332 | pad_x1, pad_x2, pad_y1, pad_y2 = get_padding(G_try, height, width, kernel_size)
333 |
334 | img_pad = F.pad(img, (pad_x1, pad_x2, pad_y1, pad_y2), mode="reflect")
335 |
336 | return img_pad, G_try, (pad_x1, pad_x2, pad_y1, pad_y2)
337 |
338 |
339 | class GridSampleForward(autograd.Function):
340 | @staticmethod
341 | def forward(ctx, input, grid):
342 | out = F.grid_sample(
343 | input, grid, mode="bilinear", padding_mode="zeros", align_corners=False
344 | )
345 | ctx.save_for_backward(input, grid)
346 |
347 | return out
348 |
349 | @staticmethod
350 | def backward(ctx, grad_output):
351 | input, grid = ctx.saved_tensors
352 | grad_input, grad_grid = GridSampleBackward.apply(grad_output, input, grid)
353 |
354 | return grad_input, grad_grid
355 |
356 |
357 | class GridSampleBackward(autograd.Function):
358 | @staticmethod
359 | def forward(ctx, grad_output, input, grid):
360 | op = torch._C._jit_get_operation("aten::grid_sampler_2d_backward")
361 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
362 | ctx.save_for_backward(grid)
363 |
364 | return grad_input, grad_grid
365 |
366 | @staticmethod
367 | def backward(ctx, grad_grad_input, grad_grad_grid):
368 | grid, = ctx.saved_tensors
369 | grad_grad_output = None
370 |
371 | if ctx.needs_input_grad[0]:
372 | grad_grad_output = GridSampleForward.apply(grad_grad_input, grid)
373 |
374 | return grad_grad_output, None, None
375 |
376 |
377 | grid_sample = GridSampleForward.apply
378 |
379 |
380 | def scale_mat_single(s_x, s_y):
381 | return torch.tensor(((s_x, 0, 0), (0, s_y, 0), (0, 0, 1)), dtype=torch.float32)
382 |
383 |
384 | def translate_mat_single(t_x, t_y):
385 | return torch.tensor(((1, 0, t_x), (0, 1, t_y), (0, 0, 1)), dtype=torch.float32)
386 |
387 |
388 | def random_apply_affine(img, p, G=None, antialiasing_kernel=SYM6):
389 | kernel = antialiasing_kernel
390 | len_k = len(kernel)
391 |
392 | kernel = torch.as_tensor(kernel).to(img)
393 | # kernel = torch.ger(kernel, kernel).to(img)
394 | kernel_flip = torch.flip(kernel, (0,))
395 |
396 | img_pad, G, (pad_x1, pad_x2, pad_y1, pad_y2) = try_sample_affine_and_pad(
397 | img, p, len_k, G
398 | )
399 |
400 | G_inv = (
401 | translate_mat_single((pad_x1 - pad_x2).item() / 2, (pad_y1 - pad_y2).item() / 2)
402 | @ G
403 | )
404 | up_pad = (
405 | (len_k + 2 - 1) // 2,
406 | (len_k - 2) // 2,
407 | (len_k + 2 - 1) // 2,
408 | (len_k - 2) // 2,
409 | )
410 | img_2x = upfirdn2d(img_pad, kernel.unsqueeze(0), up=(2, 1), pad=(*up_pad[:2], 0, 0))
411 | img_2x = upfirdn2d(img_2x, kernel.unsqueeze(1), up=(1, 2), pad=(0, 0, *up_pad[2:]))
412 | G_inv = scale_mat_single(2, 2) @ G_inv @ scale_mat_single(1 / 2, 1 / 2)
413 | G_inv = translate_mat_single(-0.5, -0.5) @ G_inv @ translate_mat_single(0.5, 0.5)
414 | batch_size, channel, height, width = img.shape
415 | pad_k = len_k // 4
416 | shape = (batch_size, channel, (height + pad_k * 2) * 2, (width + pad_k * 2) * 2)
417 | G_inv = (
418 | scale_mat_single(2 / img_2x.shape[3], 2 / img_2x.shape[2])
419 | @ G_inv
420 | @ scale_mat_single(1 / (2 / shape[3]), 1 / (2 / shape[2]))
421 | )
422 | grid = F.affine_grid(G_inv[:, :2, :].to(img_2x), shape, align_corners=False)
423 | img_affine = grid_sample(img_2x, grid)
424 | d_p = -pad_k * 2
425 | down_pad = (
426 | d_p + (len_k - 2 + 1) // 2,
427 | d_p + (len_k - 2) // 2,
428 | d_p + (len_k - 2 + 1) // 2,
429 | d_p + (len_k - 2) // 2,
430 | )
431 | img_down = upfirdn2d(
432 | img_affine, kernel_flip.unsqueeze(0), down=(2, 1), pad=(*down_pad[:2], 0, 0)
433 | )
434 | img_down = upfirdn2d(
435 | img_down, kernel_flip.unsqueeze(1), down=(1, 2), pad=(0, 0, *down_pad[2:])
436 | )
437 |
438 | return img_down, G
439 |
440 |
441 | def apply_color(img, mat):
442 | batch = img.shape[0]
443 | img = img.permute(0, 2, 3, 1)
444 | mat_mul = mat[:, :3, :3].transpose(1, 2).view(batch, 1, 3, 3)
445 | mat_add = mat[:, :3, 3].view(batch, 1, 1, 3)
446 | img = img @ mat_mul + mat_add
447 | img = img.permute(0, 3, 1, 2)
448 |
449 | return img
450 |
451 |
452 | def random_apply_color(img, p, C=None):
453 | if C is None:
454 | C = sample_color(p, img.shape[0])
455 |
456 | img = apply_color(img, C.to(img))
457 |
458 | return img, C
459 |
460 |
461 | def augment(img, p, transform_matrix=(None, None)):
462 | img, G = random_apply_affine(img, p, transform_matrix[0])
463 | img, C = random_apply_color(img, p, transform_matrix[1])
464 |
465 | return img, (G, C)
466 |
--------------------------------------------------------------------------------
/source/stylegan2/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/source/stylegan2/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/source/stylegan2/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | if bias:
39 | grad_bias = grad_input.sum(dim).detach()
40 |
41 | else:
42 | grad_bias = empty
43 |
44 | return grad_input, grad_bias
45 |
46 | @staticmethod
47 | def backward(ctx, gradgrad_input, gradgrad_bias):
48 | out, = ctx.saved_tensors
49 | gradgrad_out = fused.fused_bias_act(
50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
51 | )
52 |
53 | return gradgrad_out, None, None, None, None
54 |
55 |
56 | class FusedLeakyReLUFunction(Function):
57 | @staticmethod
58 | def forward(ctx, input, bias, negative_slope, scale):
59 | empty = input.new_empty(0)
60 |
61 | ctx.bias = bias is not None
62 |
63 | if bias is None:
64 | bias = empty
65 |
66 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
67 | ctx.save_for_backward(out)
68 | ctx.negative_slope = negative_slope
69 | ctx.scale = scale
70 |
71 | return out
72 |
73 | @staticmethod
74 | def backward(ctx, grad_output):
75 | out, = ctx.saved_tensors
76 |
77 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
78 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
79 | )
80 |
81 | if not ctx.bias:
82 | grad_bias = None
83 |
84 | return grad_input, grad_bias, None, None
85 |
86 |
87 | class FusedLeakyReLU(nn.Module):
88 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
89 | super().__init__()
90 |
91 | if bias:
92 | self.bias = nn.Parameter(torch.zeros(channel))
93 |
94 | else:
95 | self.bias = None
96 |
97 | self.negative_slope = negative_slope
98 | self.scale = scale
99 |
100 | def forward(self, input):
101 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
102 |
103 |
104 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
105 | if input.device.type == "cpu":
106 | if bias is not None:
107 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
108 | return (
109 | F.leaky_relu(
110 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
111 | )
112 | * scale
113 | )
114 |
115 | else:
116 | return F.leaky_relu(input, negative_slope=0.2) * scale
117 |
118 | else:
119 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
120 |
--------------------------------------------------------------------------------
/source/stylegan2/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6 | const torch::Tensor &bias,
7 | const torch::Tensor &refer, int act, int grad,
8 | float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) \
11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 |
18 | torch::Tensor fused_bias_act(const torch::Tensor &input,
19 | const torch::Tensor &bias,
20 | const torch::Tensor &refer, int act, int grad,
21 | float alpha, float scale) {
22 | CHECK_INPUT(input);
23 | CHECK_INPUT(bias);
24 |
25 | at::DeviceGuard guard(input.device());
26 |
27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28 | }
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32 | }
--------------------------------------------------------------------------------
/source/stylegan2/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | #include
16 | #include
17 |
18 | template
19 | static __global__ void
20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22 | scalar_t scale, int loop_x, int size_x, int step_b,
23 | int size_b, int use_bias, int use_ref) {
24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25 |
26 | scalar_t zero = 0.0;
27 |
28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29 | loop_idx++, xi += blockDim.x) {
30 | scalar_t x = p_x[xi];
31 |
32 | if (use_bias) {
33 | x += p_b[(xi / step_b) % size_b];
34 | }
35 |
36 | scalar_t ref = use_ref ? p_ref[xi] : zero;
37 |
38 | scalar_t y;
39 |
40 | switch (act * 10 + grad) {
41 | default:
42 | case 10:
43 | y = x;
44 | break;
45 | case 11:
46 | y = x;
47 | break;
48 | case 12:
49 | y = 0.0;
50 | break;
51 |
52 | case 30:
53 | y = (x > 0.0) ? x : x * alpha;
54 | break;
55 | case 31:
56 | y = (ref > 0.0) ? x : x * alpha;
57 | break;
58 | case 32:
59 | y = 0.0;
60 | break;
61 | }
62 |
63 | out[xi] = y * scale;
64 | }
65 | }
66 |
67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68 | const torch::Tensor &bias,
69 | const torch::Tensor &refer, int act, int grad,
70 | float alpha, float scale) {
71 | int curDevice = -1;
72 | cudaGetDevice(&curDevice);
73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 |
75 | auto x = input.contiguous();
76 | auto b = bias.contiguous();
77 | auto ref = refer.contiguous();
78 |
79 | int use_bias = b.numel() ? 1 : 0;
80 | int use_ref = ref.numel() ? 1 : 0;
81 |
82 | int size_x = x.numel();
83 | int size_b = b.numel();
84 | int step_b = 1;
85 |
86 | for (int i = 1 + 1; i < x.dim(); i++) {
87 | step_b *= x.size(i);
88 | }
89 |
90 | int loop_x = 4;
91 | int block_size = 4 * 32;
92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93 |
94 | auto y = torch::empty_like(x);
95 |
96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97 | x.scalar_type(), "fused_bias_act_kernel", [&] {
98 | fused_bias_act_kernel<<>>(
99 | y.data_ptr(), x.data_ptr(),
100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha,
101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102 | });
103 |
104 | return y;
105 | }
--------------------------------------------------------------------------------
/source/stylegan2/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5 | const torch::Tensor &kernel, int up_x, int up_y,
6 | int down_x, int down_y, int pad_x0, int pad_x1,
7 | int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) \
10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 |
17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18 | int up_x, int up_y, int down_x, int down_y, int pad_x0,
19 | int pad_x1, int pad_y0, int pad_y1) {
20 | CHECK_INPUT(input);
21 | CHECK_INPUT(kernel);
22 |
23 | at::DeviceGuard guard(input.device());
24 |
25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26 | pad_y0, pad_y1);
27 | }
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31 | }
--------------------------------------------------------------------------------
/source/stylegan2/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | upfirdn2d_op = load(
12 | "upfirdn2d",
13 | sources=[
14 | os.path.join(module_path, "upfirdn2d.cpp"),
15 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class UpFirDn2dBackward(Function):
21 | @staticmethod
22 | def forward(
23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24 | ):
25 |
26 | up_x, up_y = up
27 | down_x, down_y = down
28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29 |
30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31 |
32 | grad_input = upfirdn2d_op.upfirdn2d(
33 | grad_output,
34 | grad_kernel,
35 | down_x,
36 | down_y,
37 | up_x,
38 | up_y,
39 | g_pad_x0,
40 | g_pad_x1,
41 | g_pad_y0,
42 | g_pad_y1,
43 | )
44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45 |
46 | ctx.save_for_backward(kernel)
47 |
48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
49 |
50 | ctx.up_x = up_x
51 | ctx.up_y = up_y
52 | ctx.down_x = down_x
53 | ctx.down_y = down_y
54 | ctx.pad_x0 = pad_x0
55 | ctx.pad_x1 = pad_x1
56 | ctx.pad_y0 = pad_y0
57 | ctx.pad_y1 = pad_y1
58 | ctx.in_size = in_size
59 | ctx.out_size = out_size
60 |
61 | return grad_input
62 |
63 | @staticmethod
64 | def backward(ctx, gradgrad_input):
65 | kernel, = ctx.saved_tensors
66 |
67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68 |
69 | gradgrad_out = upfirdn2d_op.upfirdn2d(
70 | gradgrad_input,
71 | kernel,
72 | ctx.up_x,
73 | ctx.up_y,
74 | ctx.down_x,
75 | ctx.down_y,
76 | ctx.pad_x0,
77 | ctx.pad_x1,
78 | ctx.pad_y0,
79 | ctx.pad_y1,
80 | )
81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82 | gradgrad_out = gradgrad_out.view(
83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84 | )
85 |
86 | return gradgrad_out, None, None, None, None, None, None, None, None
87 |
88 |
89 | class UpFirDn2d(Function):
90 | @staticmethod
91 | def forward(ctx, input, kernel, up, down, pad):
92 | up_x, up_y = up
93 | down_x, down_y = down
94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
95 |
96 | kernel_h, kernel_w = kernel.shape
97 | batch, channel, in_h, in_w = input.shape
98 | ctx.in_size = input.shape
99 |
100 | input = input.reshape(-1, in_h, in_w, 1)
101 |
102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103 |
104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106 | ctx.out_size = (out_h, out_w)
107 |
108 | ctx.up = (up_x, up_y)
109 | ctx.down = (down_x, down_y)
110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111 |
112 | g_pad_x0 = kernel_w - pad_x0 - 1
113 | g_pad_y0 = kernel_h - pad_y0 - 1
114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116 |
117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118 |
119 | out = upfirdn2d_op.upfirdn2d(
120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121 | )
122 | # out = out.view(major, out_h, out_w, minor)
123 | out = out.view(-1, channel, out_h, out_w)
124 |
125 | return out
126 |
127 | @staticmethod
128 | def backward(ctx, grad_output):
129 | kernel, grad_kernel = ctx.saved_tensors
130 |
131 | grad_input = None
132 |
133 | if ctx.needs_input_grad[0]:
134 | grad_input = UpFirDn2dBackward.apply(
135 | grad_output,
136 | kernel,
137 | grad_kernel,
138 | ctx.up,
139 | ctx.down,
140 | ctx.pad,
141 | ctx.g_pad,
142 | ctx.in_size,
143 | ctx.out_size,
144 | )
145 |
146 | return grad_input, None, None, None, None
147 |
148 |
149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150 | if not isinstance(up, abc.Iterable):
151 | up = (up, up)
152 |
153 | if not isinstance(down, abc.Iterable):
154 | down = (down, down)
155 |
156 | if len(pad) == 2:
157 | pad = (pad[0], pad[1], pad[0], pad[1])
158 |
159 | if input.device.type == "cpu":
160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161 |
162 | else:
163 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
164 |
165 | return out
166 |
167 |
168 | def upfirdn2d_native(
169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170 | ):
171 | _, channel, in_h, in_w = input.shape
172 | input = input.reshape(-1, in_h, in_w, 1)
173 |
174 | _, in_h, in_w, minor = input.shape
175 | kernel_h, kernel_w = kernel.shape
176 |
177 | out = input.view(-1, in_h, 1, in_w, 1, minor)
178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180 |
181 | out = F.pad(
182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183 | )
184 | out = out[
185 | :,
186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188 | :,
189 | ]
190 |
191 | out = out.permute(0, 3, 1, 2)
192 | out = out.reshape(
193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194 | )
195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196 | out = F.conv2d(out, w)
197 | out = out.reshape(
198 | -1,
199 | minor,
200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202 | )
203 | out = out.permute(0, 2, 3, 1)
204 | out = out[:, ::down_y, ::down_x, :]
205 |
206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208 |
209 | return out.view(-1, channel, out_h, out_w)
210 |
--------------------------------------------------------------------------------
/source/stylegan2/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/source/stylegan2/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from io import BytesIO
3 | import multiprocessing
4 | from functools import partial
5 |
6 | from PIL import Image
7 | import lmdb
8 | from tqdm import tqdm
9 | from torchvision import datasets
10 | from torchvision.transforms import functional as trans_fn
11 | import os
12 |
13 | def resize_and_convert(img, size, resample, quality=100):
14 | img = trans_fn.resize(img, size, resample)
15 | img = trans_fn.center_crop(img, size)
16 | buffer = BytesIO()
17 | img.save(buffer, format="jpeg", quality=quality)
18 | val = buffer.getvalue()
19 |
20 | return val
21 |
22 |
23 | def resize_multiple(
24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
25 | ):
26 | imgs = []
27 |
28 | for size in sizes:
29 | imgs.append(resize_and_convert(img, size, resample, quality))
30 |
31 | return imgs
32 |
33 |
34 | def resize_worker(img_file, sizes, resample):
35 | i, file = img_file
36 | img = Image.open(file)
37 | img = img.convert("RGB")
38 | out = resize_multiple(img, sizes=sizes, resample=resample)
39 |
40 | return i, out
41 |
42 |
43 | def prepare(
44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
45 | ):
46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
47 |
48 | files = sorted(dataset.imgs, key=lambda x: x[0])
49 | files = [(i, file) for i, (file, label) in enumerate(files)]
50 | total = 0
51 |
52 | with multiprocessing.Pool(n_worker) as pool:
53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
54 | for size, img in zip(sizes, imgs):
55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
56 |
57 | with env.begin(write=True) as txn:
58 | txn.put(key, img)
59 |
60 | total += 1
61 |
62 | with env.begin(write=True) as txn:
63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
64 |
65 |
66 | if __name__ == "__main__":
67 | parser = argparse.ArgumentParser(description="Preprocess images for model training")
68 | parser.add_argument("path", type=str, help="path to the image dataset")
69 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
70 | parser.add_argument(
71 | "--size",
72 | type=str,
73 | default="128,256,512,1024",
74 | help="resolutions of images for the dataset",
75 | )
76 | parser.add_argument(
77 | "--n_worker",
78 | type=int,
79 | default=16,
80 | help="number of workers for preparing dataset",
81 | )
82 | parser.add_argument(
83 | "--resample",
84 | type=str,
85 | default="lanczos",
86 | help="resampling methods for resizing images",
87 | )
88 |
89 | args = parser.parse_args()
90 | if not os.path.exists(str(args.out)):
91 | os.makedirs(str(args.out))
92 |
93 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
94 | resample = resample_map[args.resample]
95 |
96 | sizes = [int(s.strip()) for s in args.size.split(",")]
97 |
98 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
99 |
100 | imgset = datasets.ImageFolder(args.path)
101 |
102 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
103 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
104 |
--------------------------------------------------------------------------------
/source/stylegan2/style_blend.py:
--------------------------------------------------------------------------------
1 | # model blending technique
2 | import os
3 | import cv2 as cv
4 | import torch
5 | from model import Generator
6 | import math
7 | import argparse
8 |
9 | def extract_conv_names(model):
10 | model = list(model.keys())
11 | conv_name = []
12 | resolutions = [4*2**x for x in range(9)]
13 | level_names = [["Conv0_up", "Const"], ["Conv1", "ToRGB"]]
14 |
15 |
16 | def blend_models(model_1, model_2, resolution, level, blend_width=None):
17 | resolutions = [4 * 2 ** i for i in range(7)]
18 | mid = resolutions.index(resolution)
19 |
20 | device = "cuda"
21 |
22 | size = 256
23 | latent = 512
24 | n_mlp = 8
25 | channel_multiplier =2
26 | G_1 = Generator(
27 | size, latent, n_mlp, channel_multiplier=channel_multiplier
28 | ).to(device)
29 | ckpt_ffhq = torch.load(model_1, map_location=lambda storage, loc: storage)
30 | G_1.load_state_dict(ckpt_ffhq["g"], strict=False)
31 |
32 |
33 | G_2 = Generator(
34 | size, latent, n_mlp, channel_multiplier=channel_multiplier
35 | ).to(device)
36 | ckpt_toon = torch.load(model_2)
37 | G_2.load_state_dict(ckpt_toon["g_ema"])
38 |
39 |
40 |
41 | # G_1 = stylegan2.models.load(model_1)
42 | # G_2 = stylegan2.models.load(model_2)
43 | model_1_state_dict = G_1.state_dict()
44 | model_2_state_dict = G_2.state_dict()
45 | assert(model_1_state_dict.keys() == model_2_state_dict.keys())
46 | G_out = G_1.clone()
47 |
48 | layers = []
49 | ys = []
50 | for k, v in model_1_state_dict.items():
51 | if k.startswith('convs.'):
52 | pos = int(k[len('convs.')])
53 | x = pos - mid
54 | if blend_width:
55 | exponent = -x / blend_width
56 | y = 1 / (1 + math.exp(exponent))
57 | else:
58 | y = 1 if x > 0 else 0
59 |
60 | layers.append(k)
61 | ys.append(y)
62 | elif k.startswith('to_rgbs.'):
63 | pos = int(k[len('to_rgbs.')])
64 | x = pos - mid
65 | if blend_width:
66 | exponent = -x / blend_width
67 | y = 1 / (1 + math.exp(exponent))
68 | else:
69 | y = 1 if x > 0 else 0
70 | layers.append(k)
71 | ys.append(y)
72 | out_state = G_out.state_dict()
73 | for y, layer in zip(ys, layers):
74 | out_state[layer] = y * model_2_state_dict[layer] + \
75 | (1 - y) * model_1_state_dict[layer]
76 | print('blend layer %s'%str(y))
77 | G_out.load_state_dict(out_state)
78 | return G_out
79 |
80 |
81 | def blend_models_2(model_1, model_2, resolution, level, blend_width=None):
82 | # resolution = f"{resolution}x{resolution}"
83 | resolutions = [4 * 2 ** i for i in range(7)]
84 | mid = [resolutions.index(r) for r in resolution]
85 |
86 | G_1 = stylegan2.models.load(model_1)
87 | G_2 = stylegan2.models.load(model_2)
88 | model_1_state_dict = G_1.state_dict()
89 | model_2_state_dict = G_2.state_dict()
90 | assert(model_1_state_dict.keys() == model_2_state_dict.keys())
91 | G_out = G_1.clone()
92 |
93 | layers = []
94 | ys = []
95 | for k, v in model_1_state_dict.items():
96 | if k.startswith('G_synthesis.conv_blocks.'):
97 | pos = int(k[len('G_synthesis.conv_blocks.')])
98 | y = 0 if pos in mid else 1
99 | layers.append(k)
100 | ys.append(y)
101 | elif k.startswith('G_synthesis.to_data_layers.'):
102 | pos = int(k[len('G_synthesis.to_data_layers.')])
103 | y = 0 if pos in mid else 1
104 | layers.append(k)
105 | ys.append(y)
106 | # print(ys, layers)
107 | out_state = G_out.state_dict()
108 | for y, layer in zip(ys, layers):
109 | out_state[layer] = y * model_2_state_dict[layer] + \
110 | (1 - y) * model_1_state_dict[layer]
111 | G_out.load_state_dict(out_state)
112 | return G_out
113 |
114 |
115 | def main(name):
116 |
117 | resolution = 4
118 |
119 | model_name = '001000.pt'
120 |
121 | G_out = blend_models("pretrained_models/stylegan2-ffhq-config-f-256-550000.pt",
122 | "face_generation/experiment_stylegan/"+name+"/models/"+model_name,
123 | resolution,
124 | None)
125 | # G_out.save('G_blend.pth')
126 | outdir = os.path.join('face_generation/experiment_stylegan',name,'models_blend')
127 | if not os.path.exists(outdir):
128 | os.makedirs(outdir)
129 |
130 | outpath = os.path.join(outdir, 'G_blend_'+str(model_name[:-3])+'_'+ str(resolution)+'.pt')
131 | torch.save(
132 | {
133 | "g_ema": G_out.state_dict(),
134 | },
135 | # 'G_blend_570000_16.pth',
136 | outpath
137 | )
138 |
139 |
140 | if __name__ == '__main__':
141 | parser = argparse.ArgumentParser(description="style blender")
142 | parser.add_argument('--name', type=str, default='')
143 | args = parser.parse_args()
144 | print('model name:%s'%args.name)
145 | main(args.name)
--------------------------------------------------------------------------------
/source/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import numpy as np
5 |
6 |
7 | def resize_size(image, size=720):
8 | h, w, c = np.shape(image)
9 | if min(h, w) > size:
10 | if h > w:
11 | h, w = int(size * h / w), size
12 | else:
13 | h, w = size, int(size * w / h)
14 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
15 | return image
16 |
17 |
18 | def padTo16x(image):
19 | h, w, c = np.shape(image)
20 | if h % 16 == 0 and w % 16 == 0:
21 | return image, h, w
22 | nh, nw = (h // 16 + 1) * 16, (w // 16 + 1) * 16
23 | img_new = np.ones((nh, nw, 3), np.uint8) * 255
24 | img_new[:h, :w, :] = image
25 |
26 | return img_new, h, w
27 |
28 |
29 | def get_f5p(landmarks, np_img):
30 | eye_left = find_pupil(landmarks[36:41], np_img)
31 | eye_right = find_pupil(landmarks[42:47], np_img)
32 | if eye_left is None or eye_right is None:
33 | print('cannot find 5 points with find_puil, used mean instead.!')
34 | eye_left = landmarks[36:41].mean(axis=0)
35 | eye_right = landmarks[42:47].mean(axis=0)
36 | nose = landmarks[30]
37 | mouth_left = landmarks[48]
38 | mouth_right = landmarks[54]
39 | f5p = [[eye_left[0], eye_left[1]], [eye_right[0], eye_right[1]],
40 | [nose[0], nose[1]], [mouth_left[0], mouth_left[1]],
41 | [mouth_right[0], mouth_right[1]]]
42 | return f5p
43 |
44 |
45 | def find_pupil(landmarks, np_img):
46 | h, w, _ = np_img.shape
47 | xmax = int(landmarks[:, 0].max())
48 | xmin = int(landmarks[:, 0].min())
49 | ymax = int(landmarks[:, 1].max())
50 | ymin = int(landmarks[:, 1].min())
51 |
52 | if ymin >= ymax or xmin >= xmax or ymin < 0 or xmin < 0 or ymax > h or xmax > w:
53 | return None
54 | eye_img_bgr = np_img[ymin:ymax, xmin:xmax, :]
55 | eye_img = cv2.cvtColor(eye_img_bgr, cv2.COLOR_BGR2GRAY)
56 | eye_img = cv2.equalizeHist(eye_img)
57 | n_marks = landmarks - np.array([xmin, ymin]).reshape([1, 2])
58 | eye_mask = cv2.fillConvexPoly(
59 | np.zeros_like(eye_img), n_marks.astype(np.int32), 1)
60 | ret, thresh = cv2.threshold(eye_img, 100, 255,
61 | cv2.THRESH_BINARY | cv2.THRESH_OTSU)
62 | thresh = (1 - thresh / 255.) * eye_mask
63 | cnt = 0
64 | xm = []
65 | ym = []
66 | for i in range(thresh.shape[0]):
67 | for j in range(thresh.shape[1]):
68 | if thresh[i, j] > 0.5:
69 | xm.append(j)
70 | ym.append(i)
71 | cnt += 1
72 | if cnt != 0:
73 | xm.sort()
74 | ym.sort()
75 | xm = xm[cnt // 2]
76 | ym = ym[cnt // 2]
77 | else:
78 | xm = thresh.shape[1] / 2
79 | ym = thresh.shape[0] / 2
80 |
81 | return xm + xmin, ym + ymin
82 |
83 |
84 | def all_file(file_dir):
85 | L = []
86 | for root, dirs, files in os.walk(file_dir):
87 | for file in files:
88 | extend = os.path.splitext(file)[1]
89 | if extend == '.png' or extend == '.jpg' or extend == '.jpeg':
90 | L.append(os.path.join(root, file))
91 | return L
92 |
93 | def initialize_mask(box_width):
94 | h, w = [box_width, box_width]
95 | mask = np.zeros((h, w), np.uint8)
96 |
97 | center = (int(w / 2), int(h / 2))
98 | axes = (int(w * 0.4), int(h * 0.49))
99 | mask = cv2.ellipse(img=mask, center=center, axes=axes, angle=0, startAngle=0, endAngle=360, color=(1),
100 | thickness=-1)
101 | mask = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
102 |
103 | maxn = max(w, h) * 0.15
104 | mask[(mask < 255) & (mask > 0)] = mask[(mask < 255) & (mask > 0)] / maxn
105 | mask = np.clip(mask, 0, 1)
106 |
107 | return mask.astype(float)
108 |
--------------------------------------------------------------------------------
/train_localtoon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from modelscope.trainers.cv import CartoonTranslationTrainer
4 |
5 |
6 | def main(args):
7 |
8 | data_photo = os.path.join(args.data_dir, 'face_photo')
9 | data_cartoon = os.path.join(args.data_dir, 'face_cartoon')
10 |
11 | style = args.style
12 | if style == "anime":
13 | style = ""
14 | else:
15 | style = '-' + style
16 | model_id = 'damo/cv_unet_person-image-cartoon' + style + '_compound-models'
17 |
18 | max_steps = 300000
19 | trainer = CartoonTranslationTrainer(
20 | model=model_id,
21 | work_dir=args.work_dir,
22 | photo=data_photo,
23 | cartoon=data_cartoon,
24 | max_steps=max_steps)
25 | trainer.train()
26 |
27 | if __name__ == '__main__':
28 | parser = argparse.ArgumentParser(description="process remove bg result")
29 | parser.add_argument("--data_dir", type=str, default='', help="Path to training images.")
30 | parser.add_argument("--work_dir", type=str, default='', help="Path to save results.")
31 | parser.add_argument("--style", type=str, default='anime', help="resume training from similar style.")
32 |
33 | args = parser.parse_args()
34 |
35 | main(args)
36 |
37 |
--------------------------------------------------------------------------------