├── LICENSE
├── README.md
├── docs
└── assets
│ ├── interpolation.gif
│ └── synthesis.jpg
├── models
├── __init__.py
├── clip_model.py
├── text_generator.py
└── utils
│ ├── __init__.py
│ └── ops.py
├── requirements
├── convert.txt
├── develop.txt
└── minimal.txt
├── run_interpolate.py
├── run_synthesize.py
├── third_party
├── .DS_Store
├── __init__.py
├── clip_official
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ └── simple_tokenizer.py
├── stylegan2_official_ops
│ ├── README.md
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── custom_ops.py
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── misc.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
└── stylegan3_official_ops
│ ├── README.md
│ ├── __init__.py
│ ├── bias_act.cpp
│ ├── bias_act.cu
│ ├── bias_act.h
│ ├── bias_act.py
│ ├── conv2d_gradfix.py
│ ├── conv2d_resample.py
│ ├── custom_ops.py
│ ├── filtered_lrelu.cpp
│ ├── filtered_lrelu.cu
│ ├── filtered_lrelu.h
│ ├── filtered_lrelu.py
│ ├── filtered_lrelu_ns.cu
│ ├── filtered_lrelu_rd.cu
│ ├── filtered_lrelu_wr.cu
│ ├── fma.py
│ ├── grid_sample_gradfix.py
│ ├── misc.py
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.cu
│ ├── upfirdn2d.h
│ └── upfirdn2d.py
└── utils
├── .DS_Store
├── __init__.py
├── dist_utils.py
├── file_transmitters
├── __init__.py
├── base_file_transmitter.py
├── dummy_file_transmitter.py
└── local_file_transmitter.py
├── formatting_utils.py
├── image_utils.py
├── loggers
├── __init__.py
├── base_logger.py
├── dummy_logger.py
├── normal_logger.py
├── rich_logger.py
└── test.py
├── misc.py
├── parsing_utils.py
├── tf_utils.py
└── visualizers
├── __init__.py
├── gif_visualizer.py
├── grid_visualizer.py
├── html_visualizer.py
├── test.py
└── video_visualizer.py
/LICENSE:
--------------------------------------------------------------------------------
1 | ------------------------------ LICENSE for Aurora ------------------------------
2 |
3 | Copyright (c) 2023 Ant Group.
4 |
5 | MIT License
6 |
7 | Permission is hereby granted, free of charge, to any person obtaining a copy
8 | of this software and associated documentation files (the "Software"), to deal
9 | in the Software without restriction, including without limitation the rights
10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | copies of the Software, and to permit persons to whom the Software is
12 | furnished to do so, subject to the following conditions:
13 |
14 | The above copyright notice and this permission notice shall be included in all
15 | copies or substantial portions of the Software.
16 |
17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | SOFTWARE.
24 |
25 | ------------------------------- LICENSE for CLIP -------------------------------
26 |
27 | MIT License
28 |
29 | Copyright (c) 2021 OpenAI
30 |
31 | Permission is hereby granted, free of charge, to any person obtaining a copy
32 | of this software and associated documentation files (the "Software"), to deal
33 | in the Software without restriction, including without limitation the rights
34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
35 | copies of the Software, and to permit persons to whom the Software is
36 | furnished to do so, subject to the following conditions:
37 |
38 | The above copyright notice and this permission notice shall be included in all
39 | copies or substantial portions of the Software.
40 |
41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 | SOFTWARE.
48 |
49 | ------------------------------ LICENSE for Hammer ------------------------------
50 |
51 | Copyright (c) 2022 ByteDance, Inc.
52 |
53 | MIT License
54 |
55 | Permission is hereby granted, free of charge, to any person obtaining a copy
56 | of this software and associated documentation files (the "Software"), to deal
57 | in the Software without restriction, including without limitation the rights
58 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
59 | copies of the Software, and to permit persons to whom the Software is
60 | furnished to do so, subject to the following conditions:
61 |
62 | The above copyright notice and this permission notice shall be included in all
63 | copies or substantial portions of the Software.
64 |
65 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
66 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
67 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
68 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
69 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
70 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
71 | SOFTWARE.
72 |
73 | ---------- LICENSE for custom CUDA kernels in StyleGAN2 and StyleGAN3 ----------
74 |
75 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
76 |
77 | NVIDIA Source Code License
78 |
79 | =======================================================================
80 |
81 | 1. Definitions
82 |
83 | "Licensor" means any person or entity that distributes its Work.
84 |
85 | "Software" means the original work of authorship made available under
86 | this License.
87 |
88 | "Work" means the Software and any additions to or derivative works of
89 | the Software that are made available under this License.
90 |
91 | The terms "reproduce," "reproduction," "derivative works," and
92 | "distribution" have the meaning as provided under U.S. copyright law;
93 | provided, however, that for the purposes of this License, derivative
94 | works shall not include works that remain separable from, or merely
95 | link (or bind by name) to the interfaces of, the Work.
96 |
97 | Works, including the Software, are "made available" under this License
98 | by including in or with the Work either (a) a copyright notice
99 | referencing the applicability of this License to the Work, or (b) a
100 | copy of this License.
101 |
102 | 2. License Grants
103 |
104 | 2.1 Copyright Grant. Subject to the terms and conditions of this
105 | License, each Licensor grants to you a perpetual, worldwide,
106 | non-exclusive, royalty-free, copyright license to reproduce,
107 | prepare derivative works of, publicly display, publicly perform,
108 | sublicense and distribute its Work and any resulting derivative
109 | works in any form.
110 |
111 | 3. Limitations
112 |
113 | 3.1 Redistribution. You may reproduce or distribute the Work only
114 | if (a) you do so under this License, (b) you include a complete
115 | copy of this License with your distribution, and (c) you retain
116 | without modification any copyright, patent, trademark, or
117 | attribution notices that are present in the Work.
118 |
119 | 3.2 Derivative Works. You may specify that additional or different
120 | terms apply to the use, reproduction, and distribution of your
121 | derivative works of the Work ("Your Terms") only if (a) Your Terms
122 | provide that the use limitation in Section 3.3 applies to your
123 | derivative works, and (b) you identify the specific derivative
124 | works that are subject to Your Terms. Notwithstanding Your Terms,
125 | this License (including the redistribution requirements in Section
126 | 3.1) will continue to apply to the Work itself.
127 |
128 | 3.3 Use Limitation. The Work and any derivative works thereof only
129 | may be used or intended for use non-commercially. Notwithstanding
130 | the foregoing, NVIDIA and its affiliates may use the Work and any
131 | derivative works commercially. As used herein, "non-commercially"
132 | means for research or evaluation purposes only.
133 |
134 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
135 | against any Licensor (including any claim, cross-claim or
136 | counterclaim in a lawsuit) to enforce any patents that you allege
137 | are infringed by any Work, then your rights under this License from
138 | such Licensor (including the grant in Section 2.1) will terminate
139 | immediately.
140 |
141 | 3.5 Trademarks. This License does not grant any rights to use any
142 | Licensor’s or its affiliates’ names, logos, or trademarks, except
143 | as necessary to reproduce the notices described in this License.
144 |
145 | 3.6 Termination. If you violate any term of this License, then your
146 | rights under this License (including the grant in Section 2.1) will
147 | terminate immediately.
148 |
149 | 4. Disclaimer of Warranty.
150 |
151 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
152 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
153 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
154 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
155 | THIS LICENSE.
156 |
157 | 5. Limitation of Liability.
158 |
159 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
160 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
161 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
162 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
163 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
164 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
165 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
166 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
167 | THE POSSIBILITY OF SUCH DAMAGES.
168 |
169 | =======================================================================
170 |
171 | ----------------------------- LICENSE for GenForce -----------------------------
172 |
173 | Copyright (c) 2020 GenForce
174 |
175 | MIT License
176 |
177 | Permission is hereby granted, free of charge, to any person obtaining a copy
178 | of this software and associated documentation files (the "Software"), to deal
179 | in the Software without restriction, including without limitation the rights
180 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
181 | copies of the Software, and to permit persons to whom the Software is
182 | furnished to do so, subject to the following conditions:
183 |
184 | The above copyright notice and this permission notice shall be included in all
185 | copies or substantial portions of the Software.
186 |
187 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
188 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
189 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
190 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
191 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
192 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
193 | SOFTWARE.
194 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Aurora -- An Open-sourced GAN-based Text-to-Image Generation Model
2 |
3 | > **Exploring Sparse MoE in GANs for Text-conditioned Image Synthesis**
4 | > Jiapeng Zhu*, Ceyuan Yang*, Kecheng Zheng, Yinghao Xu, Zifan Shi, Yujun Shen
5 | > *arXiv preprint arXiv:2309.03904*
6 |
7 | [[Paper](https://arxiv.org/pdf/2309.03904.pdf)]
8 |
9 | ## TODO
10 |
11 | - [x] Release inference code
12 | - [x] Release text-to-image generator at 64x64 resolution
13 | - [ ] Release models at higher resolution
14 | - [ ] Release training code
15 | - [ ] Release plug-ins/efficient algorithms for more functionalities
16 |
17 | ## Installation
18 |
19 | This repository is developed based on [Hammer](https://github.com/bytedance/Hammer), where you can find more detailed instructions on installation. Here, we summarize the necessary steps to facilitate reproduction.
20 |
21 | 1. Environment: CUDA version == 11.3.
22 |
23 | 2. Install package requirements with `conda`:
24 |
25 | ```shell
26 | conda create -n aurora python=3.8 # create virtual environment with Python 3.8
27 | conda activate aurora
28 | pip install -r requirements/minimal.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
29 | ```
30 |
31 | ## Inference
32 |
33 | First, please download the pre-trained model [here](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jzhubt_connect_ust_hk/EZJRG5BV_URMjywZLcAW95YBKNQaD7M35Ba6PCHe_Gf16w?e=0n5BOm).
34 |
35 | To synthesize an image with given text prompt, you can use the following command
36 |
37 | ```bash
38 | python run_synthesize.py aurora_v1.pth 'A photo of a tree with autumn leaves'
39 | ```
40 |
41 | To make interpolation between two text prompts, you can use the following command
42 |
43 | ```bash
44 | python run_interpolate.py aurora_v1.pth \
45 | --src_prompt 'A photo of a tree with autumn leaves' \
46 | --dst_prompt 'A photo of a victorian house'
47 | ```
48 |
49 | ## Results
50 |
51 | - Text-conditioned image generation
52 |
53 | 
54 |
55 | - Text prompt interpolation
56 |
57 | 
58 |
59 | ## LICENSE
60 |
61 | The project is under [MIT License](./LICENSE), and is for research purpose ONLY.
62 |
63 | ## Acknowledgements
64 |
65 | We highly appreciate [StyleGAN2](https://github.com/NVlabs/stylegan2), [StyleGAN3](https://github.com/NVlabs/stylegan3), [CLIP](https://github.com/openai/CLIP), and [Hammer](https://github.com/bytedance/Hammer) for their contributions to the community.
66 |
67 | ## BibTeX
68 |
69 | ```bibtex
70 | @article{zhu2023aurora,
71 | title = {Exploring Sparse {MoE} in {GANs} for Text-conditioned Image Synthesis},
72 | author = {Zhu, Jiapeng and Yang, Ceyuan and Zheng, Kecheng and Xu, Yinghao and Shi, Zifan and Shen, Yujun},
73 | journal = {arXiv preprint arXiv:2309.03904},
74 | year = {2023}
75 | }
76 | ```
77 |
--------------------------------------------------------------------------------
/docs/assets/interpolation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/docs/assets/interpolation.gif
--------------------------------------------------------------------------------
/docs/assets/synthesis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/docs/assets/synthesis.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all models."""
3 |
4 | from .clip_model import CLIPModel
5 | from .text_generator import Text2ImageGenerator
6 |
7 | __all__ = ['build_model']
8 |
9 | _MODELS = {
10 | 'CLIPModel': CLIPModel,
11 | 'Text2ImageGenerator': Text2ImageGenerator
12 | }
13 |
14 |
15 | def build_model(model_type, **kwargs):
16 | """Builds a model based on its class type.
17 |
18 | Args:
19 | model_type: Class type to which the model belongs, which is case
20 | sensitive.
21 | **kwargs: Additional arguments to build the model.
22 |
23 | Raises:
24 | ValueError: If the `model_type` is not supported.
25 | """
26 | if model_type not in _MODELS:
27 | raise ValueError(f'Invalid model type: `{model_type}`!\n'
28 | f'Types allowed: {list(_MODELS)}.')
29 | return _MODELS[model_type](**kwargs)
30 |
--------------------------------------------------------------------------------
/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/models/utils/__init__.py
--------------------------------------------------------------------------------
/models/utils/ops.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains operators for neural networks."""
3 |
4 | import torch
5 | import torch.distributed as dist
6 |
7 | __all__ = ['all_gather']
8 |
9 |
10 | def all_gather(tensor):
11 | """Gathers tensor from all devices and executes averaging."""
12 | if not dist.is_initialized():
13 | return tensor
14 |
15 | world_size = dist.get_world_size()
16 | tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
17 | dist.all_gather(tensor_list, tensor, async_op=False)
18 | return torch.stack(tensor_list, dim=0).mean(dim=0)
19 |
--------------------------------------------------------------------------------
/requirements/convert.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.1
2 | tensorflow-gpu==1.15
3 | ninja==1.10.2
4 | scikit-video==1.1.11
5 | pillow==9.0.0
6 | opencv-python-headless==4.5.5.62
7 | requests
8 | bs4
9 | tqdm
10 | rich
11 | easydict
12 |
--------------------------------------------------------------------------------
/requirements/develop.txt:
--------------------------------------------------------------------------------
1 | bpytop # Monitor system resources.
2 | gpustat # Monitor GPU usage.
3 | pylint # Check coding style.
4 |
--------------------------------------------------------------------------------
/requirements/minimal.txt:
--------------------------------------------------------------------------------
1 | torch==1.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
2 | torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu111/torch_stable.html
3 | tensorboard==2.7.0
4 | torch-tb-profiler==0.3.1
5 | ninja==1.10.2
6 | numpy==1.22.3
7 | scipy==1.7.3
8 | scikit-learn==1.0.2
9 | scikit-video==1.1.11
10 | pillow==9.0.0
11 | opencv-python-headless==4.5.5.62
12 | requests
13 | bs4
14 | tqdm
15 | rich
16 | click
17 | cloup
18 | psutil
19 | easydict
20 | lmdb
21 | matplotlib
22 | regex
23 | ftfy
24 | einops==0.6.1
25 | huggingface_hub==0.15.1
26 |
--------------------------------------------------------------------------------
/run_synthesize.py:
--------------------------------------------------------------------------------
1 |
2 | # python3.7
3 | """Contains the code to synthesize images from a pre-trained models.
4 | """
5 | import os
6 | import argparse
7 | from tqdm import tqdm
8 |
9 | import torch
10 | from models import build_model
11 | from utils.image_utils import postprocess_image, save_image
12 | from utils.visualizers import HtmlVisualizer
13 |
14 |
15 | def run_mapping(G, z, context, eot_ind=None):
16 | """Run mapping network of the generator."""
17 | with torch.no_grad():
18 | global_text, local_text = G.text_head(context, eot_ind=eot_ind)
19 | mapping_results = G.mapping(z,
20 | label=None,
21 | context=global_text)
22 | return mapping_results['wp'], local_text
23 |
24 |
25 | def run_synthesize(G, wp, local_text):
26 | """Run synthesis network of the generator."""
27 | with torch.no_grad():
28 | res = G.synthesis(wp, context=local_text)
29 | return res
30 |
31 |
32 | def read_text(text_path):
33 | """Prepare snapshot text that will be used for evaluation."""
34 | print(f'Loading text from {text_path}')
35 | with open(text_path) as f:
36 | text = [line.strip() for line in f.readlines()]
37 | return text
38 |
39 |
40 | def parse_float(arg):
41 | """Parse float number in string."""
42 | if not arg:
43 | return None
44 | arg = arg.split(',')
45 | arg = [float(i) for i in arg]
46 | return arg
47 |
48 |
49 | def parse_args():
50 | """Parses arguments."""
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('weight_path', type=str, default='',
53 | help='Path to the pre-trained models.')
54 | parser.add_argument('text_prompt', type=str, default='',
55 | help='The text prompt, support reading from a file '
56 | 'or just given a text prompt.')
57 | parser.add_argument('--batch_size', type=int, default=1,
58 | help='Batch size.')
59 | parser.add_argument('--syn_num', type=int, default=100,
60 | help='Number of synthesized images.')
61 | parser.add_argument('--resolution', type=int, default=64,
62 | help='Resolution of the model output.')
63 | parser.add_argument('--results_dir', type=str, default='work_dirs/syn_res',
64 | help='Results directory.')
65 | parser.add_argument('--seed', type=int, default=4,
66 | help='Random seed.')
67 | parser.add_argument('--trunc_layers', type=int, default=None,
68 | help='Number of layers to perform truncation.')
69 | parser.add_argument('--loop_mapping', type=int, default=16,
70 | help='Loop number for getting average for wp.')
71 | parser.add_argument('--save_name', type=str, default='0',
72 | help='Name to help save the file.')
73 | parser.add_argument('--num_z', type=int, default=3,
74 | help='Number of z for each text prompt.')
75 | parser.add_argument('--save_png', action='store_true',
76 | help='Whether or not to save the synthesized images.')
77 | parser.add_argument('--trunc_vals', type=str, default='0,0.05,0.1,0.15,0.2',
78 | help='Default values for truncation.')
79 | return parser.parse_args()
80 |
81 |
82 | def main():
83 | """Main function."""
84 | args = parse_args()
85 | assert args.batch_size == 1, 'Current script only support bs equals to 1.'
86 | if os.path.exists(args.text_prompt):
87 | text_prompt = read_text(args.text_prompt)
88 | else:
89 | text_prompt = [args.text_prompt]
90 | syn_num = min(args.syn_num, len(text_prompt))
91 | if torch.cuda.is_available():
92 | device = torch.device('cuda')
93 | else:
94 | device = torch.device('cpu')
95 |
96 | clip_config = {'model_name':'ViT-L-14',
97 | 'pretrained':'openai',
98 | 'freeze_clip': True}
99 | clip = build_model('CLIPModel', **clip_config)
100 |
101 | g_config = {'resolution': 64,
102 | 'image_channels':3,
103 | 'init_res': 4,
104 | 'z_dim': 128,
105 | 'w_dim': 1024,
106 | 'mapping_fmaps': 1024,
107 | 'label_dim': 0,
108 | 'context_dim': 1024,
109 | 'clip_out_dim': 768,
110 | 'head_dim': 64,
111 | 'embedding_dim': 1024,
112 | 'use_text_cond': True,
113 | 'num_layers_text_enc': 4,
114 | 'use_w_cond': False,
115 | 'use_class_label': False,
116 | 'mapping_layers': 4,
117 | 'fmaps_base': 16384,
118 | 'fmaps_max': 1600,
119 | 'num_adaptive_kernels': {"4":1,"8":1,"16":2,"32":4,"64":8},
120 | 'num_block_per_res': {"4":3,"8":3,"16":3,"32":2,"64":2},
121 | 'attn_resolutions': ['8', '16', '32', '64'],
122 | 'attn_depth': {"8":2,"16":2,"32":2,"64":1},
123 | 'attn_ch_factor': 1,
124 | 'attn_gain': 0.3,
125 | 'residual_gain': 0.4,
126 | 'text_head_gain': 1.0,
127 | 'zero_out': True,
128 | 'fourier_feat': True,
129 | 'l2_attention': True,
130 | 'tie': False,
131 | 'scale_in': False,
132 | 'include_ff': True,
133 | 'use_checkpoint': False,
134 | 'checkpoint_res': ['8', '16', '32'],
135 | 'mask_self': False,
136 | 'conv_clamp': None,
137 | 'mtm': True,
138 | 'num_experts': {"8":4,"16":8,"32":16,"64":16},
139 | 'ms_training_res': ['4','8','16','32','64'],
140 | 'skip_connection': True}
141 |
142 | G = build_model('Text2ImageGenerator', **g_config)
143 | checkpoint = torch.load(args.weight_path, map_location='cpu')
144 | if 'generator_smooth' in checkpoint:
145 | print('Loading checkpoint from generator smooth!')
146 | G.load_state_dict(checkpoint['generator_smooth'])
147 | else:
148 | print('Loading checkpoint from generator!')
149 | G.load_state_dict(checkpoint['generator'])
150 | G = G.eval().to(device)
151 |
152 | trunc_vals = parse_float(args.trunc_vals)
153 | if not trunc_vals:
154 | trunc_vals = [0, 0.05, 0.1, 0.15, 0.2]
155 | trunc_layers = args.trunc_layers
156 | if not trunc_layers:
157 | trunc_layers = G.num_layers
158 |
159 | visualizer_syn = HtmlVisualizer(image_size=args.resolution)
160 | visualizer_syn.reset(num_rows=syn_num * args.num_z,
161 | num_cols=len(trunc_vals) + 1)
162 | head = ['Number Z']
163 | head += [f'trunc_val_{val_}' for val_ in trunc_vals]
164 | visualizer_syn.set_headers(head)
165 | torch.manual_seed(args.seed)
166 | os.makedirs(args.results_dir, exist_ok=True)
167 | if args.save_png:
168 | os.makedirs(f'{args.results_dir}/images', exist_ok=True)
169 | w_avg = G.w_avg.reshape(1, -1, G.w_dim)[:, :args.trunc_layers]
170 | for idx in tqdm(range(syn_num)):
171 | text = text_prompt[idx]
172 | _, enc_text, eot_ind = clip.encode_text(text=text, is_tokenize=True)
173 | if args.loop_mapping > 0:
174 | sum_wp = 0
175 | for _ in range(args.loop_mapping):
176 | z = torch.randn((args.batch_size, *G.latent_dim), device=device)
177 | tmp_res, _ = run_mapping(G, z, enc_text, eot_ind=eot_ind)
178 | sum_wp += tmp_res
179 | avg_wp = sum_wp / args.loop_mapping
180 | avg_wp = avg_wp[:, :args.trunc_layers]
181 | for z_i in range(args.num_z):
182 | z = torch.randn((args.batch_size, *G.latent_dim), device=device)
183 | row_ind = idx * args.num_z + z_i
184 | visualizer_syn.set_cell(row_ind, 0, text=f'z_{z_i}')
185 | for col_idx, trunc_psi in enumerate(trunc_vals):
186 | wp, local_text = run_mapping(G, z, enc_text, eot_ind=eot_ind)
187 | wp[:, :args.trunc_layers] = w_avg.lerp(wp[:, :args.trunc_layers], trunc_psi)
188 | if args.loop_mapping > 0:
189 | wp[:, :args.trunc_layers] = avg_wp.lerp(wp[:, :args.trunc_layers], trunc_psi)
190 | fake_results = run_synthesize(G, wp, local_text)
191 | syn_imgs = fake_results['image'].detach().cpu().numpy()
192 | syn_imgs = postprocess_image(syn_imgs)
193 | visualizer_syn.set_cell(row_ind,
194 | col_idx + 1,
195 | text=text,
196 | image=syn_imgs[0])
197 | if args.save_png:
198 | prefix = f'{args.results_dir}/images/text_{idx:04d}_z_'
199 | save_path0 = f'{prefix}{z_i:02d}_psi_{trunc_psi:-2.1f}.png'
200 | save_image(save_path0, syn_imgs[0])
201 |
202 | # Save result.
203 | save_name_syn = f'syn_{syn_num:04d}_seed_{args.seed}_{args.save_name}.html'
204 | save_path_syn = os.path.join(args.results_dir, save_name_syn)
205 | visualizer_syn.save(save_path_syn)
206 |
207 |
208 | if __name__ == '__main__':
209 | main()
210 |
--------------------------------------------------------------------------------
/third_party/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/.DS_Store
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/__init__.py
--------------------------------------------------------------------------------
/third_party/clip_official/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/clip_official/__init__.py
--------------------------------------------------------------------------------
/third_party/clip_official/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/clip_official/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/third_party/clip_official/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021 OpenAI. All rights reserved.
4 |
5 | """Text tokenizer.
6 |
7 | Please refer to https://github.com/openai/CLIP
8 | """
9 |
10 | # pylint: disable=line-too-long
11 | # pylint: disable=inconsistent-quotes
12 | # pylint: disable=missing-class-docstring
13 | # pylint: disable=missing-function-docstring
14 | # pylint: disable=bare-except
15 | # pylint: disable=no-else-break
16 |
17 | import gzip
18 | import html
19 | import os
20 | from functools import lru_cache
21 |
22 | import ftfy
23 | import regex as re
24 |
25 |
26 | @lru_cache()
27 | def default_bpe():
28 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
29 |
30 |
31 | @lru_cache()
32 | def bytes_to_unicode():
33 | """
34 | Returns list of utf-8 byte and a corresponding list of unicode strings.
35 | The reversible bpe codes work on unicode strings.
36 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
37 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
38 | This is a signficant percentage of your normal, say, 32K bpe vocab.
39 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
40 | And avoids mapping to whitespace/control characters the bpe code barfs on.
41 | """
42 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
43 | cs = bs[:]
44 | n = 0
45 | for b in range(2**8):
46 | if b not in bs:
47 | bs.append(b)
48 | cs.append(2**8+n)
49 | n += 1
50 | cs = [chr(n) for n in cs]
51 | return dict(zip(bs, cs))
52 |
53 |
54 | def get_pairs(word):
55 | """Return set of symbol pairs in a word.
56 | Word is represented as tuple of symbols (symbols being variable-length strings).
57 | """
58 | pairs = set()
59 | prev_char = word[0]
60 | for char in word[1:]:
61 | pairs.add((prev_char, char))
62 | prev_char = char
63 | return pairs
64 |
65 |
66 | def basic_clean(text):
67 | text = ftfy.fix_text(text)
68 | text = html.unescape(html.unescape(text))
69 | return text.strip()
70 |
71 |
72 | def whitespace_clean(text):
73 | text = re.sub(r'\s+', ' ', text)
74 | text = text.strip()
75 | return text
76 |
77 |
78 | class SimpleTokenizer(object):
79 | def __init__(self, bpe_path: str = default_bpe()):
80 | self.byte_encoder = bytes_to_unicode()
81 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
82 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
83 | merges = merges[1:49152-256-2+1]
84 | merges = [tuple(merge.split()) for merge in merges]
85 | vocab = list(bytes_to_unicode().values())
86 | vocab = vocab + [v+'' for v in vocab]
87 | for merge in merges:
88 | vocab.append(''.join(merge))
89 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
90 | self.encoder = dict(zip(vocab, range(len(vocab))))
91 | self.decoder = {v: k for k, v in self.encoder.items()}
92 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
93 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
94 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
95 |
96 | def bpe(self, token):
97 | if token in self.cache:
98 | return self.cache[token]
99 | word = tuple(token[:-1]) + ( token[-1] + '',)
100 | pairs = get_pairs(word)
101 |
102 | if not pairs:
103 | return token+''
104 |
105 | while True:
106 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
107 | if bigram not in self.bpe_ranks:
108 | break
109 | first, second = bigram
110 | new_word = []
111 | i = 0
112 | while i < len(word):
113 | try:
114 | j = word.index(first, i)
115 | new_word.extend(word[i:j])
116 | i = j
117 | except:
118 | new_word.extend(word[i:])
119 | break
120 |
121 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
122 | new_word.append(first+second)
123 | i += 2
124 | else:
125 | new_word.append(word[i])
126 | i += 1
127 | new_word = tuple(new_word)
128 | word = new_word
129 | if len(word) == 1:
130 | break
131 | else:
132 | pairs = get_pairs(word)
133 | word = ' '.join(word)
134 | self.cache[token] = word
135 | return word
136 |
137 | def encode(self, text):
138 | bpe_tokens = []
139 | text = whitespace_clean(basic_clean(text)).lower()
140 | for token in re.findall(self.pat, text):
141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
142 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
143 | return bpe_tokens
144 |
145 | def decode(self, tokens):
146 | text = ''.join([self.decoder[token] for token in tokens])
147 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
148 | return text
149 |
150 | # pylint: enable=line-too-long
151 | # pylint: enable=inconsistent-quotes
152 | # pylint: enable=missing-class-docstring
153 | # pylint: enable=missing-function-docstring
154 | # pylint: enable=bare-except
155 | # pylint: enable=no-else-break
156 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/README.md:
--------------------------------------------------------------------------------
1 | # Operators for StyleGAN2
2 |
3 | All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4 |
5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
12 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
14 |
15 | We make following slight modifications beyond disabling some lint warnings:
16 |
17 | - Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
18 | - Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators.
19 | - Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
20 | - Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch).
21 | - Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
22 | - Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
23 | - Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
24 | - Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
25 | - Line 66 of file `conv2d_gradfix.py`: Update PyTorch version check considering the sustained development of the community.
26 | - Line 47 of file `grid_sample_gradfix.py`: Update PyTorch version check considering the sustained development of the community.
27 | - Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
29 |
30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
31 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/stylegan2_official_ops/__init__.py
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Custom replacement for convolution operators.
12 |
13 | Operators in this file support arbitrarily high order gradients with zero
14 | performance penalty. Please set `impl` as `cuda` to use faster customized
15 | operators, OR as `ref` to use native `torch.nn.functional.conv2d` and
16 | `torch.nn.functional.conv_transpose2d`.
17 |
18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
19 | """
20 |
21 | # pylint: disable=redefined-builtin
22 | # pylint: disable=arguments-differ
23 | # pylint: disable=protected-access
24 | # pylint: disable=line-too-long
25 | # pylint: disable=global-statement
26 | # pylint: disable=missing-class-docstring
27 | # pylint: disable=missing-function-docstring
28 | # pylint: disable=deprecated-module
29 | # pylint: disable=wrong-import-order
30 |
31 | import warnings
32 | import contextlib
33 | import torch
34 |
35 | from distutils.version import LooseVersion
36 |
37 | enabled = True # Enable the custom op by setting this to true.
38 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
39 |
40 | @contextlib.contextmanager
41 | def no_weight_gradients():
42 | global weight_gradients_disabled
43 | old = weight_gradients_disabled
44 | weight_gradients_disabled = True
45 | yield
46 | weight_gradients_disabled = old
47 |
48 | #----------------------------------------------------------------------------
49 |
50 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'):
51 | if impl == 'cuda' and _should_use_custom_op(input):
52 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
53 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
54 |
55 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'):
56 | if impl == 'cuda' and _should_use_custom_op(input):
57 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
58 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
59 |
60 | #----------------------------------------------------------------------------
61 |
62 | def _should_use_custom_op(input):
63 | assert isinstance(input, torch.Tensor)
64 | if (not enabled) or (not torch.backends.cudnn.enabled):
65 | return False
66 | if input.device.type != 'cuda':
67 | return False
68 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
69 | return True
70 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().')
71 | return False
72 |
73 | def _tuple_of_ints(xs, ndim):
74 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
75 | assert len(xs) == ndim
76 | assert all(isinstance(x, int) for x in xs)
77 | return xs
78 |
79 | #----------------------------------------------------------------------------
80 |
81 | _conv2d_gradfix_cache = dict()
82 |
83 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
84 | # Parse arguments.
85 | ndim = 2
86 | weight_shape = tuple(weight_shape)
87 | stride = _tuple_of_ints(stride, ndim)
88 | padding = _tuple_of_ints(padding, ndim)
89 | output_padding = _tuple_of_ints(output_padding, ndim)
90 | dilation = _tuple_of_ints(dilation, ndim)
91 |
92 | # Lookup from cache.
93 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
94 | if key in _conv2d_gradfix_cache:
95 | return _conv2d_gradfix_cache[key]
96 |
97 | # Validate arguments.
98 | assert groups >= 1
99 | assert len(weight_shape) == ndim + 2
100 | assert all(stride[i] >= 1 for i in range(ndim))
101 | assert all(padding[i] >= 0 for i in range(ndim))
102 | assert all(dilation[i] >= 0 for i in range(ndim))
103 | if not transpose:
104 | assert all(output_padding[i] == 0 for i in range(ndim))
105 | else: # transpose
106 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
107 |
108 | # Helpers.
109 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
110 | def calc_output_padding(input_shape, output_shape):
111 | if transpose:
112 | return [0, 0]
113 | return [
114 | input_shape[i + 2]
115 | - (output_shape[i + 2] - 1) * stride[i]
116 | - (1 - 2 * padding[i])
117 | - dilation[i] * (weight_shape[i + 2] - 1)
118 | for i in range(ndim)
119 | ]
120 |
121 | # Forward & backward.
122 | class Conv2d(torch.autograd.Function):
123 | @staticmethod
124 | def forward(ctx, input, weight, bias):
125 | assert weight.shape == weight_shape
126 | if not transpose:
127 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128 | else: # transpose
129 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
130 | ctx.save_for_backward(input, weight)
131 | return output
132 |
133 | @staticmethod
134 | def backward(ctx, grad_output):
135 | input, weight = ctx.saved_tensors
136 | grad_input = None
137 | grad_weight = None
138 | grad_bias = None
139 |
140 | if ctx.needs_input_grad[0]:
141 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
142 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None)
143 | assert grad_input.shape == input.shape
144 |
145 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
146 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
147 | assert grad_weight.shape == weight_shape
148 |
149 | if ctx.needs_input_grad[2]:
150 | grad_bias = grad_output.sum([0, 2, 3])
151 |
152 | return grad_input, grad_weight, grad_bias
153 |
154 | # Gradient with respect to the weights.
155 | class Conv2dGradWeight(torch.autograd.Function):
156 | @staticmethod
157 | def forward(ctx, grad_output, input):
158 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight')
159 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
160 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
161 | assert grad_weight.shape == weight_shape
162 | ctx.save_for_backward(grad_output, input)
163 | return grad_weight
164 |
165 | @staticmethod
166 | def backward(ctx, grad2_grad_weight):
167 | grad_output, input = ctx.saved_tensors
168 | grad2_grad_output = None
169 | grad2_input = None
170 |
171 | if ctx.needs_input_grad[0]:
172 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
173 | assert grad2_grad_output.shape == grad_output.shape
174 |
175 | if ctx.needs_input_grad[1]:
176 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape)
177 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None)
178 | assert grad2_input.shape == input.shape
179 |
180 | return grad2_grad_output, grad2_input
181 |
182 | _conv2d_gradfix_cache[key] = Conv2d
183 | return Conv2d
184 |
185 | #----------------------------------------------------------------------------
186 |
187 | # pylint: enable=redefined-builtin
188 | # pylint: enable=arguments-differ
189 | # pylint: enable=protected-access
190 | # pylint: enable=line-too-long
191 | # pylint: enable=global-statement
192 | # pylint: enable=missing-class-docstring
193 | # pylint: enable=missing-function-docstring
194 | # pylint: enable=deprecated-module
195 | # pylint: enable=wrong-import-order
196 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """2D convolution with optional up/downsampling.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14 | """
15 |
16 | # pylint: disable=line-too-long
17 |
18 | import torch
19 |
20 | from . import misc
21 | from . import conv2d_gradfix
22 | from . import upfirdn2d
23 | from .upfirdn2d import _parse_padding
24 | from .upfirdn2d import _get_filter_size
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def _get_weight_shape(w):
29 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
30 | shape = [int(sz) for sz in w.shape]
31 | misc.assert_shape(w, shape)
32 | return shape
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'):
37 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
38 | """
39 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
40 |
41 | # Flip weight if requested.
42 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
43 | w = w.flip([2, 3])
44 |
45 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
46 | # 1x1 kernel + memory_format=channels_last + less than 64 channels.
47 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
48 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
49 | if out_channels <= 4 and groups == 1:
50 | in_shape = x.shape
51 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
52 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
53 | else:
54 | x = x.to(memory_format=torch.contiguous_format)
55 | w = w.to(memory_format=torch.contiguous_format)
56 | x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl)
57 | return x.to(memory_format=torch.channels_last)
58 |
59 | # Otherwise => execute using conv2d_gradfix.
60 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
61 | return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl)
62 |
63 | #----------------------------------------------------------------------------
64 |
65 | @misc.profiled_function
66 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'):
67 | r"""2D convolution with optional up/downsampling.
68 |
69 | Padding is performed only once at the beginning, not between the operations.
70 |
71 | Args:
72 | x: Input tensor of shape
73 | `[batch_size, in_channels, in_height, in_width]`.
74 | w: Weight tensor of shape
75 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
76 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
77 | calling upfirdn2d.setup_filter(). None = identity (default).
78 | up: Integer upsampling factor (default: 1).
79 | down: Integer downsampling factor (default: 1).
80 | padding: Padding with respect to the upsampled image. Can be a single number
81 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
82 | (default: 0).
83 | groups: Split input channels into N groups (default: 1).
84 | flip_weight: False = convolution, True = correlation (default: True).
85 | flip_filter: False = convolution, True = correlation (default: False).
86 | impl: Implementation mode of customized ops. 'ref' for native PyTorch
87 | implementation, 'cuda' for `.cu` implementation
88 | (default: 'cuda').
89 |
90 | Returns:
91 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
92 | """
93 | # Validate arguments.
94 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
95 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
96 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
97 | assert isinstance(up, int) and (up >= 1)
98 | assert isinstance(down, int) and (down >= 1)
99 | assert isinstance(groups, int) and (groups >= 1)
100 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
101 | fw, fh = _get_filter_size(f)
102 | px0, px1, py0, py1 = _parse_padding(padding)
103 |
104 | # Adjust padding to account for up/downsampling.
105 | if up > 1:
106 | px0 += (fw + up - 1) // 2
107 | px1 += (fw - up) // 2
108 | py0 += (fh + up - 1) // 2
109 | py1 += (fh - up) // 2
110 | if down > 1:
111 | px0 += (fw - down + 1) // 2
112 | px1 += (fw - down) // 2
113 | py0 += (fh - down + 1) // 2
114 | py1 += (fh - down) // 2
115 |
116 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
117 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
118 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
119 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
120 | return x
121 |
122 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
123 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
124 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
125 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
126 | return x
127 |
128 | # Fast path: downsampling only => use strided convolution.
129 | if down > 1 and up == 1:
130 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
131 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl)
132 | return x
133 |
134 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
135 | if up > 1:
136 | if groups == 1:
137 | w = w.transpose(0, 1)
138 | else:
139 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
140 | w = w.transpose(1, 2)
141 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
142 | px0 -= kw - 1
143 | px1 -= kw - up
144 | py0 -= kh - 1
145 | py1 -= kh - up
146 | pxt = max(min(-px0, -px1), 0)
147 | pyt = max(min(-py0, -py1), 0)
148 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl)
149 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl)
150 | if down > 1:
151 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
152 | return x
153 |
154 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
155 | if up == 1 and down == 1:
156 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
157 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl)
158 |
159 | # Fallback: Generic reference implementation.
160 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
161 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
162 | if down > 1:
163 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
164 | return x
165 |
166 | #----------------------------------------------------------------------------
167 |
168 | # pylint: enable=line-too-long
169 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/custom_ops.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Utility functions to setup customized operators.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=missing-function-docstring
18 | # pylint: disable=useless-suppression
19 | # pylint: disable=inconsistent-quotes
20 |
21 | import os
22 | import glob
23 | import importlib
24 | import hashlib
25 | import shutil
26 | from pathlib import Path
27 |
28 | import torch
29 | from torch.utils.file_baton import FileBaton
30 | import torch.utils.cpp_extension
31 |
32 | #----------------------------------------------------------------------------
33 | # Global options.
34 |
35 | verbosity = 'none' # Verbosity level: 'none', 'brief', 'full'
36 |
37 | #----------------------------------------------------------------------------
38 | # Internal helper funcs.
39 |
40 | def _find_compiler_bindir():
41 | patterns = [
42 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
43 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
44 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
45 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
46 | ]
47 | for pattern in patterns:
48 | matches = sorted(glob.glob(pattern))
49 | if len(matches):
50 | return matches[-1]
51 | return None
52 |
53 | def _find_compiler_bindir_posix():
54 | patterns = [
55 | '/usr/local/cuda/bin'
56 | ]
57 | for pattern in patterns:
58 | matches = sorted(glob.glob(pattern))
59 | if len(matches):
60 | return matches[-1]
61 | return None
62 |
63 | #----------------------------------------------------------------------------
64 | # Main entry point for compiling and loading C++/CUDA plugins.
65 |
66 | _cached_plugins = dict()
67 |
68 | def get_plugin(module_name, sources, **build_kwargs):
69 | assert verbosity in ['none', 'brief', 'full']
70 |
71 | # Already cached?
72 | if module_name in _cached_plugins:
73 | return _cached_plugins[module_name]
74 |
75 | # Print status.
76 | if verbosity == 'full':
77 | print(f'Setting up PyTorch plugin "{module_name}"...')
78 | elif verbosity == 'brief':
79 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
80 |
81 | try: # pylint: disable=too-many-nested-blocks
82 | # Make sure we can find the necessary compiler binaries.
83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
84 | compiler_bindir = _find_compiler_bindir()
85 | if compiler_bindir is None:
86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
87 | os.environ['PATH'] += ';' + compiler_bindir
88 |
89 | elif os.name == 'posix':
90 | compiler_bindir = _find_compiler_bindir_posix()
91 | if compiler_bindir is None:
92 | raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".')
93 | os.environ['PATH'] += ';' + compiler_bindir
94 |
95 | # Compile and load.
96 | verbose_build = (verbosity == 'full')
97 |
98 | # Incremental build md5sum trickery. Copies all the input source files
99 | # into a cached build directory under a combined md5 digest of the input
100 | # source files. Copying is done only if the combined digest has changed.
101 | # This keeps input file timestamps and filenames the same as in previous
102 | # extension builds, allowing for fast incremental rebuilds.
103 | #
104 | # This optimization is done only in case all the source files reside in
105 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
106 | # environment variable is set (we take this as a signal that the user
107 | # actually cares about this.)
108 | source_dirs_set = set(os.path.dirname(source) for source in sources)
109 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ):
110 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file()))
111 |
112 | # Compute a combined hash digest for all source files in the same
113 | # custom op directory (usually .cu, .cpp, .py and .h files).
114 | hash_md5 = hashlib.md5()
115 | for src in all_source_files:
116 | with open(src, 'rb') as f:
117 | hash_md5.update(f.read())
118 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
119 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest())
120 |
121 | if not os.path.isdir(digest_build_dir):
122 | os.makedirs(digest_build_dir, exist_ok=True)
123 | baton = FileBaton(os.path.join(digest_build_dir, 'lock'))
124 | if baton.try_acquire():
125 | try:
126 | for src in all_source_files:
127 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src)))
128 | finally:
129 | baton.release()
130 | else:
131 | # Someone else is copying source files under the digest dir,
132 | # wait until done and continue.
133 | baton.wait()
134 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources]
135 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir,
136 | verbose=verbose_build, sources=digest_sources, **build_kwargs)
137 | else:
138 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
139 | module = importlib.import_module(module_name)
140 |
141 | except:
142 | if verbosity == 'brief':
143 | print('Failed!')
144 | raise
145 |
146 | # Print status and add to cache.
147 | if verbosity == 'full':
148 | print(f'Done setting up PyTorch plugin "{module_name}".')
149 | elif verbosity == 'brief':
150 | print('Done.')
151 | _cached_plugins[module_name] = module
152 | return module
153 |
154 | #----------------------------------------------------------------------------
155 |
156 | # pylint: enable=line-too-long
157 | # pylint: enable=missing-function-docstring
158 | # pylint: enable=useless-suppression
159 | # pylint: enable=inconsistent-quotes
160 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/fma.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=missing-function-docstring
18 |
19 | import torch
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def fma(a, b, c, impl='cuda'): # => a * b + c
24 | if impl == 'cuda':
25 | return _FusedMultiplyAdd.apply(a, b, c)
26 | return torch.addcmul(c, a, b)
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
31 | @staticmethod
32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
33 | out = torch.addcmul(c, a, b)
34 | ctx.save_for_backward(a, b)
35 | ctx.c_shape = c.shape
36 | return out
37 |
38 | @staticmethod
39 | def backward(ctx, dout): # pylint: disable=arguments-differ
40 | a, b = ctx.saved_tensors
41 | c_shape = ctx.c_shape
42 | da = None
43 | db = None
44 | dc = None
45 |
46 | if ctx.needs_input_grad[0]:
47 | da = _unbroadcast(dout * b, a.shape)
48 |
49 | if ctx.needs_input_grad[1]:
50 | db = _unbroadcast(dout * a, b.shape)
51 |
52 | if ctx.needs_input_grad[2]:
53 | dc = _unbroadcast(dout, c_shape)
54 |
55 | return da, db, dc
56 |
57 | #----------------------------------------------------------------------------
58 |
59 | def _unbroadcast(x, shape):
60 | extra_dims = x.ndim - len(shape)
61 | assert extra_dims >= 0
62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
63 | if len(dim):
64 | x = x.sum(dim=dim, keepdim=True)
65 | if extra_dims:
66 | x = x.reshape(-1, *x.shape[extra_dims+1:])
67 | assert x.shape == shape
68 | return x
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | # pylint: enable=line-too-long
73 | # pylint: enable=missing-function-docstring
74 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample`.
12 |
13 | This is useful for differentiable augmentation. This customized operator
14 | supports arbitrarily high order gradients between the input and output. Only
15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
16 | `align_corners=False`.
17 |
18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch
19 | """
20 |
21 | # pylint: disable=redefined-builtin
22 | # pylint: disable=arguments-differ
23 | # pylint: disable=protected-access
24 | # pylint: disable=line-too-long
25 | # pylint: disable=missing-function-docstring
26 | # pylint: disable=deprecated-module
27 | # pylint: disable=wrong-import-order
28 |
29 | import warnings
30 | import torch
31 | from distutils.version import LooseVersion
32 |
33 | #----------------------------------------------------------------------------
34 |
35 | enabled = True # Enable the custom op by setting this to true.
36 |
37 | #----------------------------------------------------------------------------
38 |
39 | def grid_sample(input, grid, impl='cuda'):
40 | if impl == 'cuda' and _should_use_custom_op():
41 | return _GridSample2dForward.apply(input, grid)
42 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
43 |
44 | #----------------------------------------------------------------------------
45 |
46 | def _should_use_custom_op():
47 | if not enabled:
48 | return False
49 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
50 | return True
51 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
52 | return False
53 |
54 | #----------------------------------------------------------------------------
55 |
56 | class _GridSample2dForward(torch.autograd.Function):
57 | @staticmethod
58 | def forward(ctx, input, grid):
59 | assert input.ndim == 4
60 | assert grid.ndim == 4
61 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
62 | ctx.save_for_backward(input, grid)
63 | return output
64 |
65 | @staticmethod
66 | def backward(ctx, grad_output):
67 | input, grid = ctx.saved_tensors
68 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
69 | return grad_input, grad_grid
70 |
71 | #----------------------------------------------------------------------------
72 |
73 | class _GridSample2dBackward(torch.autograd.Function):
74 | @staticmethod
75 | def forward(ctx, grad_output, input, grid):
76 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
77 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
78 | ctx.save_for_backward(grid)
79 | return grad_input, grad_grid
80 |
81 | @staticmethod
82 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
83 | _ = grad2_grad_grid # unused
84 | grid, = ctx.saved_tensors
85 | grad2_grad_output = None
86 | grad2_input = None
87 | grad2_grid = None
88 |
89 | if ctx.needs_input_grad[0]:
90 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
91 |
92 | assert not ctx.needs_input_grad[2]
93 | return grad2_grad_output, grad2_input, grad2_grid
94 |
95 | #----------------------------------------------------------------------------
96 |
97 | # pylint: enable=redefined-builtin
98 | # pylint: enable=arguments-differ
99 | # pylint: enable=protected-access
100 | # pylint: enable=line-too-long
101 | # pylint: enable=missing-function-docstring
102 | # pylint: enable=deprecated-module
103 | # pylint: enable=wrong-import-order
104 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
29 |
30 | // Create output tensor.
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
37 |
38 | // Initialize CUDA kernel parameters.
39 | upfirdn2d_kernel_params p;
40 | p.x = x.data_ptr();
41 | p.f = f.data_ptr();
42 | p.y = y.data_ptr();
43 | p.up = make_int2(upx, upy);
44 | p.down = make_int2(downx, downy);
45 | p.pad0 = make_int2(padx0, pady0);
46 | p.flip = (flip) ? 1 : 0;
47 | p.gain = gain;
48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
56 |
57 | // Choose CUDA kernel.
58 | upfirdn2d_kernel_spec spec;
59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
60 | {
61 | spec = choose_upfirdn2d_kernel(p);
62 | });
63 |
64 | // Set looping options.
65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
66 | p.loopMinor = spec.loopMinor;
67 | p.loopX = spec.loopX;
68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
70 |
71 | // Compute grid size.
72 | dim3 blockSize, gridSize;
73 | if (spec.tileOutW < 0) // large
74 | {
75 | blockSize = dim3(4, 32, 1);
76 | gridSize = dim3(
77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
79 | p.launchMajor);
80 | }
81 | else // small
82 | {
83 | blockSize = dim3(256, 1, 1);
84 | gridSize = dim3(
85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
87 | p.launchMajor);
88 | }
89 |
90 | // Launch CUDA kernel.
91 | void* args[] = {&p};
92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
93 | return y;
94 | }
95 |
96 | //------------------------------------------------------------------------
97 |
98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
99 | {
100 | m.def("upfirdn2d", &upfirdn2d);
101 | }
102 |
103 | //------------------------------------------------------------------------
104 |
--------------------------------------------------------------------------------
/third_party/stylegan2_official_ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/README.md:
--------------------------------------------------------------------------------
1 | # Operators for StyleGAN2
2 |
3 | All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including
4 |
5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator.
6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering.
7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel.
8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map.
9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map.
10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map.
11 | - `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing.
12 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
13 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty.
14 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`)
15 |
16 | We make following slight modifications beyond disabling some lint warnings:
17 |
18 | - Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
19 | - Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators.
20 | - Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*)
21 | - Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3).
22 | - Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`.
23 | - Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default.
24 | - Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator.
25 | - Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default.
26 | - Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators.
27 | - Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators.
28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator.
29 |
30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default.
31 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/stylegan3_official_ops/__init__.py
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/bias_act.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "bias_act.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17 | {
18 | if (x.dim() != y.dim())
19 | return false;
20 | for (int64_t i = 0; i < x.dim(); i++)
21 | {
22 | if (x.size(i) != y.size(i))
23 | return false;
24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25 | return false;
26 | }
27 | return true;
28 | }
29 |
30 | //------------------------------------------------------------------------
31 |
32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33 | {
34 | // Validate arguments.
35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44 | TORCH_CHECK(grad >= 0, "grad must be non-negative");
45 |
46 | // Validate layout.
47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52 |
53 | // Create output tensor.
54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55 | torch::Tensor y = torch::empty_like(x);
56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57 |
58 | // Initialize CUDA kernel parameters.
59 | bias_act_kernel_params p;
60 | p.x = x.data_ptr();
61 | p.b = (b.numel()) ? b.data_ptr() : NULL;
62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65 | p.y = y.data_ptr();
66 | p.grad = grad;
67 | p.act = act;
68 | p.alpha = alpha;
69 | p.gain = gain;
70 | p.clamp = clamp;
71 | p.sizeX = (int)x.numel();
72 | p.sizeB = (int)b.numel();
73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74 |
75 | // Choose CUDA kernel.
76 | void* kernel;
77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78 | {
79 | kernel = choose_bias_act_kernel(p);
80 | });
81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82 |
83 | // Launch CUDA kernel.
84 | p.loopX = 4;
85 | int blockSize = 4 * 32;
86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87 | void* args[] = {&p};
88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89 | return y;
90 | }
91 |
92 | //------------------------------------------------------------------------
93 |
94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95 | {
96 | m.def("bias_act", &bias_act);
97 | }
98 |
99 | //------------------------------------------------------------------------
100 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/bias_act.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include "bias_act.h"
11 |
12 | //------------------------------------------------------------------------
13 | // Helpers.
14 |
15 | template struct InternalType;
16 | template <> struct InternalType { typedef double scalar_t; };
17 | template <> struct InternalType { typedef float scalar_t; };
18 | template <> struct InternalType { typedef float scalar_t; };
19 |
20 | //------------------------------------------------------------------------
21 | // CUDA kernel.
22 |
23 | template
24 | __global__ void bias_act_kernel(bias_act_kernel_params p)
25 | {
26 | typedef typename InternalType::scalar_t scalar_t;
27 | int G = p.grad;
28 | scalar_t alpha = (scalar_t)p.alpha;
29 | scalar_t gain = (scalar_t)p.gain;
30 | scalar_t clamp = (scalar_t)p.clamp;
31 | scalar_t one = (scalar_t)1;
32 | scalar_t two = (scalar_t)2;
33 | scalar_t expRange = (scalar_t)80;
34 | scalar_t halfExpRange = (scalar_t)40;
35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37 |
38 | // Loop over elements.
39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41 | {
42 | // Load.
43 | scalar_t x = (scalar_t)((const T*)p.x)[xi];
44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48 | scalar_t yy = (gain != 0) ? yref / gain : 0;
49 | scalar_t y = 0;
50 |
51 | // Apply bias.
52 | ((G == 0) ? x : xref) += b;
53 |
54 | // linear
55 | if (A == 1)
56 | {
57 | if (G == 0) y = x;
58 | if (G == 1) y = x;
59 | }
60 |
61 | // relu
62 | if (A == 2)
63 | {
64 | if (G == 0) y = (x > 0) ? x : 0;
65 | if (G == 1) y = (yy > 0) ? x : 0;
66 | }
67 |
68 | // lrelu
69 | if (A == 3)
70 | {
71 | if (G == 0) y = (x > 0) ? x : x * alpha;
72 | if (G == 1) y = (yy > 0) ? x : x * alpha;
73 | }
74 |
75 | // tanh
76 | if (A == 4)
77 | {
78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79 | if (G == 1) y = x * (one - yy * yy);
80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81 | }
82 |
83 | // sigmoid
84 | if (A == 5)
85 | {
86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87 | if (G == 1) y = x * yy * (one - yy);
88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89 | }
90 |
91 | // elu
92 | if (A == 6)
93 | {
94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97 | }
98 |
99 | // selu
100 | if (A == 7)
101 | {
102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105 | }
106 |
107 | // softplus
108 | if (A == 8)
109 | {
110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111 | if (G == 1) y = x * (one - exp(-yy));
112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113 | }
114 |
115 | // swish
116 | if (A == 9)
117 | {
118 | if (G == 0)
119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120 | else
121 | {
122 | scalar_t c = exp(xref);
123 | scalar_t d = c + one;
124 | if (G == 1)
125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126 | else
127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129 | }
130 | }
131 |
132 | // Apply gain.
133 | y *= gain * dy;
134 |
135 | // Clamp.
136 | if (clamp >= 0)
137 | {
138 | if (G == 0)
139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140 | else
141 | y = (yref > -clamp & yref < clamp) ? y : 0;
142 | }
143 |
144 | // Store.
145 | ((T*)p.y)[xi] = (T)y;
146 | }
147 | }
148 |
149 | //------------------------------------------------------------------------
150 | // CUDA kernel selection.
151 |
152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153 | {
154 | if (p.act == 1) return (void*)bias_act_kernel;
155 | if (p.act == 2) return (void*)bias_act_kernel;
156 | if (p.act == 3) return (void*)bias_act_kernel;
157 | if (p.act == 4) return (void*)bias_act_kernel;
158 | if (p.act == 5) return (void*)bias_act_kernel;
159 | if (p.act == 6) return (void*)bias_act_kernel;
160 | if (p.act == 7) return (void*)bias_act_kernel;
161 | if (p.act == 8) return (void*)bias_act_kernel;
162 | if (p.act == 9) return (void*)bias_act_kernel;
163 | return NULL;
164 | }
165 |
166 | //------------------------------------------------------------------------
167 | // Template specializations.
168 |
169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
172 |
173 | //------------------------------------------------------------------------
174 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/bias_act.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | //------------------------------------------------------------------------
10 | // CUDA kernel parameters.
11 |
12 | struct bias_act_kernel_params
13 | {
14 | const void* x; // [sizeX]
15 | const void* b; // [sizeB] or NULL
16 | const void* xref; // [sizeX] or NULL
17 | const void* yref; // [sizeX] or NULL
18 | const void* dy; // [sizeX] or NULL
19 | void* y; // [sizeX]
20 |
21 | int grad;
22 | int act;
23 | float alpha;
24 | float gain;
25 | float clamp;
26 |
27 | int sizeX;
28 | int sizeB;
29 | int stepB;
30 | int loopX;
31 | };
32 |
33 | //------------------------------------------------------------------------
34 | // CUDA kernel selection.
35 |
36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37 |
38 | //------------------------------------------------------------------------
39 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/conv2d_resample.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """2D convolution with optional up/downsampling.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan3
14 | """
15 |
16 | # pylint: disable=line-too-long
17 |
18 | import torch
19 |
20 | from . import misc
21 | from . import conv2d_gradfix
22 | from . import upfirdn2d
23 | from .upfirdn2d import _parse_padding
24 | from .upfirdn2d import _get_filter_size
25 |
26 | #----------------------------------------------------------------------------
27 |
28 | def _get_weight_shape(w):
29 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant
30 | shape = [int(sz) for sz in w.shape]
31 | misc.assert_shape(w, shape)
32 | return shape
33 |
34 | #----------------------------------------------------------------------------
35 |
36 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'):
37 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
38 | """
39 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
40 |
41 | # Flip weight if requested.
42 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
43 | if not flip_weight and (kw > 1 or kh > 1):
44 | w = w.flip([2, 3])
45 |
46 | # Execute using conv2d_gradfix.
47 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
48 | return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl)
49 |
50 | #----------------------------------------------------------------------------
51 |
52 | @misc.profiled_function
53 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'):
54 | r"""2D convolution with optional up/downsampling.
55 |
56 | Padding is performed only once at the beginning, not between the operations.
57 |
58 | Args:
59 | x: Input tensor of shape
60 | `[batch_size, in_channels, in_height, in_width]`.
61 | w: Weight tensor of shape
62 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
63 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by
64 | calling upfirdn2d.setup_filter(). None = identity (default).
65 | up: Integer upsampling factor (default: 1).
66 | down: Integer downsampling factor (default: 1).
67 | padding: Padding with respect to the upsampled image. Can be a single number
68 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
69 | (default: 0).
70 | groups: Split input channels into N groups (default: 1).
71 | flip_weight: False = convolution, True = correlation (default: True).
72 | flip_filter: False = convolution, True = correlation (default: False).
73 | impl: Implementation mode, 'cuda' for CUDA implementation, and 'ref' for
74 | native PyTorch implementation (default: 'cuda').
75 |
76 | Returns:
77 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
78 | """
79 | # Validate arguments.
80 | assert isinstance(x, torch.Tensor) and (x.ndim == 4)
81 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
82 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
83 | assert isinstance(up, int) and (up >= 1)
84 | assert isinstance(down, int) and (down >= 1)
85 | assert isinstance(groups, int) and (groups >= 1)
86 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
87 | fw, fh = _get_filter_size(f)
88 | px0, px1, py0, py1 = _parse_padding(padding)
89 |
90 | # Adjust padding to account for up/downsampling.
91 | if up > 1:
92 | px0 += (fw + up - 1) // 2
93 | px1 += (fw - up) // 2
94 | py0 += (fh + up - 1) // 2
95 | py1 += (fh - up) // 2
96 | if down > 1:
97 | px0 += (fw - down + 1) // 2
98 | px1 += (fw - down) // 2
99 | py0 += (fh - down + 1) // 2
100 | py1 += (fh - down) // 2
101 |
102 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
103 | if kw == 1 and kh == 1 and (down > 1 and up == 1):
104 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
105 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
106 | return x
107 |
108 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
109 | if kw == 1 and kh == 1 and (up > 1 and down == 1):
110 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
111 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
112 | return x
113 |
114 | # Fast path: downsampling only => use strided convolution.
115 | if down > 1 and up == 1:
116 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl)
117 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl)
118 | return x
119 |
120 | # Fast path: upsampling with optional downsampling => use transpose strided convolution.
121 | if up > 1:
122 | if groups == 1:
123 | w = w.transpose(0, 1)
124 | else:
125 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
126 | w = w.transpose(1, 2)
127 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
128 | px0 -= kw - 1
129 | px1 -= kw - up
130 | py0 -= kh - 1
131 | py1 -= kh - up
132 | pxt = max(min(-px0, -px1), 0)
133 | pyt = max(min(-py0, -py1), 0)
134 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl)
135 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl)
136 | if down > 1:
137 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
138 | return x
139 |
140 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
141 | if up == 1 and down == 1:
142 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
143 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl)
144 |
145 | # Fallback: Generic reference implementation.
146 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl)
147 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl)
148 | if down > 1:
149 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl)
150 | return x
151 |
152 | #----------------------------------------------------------------------------
153 |
154 | # pylint: enable=line-too-long
155 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/custom_ops.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Utility functions to setup customized operators.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan3
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=multiple-statements
18 | # pylint: disable=missing-function-docstring
19 | # pylint: disable=useless-suppression
20 | # pylint: disable=inconsistent-quotes
21 |
22 | import glob
23 | import hashlib
24 | import importlib
25 | import os
26 | import re
27 | import shutil
28 | import uuid
29 |
30 | import torch
31 | import torch.utils.cpp_extension
32 |
33 | #----------------------------------------------------------------------------
34 | # Global options.
35 |
36 | verbosity = 'none' # Verbosity level: 'none', 'brief', 'full'
37 |
38 | #----------------------------------------------------------------------------
39 | # Internal helper funcs.
40 |
41 | def _find_compiler_bindir():
42 | patterns = [
43 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
44 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
45 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
46 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
47 | ]
48 | for pattern in patterns:
49 | matches = sorted(glob.glob(pattern))
50 | if len(matches):
51 | return matches[-1]
52 | return None
53 |
54 | def _find_compiler_bindir_posix():
55 | patterns = [
56 | '/usr/local/cuda/bin'
57 | ]
58 | for pattern in patterns:
59 | matches = sorted(glob.glob(pattern))
60 | if len(matches):
61 | return matches[-1]
62 | return None
63 |
64 | #----------------------------------------------------------------------------
65 |
66 | def _get_mangled_gpu_name():
67 | name = torch.cuda.get_device_name().lower()
68 | out = []
69 | for c in name:
70 | if re.match('[a-z0-9_-]+', c):
71 | out.append(c)
72 | else:
73 | out.append('-')
74 | return ''.join(out)
75 |
76 | #----------------------------------------------------------------------------
77 | # Main entry point for compiling and loading C++/CUDA plugins.
78 |
79 | _cached_plugins = dict()
80 |
81 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
82 | assert verbosity in ['none', 'brief', 'full']
83 | if headers is None:
84 | headers = []
85 | if source_dir is not None:
86 | sources = [os.path.join(source_dir, fname) for fname in sources]
87 | headers = [os.path.join(source_dir, fname) for fname in headers]
88 |
89 | # Already cached?
90 | if module_name in _cached_plugins:
91 | return _cached_plugins[module_name]
92 |
93 | # Print status.
94 | if verbosity == 'full':
95 | print(f'Setting up PyTorch plugin "{module_name}"...')
96 | elif verbosity == 'brief':
97 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
98 | verbose_build = (verbosity == 'full')
99 |
100 | # Compile and load.
101 | try: # pylint: disable=too-many-nested-blocks
102 | # Make sure we can find the necessary compiler binaries.
103 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
104 | compiler_bindir = _find_compiler_bindir()
105 | if compiler_bindir is None:
106 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
107 | os.environ['PATH'] += ';' + compiler_bindir
108 |
109 | elif os.name == 'posix':
110 | compiler_bindir = _find_compiler_bindir_posix()
111 | if compiler_bindir is None:
112 | raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".')
113 | os.environ['PATH'] += ';' + compiler_bindir
114 |
115 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
116 | # break the build or unnecessarily restrict what's available to nvcc.
117 | # Unset it to let nvcc decide based on what's available on the
118 | # machine.
119 | os.environ['TORCH_CUDA_ARCH_LIST'] = ''
120 |
121 | # Incremental build md5sum trickery. Copies all the input source files
122 | # into a cached build directory under a combined md5 digest of the input
123 | # source files. Copying is done only if the combined digest has changed.
124 | # This keeps input file timestamps and filenames the same as in previous
125 | # extension builds, allowing for fast incremental rebuilds.
126 | #
127 | # This optimization is done only in case all the source files reside in
128 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
129 | # environment variable is set (we take this as a signal that the user
130 | # actually cares about this.)
131 | #
132 | # EDIT: We now do it regardless of TORCH_EXTENSIONS_DIR, in order to work
133 | # around the *.cu dependency bug in ninja config.
134 | #
135 | all_source_files = sorted(sources + headers)
136 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
137 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
138 |
139 | # Compute combined hash digest for all source files.
140 | hash_md5 = hashlib.md5()
141 | for src in all_source_files:
142 | with open(src, 'rb') as f:
143 | hash_md5.update(f.read())
144 |
145 | # Select cached build directory name.
146 | source_digest = hash_md5.hexdigest()
147 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
148 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
149 |
150 | if not os.path.isdir(cached_build_dir):
151 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
152 | os.makedirs(tmpdir)
153 | for src in all_source_files:
154 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
155 | try:
156 | os.replace(tmpdir, cached_build_dir) # atomic
157 | except OSError:
158 | # source directory already exists, delete tmpdir and its contents.
159 | shutil.rmtree(tmpdir)
160 | if not os.path.isdir(cached_build_dir): raise
161 |
162 | # Compile.
163 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
164 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
165 | verbose=verbose_build, sources=cached_sources, **build_kwargs)
166 | else:
167 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
168 |
169 | # Load.
170 | module = importlib.import_module(module_name)
171 |
172 | except:
173 | if verbosity == 'brief':
174 | print('Failed!')
175 | raise
176 |
177 | # Print status and add to cache dict.
178 | if verbosity == 'full':
179 | print(f'Done setting up PyTorch plugin "{module_name}".')
180 | elif verbosity == 'brief':
181 | print('Done.')
182 | _cached_plugins[module_name] = module
183 | return module
184 |
185 | #----------------------------------------------------------------------------
186 |
187 | # pylint: enable=line-too-long
188 | # pylint: enable=multiple-statements
189 | # pylint: enable=missing-function-docstring
190 | # pylint: enable=useless-suppression
191 | # pylint: enable=inconsistent-quotes
192 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct filtered_lrelu_kernel_params
15 | {
16 | // These parameters decide which kernel to use.
17 | int up; // upsampling ratio (1, 2, 4)
18 | int down; // downsampling ratio (1, 2, 4)
19 | int2 fuShape; // [size, 1] | [size, size]
20 | int2 fdShape; // [size, 1] | [size, size]
21 |
22 | int _dummy; // Alignment.
23 |
24 | // Rest of the parameters.
25 | const void* x; // Input tensor.
26 | void* y; // Output tensor.
27 | const void* b; // Bias tensor.
28 | unsigned char* s; // Sign tensor in/out. NULL if unused.
29 | const float* fu; // Upsampling filter.
30 | const float* fd; // Downsampling filter.
31 |
32 | int2 pad0; // Left/top padding.
33 | float gain; // Additional gain factor.
34 | float slope; // Leaky ReLU slope on negative side.
35 | float clamp; // Clamp after nonlinearity.
36 | int flip; // Filter kernel flip for gradient computation.
37 |
38 | int tilesXdim; // Original number of horizontal output tiles.
39 | int tilesXrep; // Number of horizontal tiles per CTA.
40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41 |
42 | int4 xShape; // [width, height, channel, batch]
43 | int4 yShape; // [width, height, channel, batch]
44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46 | int swLimit; // Active width of sign tensor in bytes.
47 |
48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49 | longlong4 yStride; //
50 | int64_t bStride; //
51 | longlong3 fuStride; //
52 | longlong3 fdStride; //
53 | };
54 |
55 | struct filtered_lrelu_act_kernel_params
56 | {
57 | void* x; // Input/output, modified in-place.
58 | unsigned char* s; // Sign tensor in/out. NULL if unused.
59 |
60 | float gain; // Additional gain factor.
61 | float slope; // Leaky ReLU slope on negative side.
62 | float clamp; // Clamp after nonlinearity.
63 |
64 | int4 xShape; // [width, height, channel, batch]
65 | longlong4 xStride; // Input/output tensor strides, same order as in shape.
66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68 | };
69 |
70 | //------------------------------------------------------------------------
71 | // CUDA kernel specialization.
72 |
73 | struct filtered_lrelu_kernel_spec
74 | {
75 | void* setup; // Function for filter kernel setup.
76 | void* exec; // Function for main operation.
77 | int2 tileOut; // Width/height of launch tile.
78 | int numWarps; // Number of warps per thread block, determines launch block size.
79 | int xrep; // For processing multiple horizontal tiles per thread block.
80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81 | };
82 |
83 | //------------------------------------------------------------------------
84 | // CUDA kernel selection.
85 |
86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87 | template void* choose_filtered_lrelu_act_kernel(void);
88 | template cudaError_t copy_filters(cudaStream_t stream);
89 |
90 | //------------------------------------------------------------------------
91 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_ns.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for no signs mode (no gradients required).
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_rd.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign read mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/filtered_lrelu_wr.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include "filtered_lrelu.cu"
10 |
11 | // Template/kernel specializations for sign write mode.
12 |
13 | // Full op, 32-bit indexing.
14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
16 |
17 | // Full op, 64-bit indexing.
18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
20 |
21 | // Activation/signs only for generic variant. 64-bit indexing.
22 | template void* choose_filtered_lrelu_act_kernel(void);
23 | template void* choose_filtered_lrelu_act_kernel(void);
24 | template void* choose_filtered_lrelu_act_kernel(void);
25 |
26 | // Copy filters to constant memory.
27 | template cudaError_t copy_filters(cudaStream_t stream);
28 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/fma.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.
12 |
13 | Please refer to https://github.com/NVlabs/stylegan3
14 | """
15 |
16 | # pylint: disable=line-too-long
17 | # pylint: disable=missing-function-docstring
18 |
19 | import torch
20 |
21 | #----------------------------------------------------------------------------
22 |
23 | def fma(a, b, c, impl='cuda'): # => a * b + c
24 | if impl == 'cuda':
25 | return _FusedMultiplyAdd.apply(a, b, c)
26 | return torch.addcmul(c, a, b)
27 |
28 | #----------------------------------------------------------------------------
29 |
30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
31 | @staticmethod
32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ
33 | out = torch.addcmul(c, a, b)
34 | ctx.save_for_backward(a, b)
35 | ctx.c_shape = c.shape
36 | return out
37 |
38 | @staticmethod
39 | def backward(ctx, dout): # pylint: disable=arguments-differ
40 | a, b = ctx.saved_tensors
41 | c_shape = ctx.c_shape
42 | da = None
43 | db = None
44 | dc = None
45 |
46 | if ctx.needs_input_grad[0]:
47 | da = _unbroadcast(dout * b, a.shape)
48 |
49 | if ctx.needs_input_grad[1]:
50 | db = _unbroadcast(dout * a, b.shape)
51 |
52 | if ctx.needs_input_grad[2]:
53 | dc = _unbroadcast(dout, c_shape)
54 |
55 | return da, db, dc
56 |
57 | #----------------------------------------------------------------------------
58 |
59 | def _unbroadcast(x, shape):
60 | extra_dims = x.ndim - len(shape)
61 | assert extra_dims >= 0
62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
63 | if len(dim):
64 | x = x.sum(dim=dim, keepdim=True)
65 | if extra_dims:
66 | x = x.reshape(-1, *x.shape[extra_dims+1:])
67 | assert x.shape == shape
68 | return x
69 |
70 | #----------------------------------------------------------------------------
71 |
72 | # pylint: enable=line-too-long
73 | # pylint: enable=missing-function-docstring
74 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/grid_sample_gradfix.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 |
3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4 | #
5 | # NVIDIA CORPORATION and its licensors retain all intellectual property
6 | # and proprietary rights in and to this software, related documentation
7 | # and any modifications thereto. Any use, reproduction, disclosure or
8 | # distribution of this software and related documentation without an express
9 | # license agreement from NVIDIA CORPORATION is strictly prohibited.
10 |
11 | """Custom replacement for `torch.nn.functional.grid_sample`.
12 |
13 | This is useful for differentiable augmentation. This customized operator
14 | supports arbitrarily high order gradients between the input and output. Only
15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and
16 | `align_corners=False`.
17 |
18 | Please refer to https://github.com/NVlabs/stylegan3
19 | """
20 |
21 | # pylint: disable=redefined-builtin
22 | # pylint: disable=arguments-differ
23 | # pylint: disable=protected-access
24 | # pylint: disable=line-too-long
25 | # pylint: disable=missing-function-docstring
26 |
27 | import torch
28 | from pkg_resources import parse_version
29 |
30 | #----------------------------------------------------------------------------
31 |
32 | enabled = True # Enable the custom op by setting this to true.
33 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
34 | #----------------------------------------------------------------------------
35 |
36 | def grid_sample(input, grid, impl='cuda'):
37 | if impl == 'cuda' and _should_use_custom_op():
38 | return _GridSample2dForward.apply(input, grid)
39 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
40 |
41 | #----------------------------------------------------------------------------
42 |
43 | def _should_use_custom_op():
44 | return enabled
45 |
46 | #----------------------------------------------------------------------------
47 |
48 | class _GridSample2dForward(torch.autograd.Function):
49 | @staticmethod
50 | def forward(ctx, input, grid):
51 | assert input.ndim == 4
52 | assert grid.ndim == 4
53 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
54 | ctx.save_for_backward(input, grid)
55 | return output
56 |
57 | @staticmethod
58 | def backward(ctx, grad_output):
59 | input, grid = ctx.saved_tensors
60 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
61 | return grad_input, grad_grid
62 |
63 | #----------------------------------------------------------------------------
64 |
65 | class _GridSample2dBackward(torch.autograd.Function):
66 | @staticmethod
67 | def forward(ctx, grad_output, input, grid):
68 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
69 | if _use_pytorch_1_11_api:
70 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
71 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
72 | else:
73 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
74 | ctx.save_for_backward(grid)
75 | return grad_input, grad_grid
76 |
77 | @staticmethod
78 | def backward(ctx, grad2_grad_input, grad2_grad_grid):
79 | _ = grad2_grad_grid # unused
80 | grid, = ctx.saved_tensors
81 | grad2_grad_output = None
82 | grad2_input = None
83 | grad2_grid = None
84 |
85 | if ctx.needs_input_grad[0]:
86 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
87 |
88 | assert not ctx.needs_input_grad[2]
89 | return grad2_grad_output, grad2_input, grad2_grid
90 |
91 | #----------------------------------------------------------------------------
92 |
93 | # pylint: enable=redefined-builtin
94 | # pylint: enable=arguments-differ
95 | # pylint: enable=protected-access
96 | # pylint: enable=line-too-long
97 | # pylint: enable=missing-function-docstring
98 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 | #include
11 | #include
12 | #include "upfirdn2d.h"
13 |
14 | //------------------------------------------------------------------------
15 |
16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17 | {
18 | // Validate arguments.
19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24 | TORCH_CHECK(x.numel() > 0, "x has zero size");
25 | TORCH_CHECK(f.numel() > 0, "f has zero size");
26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32 |
33 | // Create output tensor.
34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41 |
42 | // Initialize CUDA kernel parameters.
43 | upfirdn2d_kernel_params p;
44 | p.x = x.data_ptr();
45 | p.f = f.data_ptr();
46 | p.y = y.data_ptr();
47 | p.up = make_int2(upx, upy);
48 | p.down = make_int2(downx, downy);
49 | p.pad0 = make_int2(padx0, pady0);
50 | p.flip = (flip) ? 1 : 0;
51 | p.gain = gain;
52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60 |
61 | // Choose CUDA kernel.
62 | upfirdn2d_kernel_spec spec;
63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64 | {
65 | spec = choose_upfirdn2d_kernel(p);
66 | });
67 |
68 | // Set looping options.
69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70 | p.loopMinor = spec.loopMinor;
71 | p.loopX = spec.loopX;
72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74 |
75 | // Compute grid size.
76 | dim3 blockSize, gridSize;
77 | if (spec.tileOutW < 0) // large
78 | {
79 | blockSize = dim3(4, 32, 1);
80 | gridSize = dim3(
81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83 | p.launchMajor);
84 | }
85 | else // small
86 | {
87 | blockSize = dim3(256, 1, 1);
88 | gridSize = dim3(
89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91 | p.launchMajor);
92 | }
93 |
94 | // Launch CUDA kernel.
95 | void* args[] = {&p};
96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97 | return y;
98 | }
99 |
100 | //------------------------------------------------------------------------
101 |
102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103 | {
104 | m.def("upfirdn2d", &upfirdn2d);
105 | }
106 |
107 | //------------------------------------------------------------------------
108 |
--------------------------------------------------------------------------------
/third_party/stylegan3_official_ops/upfirdn2d.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | //
3 | // NVIDIA CORPORATION and its licensors retain all intellectual property
4 | // and proprietary rights in and to this software, related documentation
5 | // and any modifications thereto. Any use, reproduction, disclosure or
6 | // distribution of this software and related documentation without an express
7 | // license agreement from NVIDIA CORPORATION is strictly prohibited.
8 |
9 | #include
10 |
11 | //------------------------------------------------------------------------
12 | // CUDA kernel parameters.
13 |
14 | struct upfirdn2d_kernel_params
15 | {
16 | const void* x;
17 | const float* f;
18 | void* y;
19 |
20 | int2 up;
21 | int2 down;
22 | int2 pad0;
23 | int flip;
24 | float gain;
25 |
26 | int4 inSize; // [width, height, channel, batch]
27 | int4 inStride;
28 | int2 filterSize; // [width, height]
29 | int2 filterStride;
30 | int4 outSize; // [width, height, channel, batch]
31 | int4 outStride;
32 | int sizeMinor;
33 | int sizeMajor;
34 |
35 | int loopMinor;
36 | int loopMajor;
37 | int loopX;
38 | int launchMinor;
39 | int launchMajor;
40 | };
41 |
42 | //------------------------------------------------------------------------
43 | // CUDA kernel specialization.
44 |
45 | struct upfirdn2d_kernel_spec
46 | {
47 | void* kernel;
48 | int tileOutW;
49 | int tileOutH;
50 | int loopMinor;
51 | int loopX;
52 | };
53 |
54 | //------------------------------------------------------------------------
55 | // CUDA kernel selection.
56 |
57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58 |
59 | //------------------------------------------------------------------------
60 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains utility functions used for distribution."""
3 |
4 | import contextlib
5 | import os
6 | import subprocess
7 |
8 | import torch
9 | import torch.distributed as dist
10 | import torch.multiprocessing as mp
11 |
12 | __all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module']
13 |
14 |
15 | def init_dist(launcher, backend='nccl', **kwargs):
16 | """Initializes distributed environment."""
17 | if mp.get_start_method(allow_none=True) is None:
18 | mp.set_start_method('spawn')
19 | if launcher == 'pytorch':
20 | rank = int(os.environ['RANK'])
21 | num_gpus = torch.cuda.device_count()
22 | torch.cuda.set_device(rank % num_gpus)
23 | dist.init_process_group(backend=backend, **kwargs)
24 | elif launcher == 'slurm':
25 | proc_id = int(os.environ['SLURM_PROCID'])
26 | ntasks = int(os.environ['SLURM_NTASKS'])
27 | node_list = os.environ['SLURM_NODELIST']
28 | num_gpus = torch.cuda.device_count()
29 | torch.cuda.set_device(proc_id % num_gpus)
30 | addr = subprocess.getoutput(
31 | f'scontrol show hostname {node_list} | head -n1')
32 | port = os.environ.get('PORT', 29500)
33 | os.environ['MASTER_PORT'] = str(port)
34 | os.environ['MASTER_ADDR'] = addr
35 | os.environ['WORLD_SIZE'] = str(ntasks)
36 | os.environ['RANK'] = str(proc_id)
37 | dist.init_process_group(backend=backend)
38 | else:
39 | raise NotImplementedError(f'Not implemented launcher type: '
40 | f'`{launcher}`!')
41 |
42 |
43 | def exit_dist():
44 | """Exits the distributed environment."""
45 | if dist.is_initialized():
46 | dist.destroy_process_group()
47 |
48 |
49 | @contextlib.contextmanager
50 | def ddp_sync(model, sync):
51 | """Controls whether the `DistributedDataParallel` model should be synced."""
52 | assert isinstance(model, torch.nn.Module)
53 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel)
54 | if sync or not is_ddp:
55 | yield
56 | else:
57 | with model.no_sync():
58 | yield
59 |
60 |
61 | def get_ddp_module(model):
62 | """Gets the module from `DistributedDataParallel`."""
63 | assert isinstance(model, torch.nn.Module)
64 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel)
65 | if is_ddp:
66 | return model.module
67 | return model
68 |
--------------------------------------------------------------------------------
/utils/file_transmitters/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all file transmitters."""
3 |
4 | from .local_file_transmitter import LocalFileTransmitter
5 | from .dummy_file_transmitter import DummyFileTransmitter
6 |
7 | __all__ = ['build_file_transmitter']
8 |
9 | _TRANSMITTERS = {
10 | 'local': LocalFileTransmitter,
11 | 'dummy': DummyFileTransmitter,
12 | }
13 |
14 |
15 | def build_file_transmitter(transmitter_type='local', **kwargs):
16 | """Builds a file transmitter.
17 |
18 | Args:
19 | transmitter_type: Type of the file transmitter_type, which is case
20 | insensitive. (default: `normal`)
21 | **kwargs: Additional arguments to build the file transmitter.
22 |
23 | Raises:
24 | ValueError: If the `transmitter_type` is not supported.
25 | """
26 | transmitter_type = transmitter_type.lower()
27 | if transmitter_type not in _TRANSMITTERS:
28 | raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n'
29 | f'Types allowed: {list(_TRANSMITTERS)}.')
30 | return _TRANSMITTERS[transmitter_type](**kwargs)
31 |
--------------------------------------------------------------------------------
/utils/file_transmitters/base_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the base class to transmit files across file systems.
3 |
4 | Basically, a file transmitter connects the local file system, on which the
5 | programme runs, to a remote file system. This is particularly used for
6 | (1) pulling files that are required by the programme from remote, and
7 | (2) pushing results that are produced by the programme to remote. In this way,
8 | the programme can focus on local file system only.
9 |
10 | NOTE: The remote file system can be the same as the local file system, since
11 | users may want to transmit files across directories.
12 | """
13 |
14 | import warnings
15 |
16 | __all__ = ['BaseFileTransmitter']
17 |
18 |
19 | class BaseFileTransmitter(object):
20 | """Defines the base file transmitter.
21 |
22 | A transmitter should have the following functions:
23 |
24 | (1) pull(): The function to pull a file/directory from remote to local.
25 | (2) push(): The function to push a file/directory from local to remote.
26 | (3) remove(): The function to remove a file/directory.
27 | (4) make_remote_dir(): Make directory remotely.
28 |
29 |
30 | To simplify, each derived class just need to implement the following helper
31 | functions:
32 |
33 | (1) download_hard(): Hard download a file/directory from remote to local.
34 | (2) download_soft(): Soft download a file/directory from remote to local.
35 | This is especially used to save space (e.g., soft link).
36 | (3) upload(): Upload a file/directory from local to remote.
37 | (4) delete(): Delete a file/directory according to given path.
38 | """
39 |
40 | def __init__(self):
41 | pass
42 |
43 | @property
44 | def name(self):
45 | """Returns the class name of the file transmitter."""
46 | return self.__class__.__name__
47 |
48 | @staticmethod
49 | def download_hard(src, dst):
50 | """Downloads (in hard mode) a file/directory from remote to local."""
51 | raise NotImplementedError('Should be implemented in derived class!')
52 |
53 | @staticmethod
54 | def download_soft(src, dst):
55 | """Downloads (in soft mode) a file/directory from local to remote."""
56 | raise NotImplementedError('Should be implemented in derived class!')
57 |
58 | @staticmethod
59 | def upload(src, dst):
60 | """Uploads a file/directory from local to remote."""
61 | raise NotImplementedError('Should be implemented in derived class!')
62 |
63 | @staticmethod
64 | def delete(path):
65 | """Deletes the given path."""
66 | # TODO: should we secure the path to avoid mis-removing / attacks?
67 | raise NotImplementedError('Should be implemented in derived class!')
68 |
69 | def pull(self, src, dst, hard=False):
70 | """Pulls a file/directory from remote to local.
71 |
72 | The argument `hard` is to control the download mode (hard or soft).
73 | For example, the hard mode may hardly copy the file while the soft mode
74 | may softly link the file.
75 | """
76 | if hard:
77 | self.download_hard(src, dst)
78 | else:
79 | self.download_soft(src, dst)
80 |
81 | def push(self, src, dst):
82 | """Pushes a file/directory from local to remote."""
83 | self.upload(src, dst)
84 |
85 | def remove(self, path):
86 | """Removes the given path."""
87 | warnings.warn(f'`{path}` will be removed!')
88 | self.delete(path)
89 |
90 | def make_remote_dir(self, directory):
91 | """Makes a directory on the remote system."""
92 | raise NotImplementedError('Should be implemented in derived class!')
93 |
--------------------------------------------------------------------------------
/utils/file_transmitters/dummy_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of dummy file transmitter.
3 |
4 | This file transmitter has all expected data transmission functions but behaves
5 | silently, which is very useful in multi-processing mode. Only the chief process
6 | can have the file transmitter with normal behavior.
7 | """
8 |
9 | from .base_file_transmitter import BaseFileTransmitter
10 |
11 | __all__ = ['DummyFileTransmitter']
12 |
13 |
14 | class DummyFileTransmitter(BaseFileTransmitter):
15 | """Implements a dummy transmitter which transmits nothing."""
16 |
17 | @staticmethod
18 | def download_hard(src, dst):
19 | return
20 |
21 | @staticmethod
22 | def download_soft(src, dst):
23 | return
24 |
25 | @staticmethod
26 | def upload(src, dst):
27 | return
28 |
29 | @staticmethod
30 | def delete(path):
31 | return
32 |
33 | def make_remote_dir(self, directory):
34 | return
35 |
--------------------------------------------------------------------------------
/utils/file_transmitters/local_file_transmitter.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of local file transmitter.
3 |
4 | The transmitter builds the connection between the local file system and itself.
5 | This can be used to transmit files from one directory to another. Consequently,
6 | `remote` in this file also means `local`.
7 | """
8 |
9 | from utils.misc import print_and_execute
10 | from .base_file_transmitter import BaseFileTransmitter
11 |
12 | __all__ = ['LocalFileTransmitter']
13 |
14 |
15 | class LocalFileTransmitter(BaseFileTransmitter):
16 | """Implements the transmitter connecting local file system to itself."""
17 |
18 | @staticmethod
19 | def download_hard(src, dst):
20 | print_and_execute(f'cp {src} {dst}')
21 |
22 | @staticmethod
23 | def download_soft(src, dst):
24 | print_and_execute(f'ln -s {src} {dst}')
25 |
26 | @staticmethod
27 | def upload(src, dst):
28 | print_and_execute(f'cp {src} {dst}')
29 |
30 | @staticmethod
31 | def delete(path):
32 | print_and_execute(f'rm -r {path}')
33 |
34 | def make_remote_dir(self, directory):
35 | print_and_execute(f'mkdir -p {directory}')
36 |
--------------------------------------------------------------------------------
/utils/formatting_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains utility functions used for formatting."""
3 |
4 | import cv2
5 | import numpy as np
6 |
7 | __all__ = [
8 | 'format_time', 'format_range', 'format_image_size', 'format_image',
9 | 'raw_label_to_one_hot', 'one_hot_to_raw_label'
10 | ]
11 |
12 |
13 | def format_time(seconds):
14 | """Formats seconds to readable time string.
15 |
16 | Args:
17 | seconds: Number of seconds to format.
18 |
19 | Returns:
20 | The formatted time string.
21 |
22 | Raises:
23 | ValueError: If the input `seconds` is less than 0.
24 | """
25 | if seconds < 0:
26 | raise ValueError(f'Input `seconds` should be greater than or equal to '
27 | f'0, but `{seconds}` is received!')
28 |
29 | # Returns seconds as float if less than 1 minute.
30 | if seconds < 10:
31 | return f'{seconds:7.3f} s'
32 | if seconds < 60:
33 | return f'{seconds:7.2f} s'
34 |
35 | seconds = int(seconds + 0.5)
36 | days, seconds = divmod(seconds, 86400)
37 | hours, seconds = divmod(seconds, 3600)
38 | minutes, seconds = divmod(seconds, 60)
39 | if days:
40 | return f'{days:2d} d {hours:02d} h'
41 | if hours:
42 | return f'{hours:2d} h {minutes:02d} m'
43 | return f'{minutes:2d} m {seconds:02d} s'
44 |
45 |
46 | def format_range(obj, min_val=None, max_val=None):
47 | """Formats the given object to a valid range.
48 |
49 | If `min_val` or `max_val` is provided, both the starting value and the end
50 | value will be clamped to range `[min_val, max_val]`.
51 |
52 | NOTE: (a, b) is regarded as a valid range if and only if `a <= b`.
53 |
54 | Args:
55 | obj: The input object to format.
56 | min_val: The minimum value to cut off the input range. If not provided,
57 | the default minimum value is negative infinity. (default: None)
58 | max_val: The maximum value to cut off the input range. If not provided,
59 | the default maximum value is infinity. (default: None)
60 |
61 | Returns:
62 | A two-elements tuple, indicating the start and the end of the range.
63 |
64 | Raises:
65 | ValueError: If the input object is an invalid range.
66 | """
67 | if not isinstance(obj, (tuple, list)):
68 | raise ValueError(f'Input object must be a tuple or a list, '
69 | f'but `{type(obj)}` received!')
70 | if len(obj) != 2:
71 | raise ValueError(f'Input object is expected to contain two elements, '
72 | f'but `{len(obj)}` received!')
73 | if obj[0] > obj[1]:
74 | raise ValueError(f'The second element is expected to be equal to or '
75 | f'greater than the first one, '
76 | f'but `({obj[0]}, {obj[1]})` received!')
77 |
78 | obj = list(obj)
79 | if min_val is not None:
80 | obj[0] = max(obj[0], min_val)
81 | obj[1] = max(obj[1], min_val)
82 | if max_val is not None:
83 | obj[0] = min(obj[0], max_val)
84 | obj[1] = min(obj[1], max_val)
85 | return tuple(obj)
86 |
87 |
88 | def format_image_size(size):
89 | """Formats the given image size to a two-element tuple.
90 |
91 | A valid image size can be an integer, indicating both the height and the
92 | width, OR can be a two-element list or tuple. Both height and width are
93 | assumed to be positive integer.
94 |
95 | Args:
96 | size: The input size to format.
97 |
98 | Returns:
99 | A two-elements tuple, indicating the height and the width, respectively.
100 |
101 | Raises:
102 | ValueError: If the input size is invalid.
103 | """
104 | if not isinstance(size, (int, tuple, list)):
105 | raise ValueError(f'Input size must be an integer, a tuple, or a list, '
106 | f'but `{type(size)}` received!')
107 | if isinstance(size, int):
108 | size = (size, size)
109 | else:
110 | if len(size) == 1:
111 | size = (size[0], size[0])
112 | if not len(size) == 2:
113 | raise ValueError(f'Input size is expected to have two numbers at '
114 | f'most, but `{len(size)}` numbers received!')
115 | if not isinstance(size[0], int) or size[0] < 0:
116 | raise ValueError(f'The height is expected to be a non-negative '
117 | f'integer, but `{size[0]}` received!')
118 | if not isinstance(size[1], int) or size[1] < 0:
119 | raise ValueError(f'The width is expected to be a non-negative '
120 | f'integer, but `{size[1]}` received!')
121 | return tuple(size)
122 |
123 |
124 | def format_image(image):
125 | """Formats an image read from `cv2`.
126 |
127 | NOTE: This function will always return a 3-dimensional image (i.e., with
128 | shape [H, W, C]) in pixel range [0, 255]. For color images, the channel
129 | order of the input is expected to be with `BGR` or `BGRA`, which is the
130 | raw image decoded by `cv2`; while the channel order of the output is set to
131 | `RGB` or `RGBA` by default.
132 |
133 | Args:
134 | image: `np.ndarray`, an image read by `cv2.imread()` or
135 | `cv2.imdecode()`.
136 |
137 | Returns:
138 | An image with shape [H, W, C] (where `C = 1` for grayscale image).
139 | """
140 | if image.ndim == 2: # add additional axis if given a grayscale image
141 | image = image[:, :, np.newaxis]
142 |
143 | assert isinstance(image, np.ndarray)
144 | assert image.dtype == np.uint8
145 | assert image.ndim == 3 and image.shape[2] in [1, 3, 4]
146 |
147 | if image.shape[2] == 3: # BGR image
148 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
149 | if image.shape[2] == 4: # BGRA image
150 | return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
151 | return image
152 |
153 |
154 | def raw_label_to_one_hot(raw_label, num_classes):
155 | """Converts a single label into one-hot vector.
156 |
157 | Args:
158 | raw_label: The raw label.
159 | num_classes: Total number of classes.
160 |
161 | Returns:
162 | one-hot vector of the given raw label.
163 | """
164 | one_hot = np.zeros(num_classes, dtype=np.float32)
165 | one_hot[raw_label] = 1.0
166 | return one_hot
167 |
168 |
169 | def one_hot_to_raw_label(one_hot):
170 | """Converts a one-hot vector to a single value label.
171 |
172 | Args:
173 | one_hot: `np.ndarray`, a one-hot encoded vector.
174 |
175 | Returns:
176 | A single integer to represent the category.
177 | """
178 | return np.argmax(one_hot)
179 |
--------------------------------------------------------------------------------
/utils/loggers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all loggers."""
3 |
4 | from .normal_logger import NormalLogger
5 | from .rich_logger import RichLogger
6 | from .dummy_logger import DummyLogger
7 |
8 | __all__ = ['build_logger']
9 |
10 | _LOGGERS = {
11 | 'normal': NormalLogger,
12 | 'rich': RichLogger,
13 | 'dummy': DummyLogger
14 | }
15 |
16 |
17 | def build_logger(logger_type='normal', **kwargs):
18 | """Builds a logger.
19 |
20 | Args:
21 | logger_type: Type of logger, which is case insensitive.
22 | (default: `normal`)
23 | **kwargs: Additional arguments to build the logger.
24 |
25 | Raises:
26 | ValueError: If the `logger_type` is not supported.
27 | """
28 | logger_type = logger_type.lower()
29 | if logger_type not in _LOGGERS:
30 | raise ValueError(f'Invalid logger type: `{logger_type}`!\n'
31 | f'Types allowed: {list(_LOGGERS)}.')
32 | return _LOGGERS[logger_type](**kwargs)
33 |
--------------------------------------------------------------------------------
/utils/loggers/dummy_logger.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of dummy logger.
3 |
4 | This logger has all expected logging functions but behaves silently, which is
5 | very useful in multi-processing mode. Only the chief process can have the logger
6 | with normal behavior.
7 | """
8 |
9 | from .base_logger import BaseLogger
10 |
11 | __all__ = ['DummyLogger']
12 |
13 |
14 | class DummyLogger(BaseLogger):
15 | """Implements a dummy logger which logs nothing."""
16 |
17 | def __init__(self,
18 | logger_name='logger',
19 | logfile=None,
20 | screen_level=None,
21 | file_level=None,
22 | indent_space=4,
23 | verbose_log=False):
24 | super().__init__(logger_name=logger_name,
25 | logfile=logfile,
26 | screen_level=screen_level,
27 | file_level=file_level,
28 | indent_space=indent_space,
29 | verbose_log=verbose_log)
30 |
31 | def _log(self, message, **kwargs):
32 | return
33 |
34 | def _debug(self, message, **kwargs):
35 | return
36 |
37 | def _info(self, message, **kwargs):
38 | return
39 |
40 | def _warning(self, message, **kwargs):
41 | return
42 |
43 | def _error(self, message, **kwargs):
44 | return
45 |
46 | def _exception(self, message, **kwargs):
47 | return
48 |
49 | def _critical(self, message, **kwargs):
50 | return
51 |
52 | def _print(self, *messages, **kwargs):
53 | return
54 |
55 | def init_pbar(self, leave=False):
56 | return
57 |
58 | def add_pbar_task(self, name, total, **kwargs):
59 | return -1
60 |
61 | def update_pbar(self, task_id, advance=1):
62 | return
63 |
64 | def close_pbar(self):
65 | return
66 |
--------------------------------------------------------------------------------
/utils/loggers/normal_logger.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of normal logger.
3 |
4 | This class is built based on the built-in function `print()`, the module
5 | `logging` and the module `tqdm` for progressive bar.
6 | """
7 |
8 | import sys
9 | import logging
10 | from copy import deepcopy
11 | from tqdm import tqdm
12 |
13 | from .base_logger import BaseLogger
14 |
15 | __all__ = ['NormalLogger']
16 |
17 |
18 | class NormalLogger(BaseLogger):
19 | """Implements the logger based on `logging` module."""
20 |
21 | def __init__(self,
22 | logger_name='logger',
23 | logfile=None,
24 | screen_level=logging.INFO,
25 | file_level=logging.DEBUG,
26 | indent_space=4,
27 | verbose_log=False):
28 | super().__init__(logger_name=logger_name,
29 | logfile=logfile,
30 | screen_level=screen_level,
31 | file_level=file_level,
32 | indent_space=indent_space,
33 | verbose_log=verbose_log)
34 |
35 | # Get logger and check whether the logger has already been created.
36 | self.logger = logging.getLogger(self.logger_name)
37 | self.logger.propagate = False
38 | if self.logger.hasHandlers(): # Already existed
39 | raise SystemExit(f'Logger `{self.logger_name}` has already '
40 | f'existed!\n'
41 | f'Please use another name, or otherwise the '
42 | f'messages may be mixed up.')
43 |
44 | # Set format.
45 | self.logger.setLevel(logging.DEBUG)
46 | formatter = logging.Formatter(
47 | '[%(asctime)s][%(levelname)s] %(message)s',
48 | datefmt='%Y-%m-%d %H:%M:%S')
49 |
50 | # Print log message onto the screen.
51 | terminal_handler = logging.StreamHandler(stream=sys.stdout)
52 | terminal_handler.setLevel(self.screen_level)
53 | terminal_handler.setFormatter(formatter)
54 | self.logger.addHandler(terminal_handler)
55 |
56 | # Save log message into log file if needed.
57 | if self.logfile:
58 | # File will be closed when the logger is closed in `self.close()`.
59 | self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with
60 | file_handler = logging.StreamHandler(stream=self.file_stream)
61 | file_handler.setLevel(self.file_level)
62 | file_handler.setFormatter(formatter)
63 | self.logger.addHandler(file_handler)
64 |
65 | self.pbar = []
66 | self.pbar_kwargs = {}
67 |
68 | def _log(self, message, **kwargs):
69 | self.logger.log(message, **kwargs)
70 |
71 | def _debug(self, message, **kwargs):
72 | self.logger.debug(message, **kwargs)
73 |
74 | def _info(self, message, **kwargs):
75 | self.logger.info(message, **kwargs)
76 |
77 | def _warning(self, message, **kwargs):
78 | self.logger.warning(message, **kwargs)
79 |
80 | def _error(self, message, **kwargs):
81 | self.logger.error(message, **kwargs)
82 |
83 | def _exception(self, message, **kwargs):
84 | self.logger.exception(message, **kwargs)
85 |
86 | def _critical(self, message, **kwargs):
87 | self.logger.critical(message, **kwargs)
88 |
89 | def _print(self, *messages, **kwargs):
90 | for handler in self.logger.handlers:
91 | print(*messages, file=handler.stream)
92 |
93 | def init_pbar(self, leave=False):
94 | columns = [
95 | '{desc}',
96 | '{bar}',
97 | ' {percentage:5.1f}%',
98 | '[{elapsed}<{remaining}, {rate_fmt}{postfix}]',
99 | ]
100 | self.pbar_kwargs = dict(
101 | leave=leave,
102 | bar_format=' '.join(columns),
103 | unit='',
104 | )
105 |
106 | def add_pbar_task(self, name, total, **kwargs):
107 | assert isinstance(self.pbar_kwargs, dict)
108 | pbar_kwargs = deepcopy(self.pbar_kwargs)
109 | pbar_kwargs.update(**kwargs)
110 | self.pbar.append(tqdm(desc=name, total=total, **pbar_kwargs))
111 | return len(self.pbar) - 1
112 |
113 | def update_pbar(self, task_id, advance=1):
114 | assert len(self.pbar) > task_id and isinstance(self.pbar[task_id], tqdm)
115 | if self.pbar[task_id].n < self.pbar[task_id].total:
116 | self.pbar[task_id].update(advance)
117 | if self.pbar[task_id].n >= self.pbar[task_id].total:
118 | self.pbar[task_id].refresh()
119 |
120 | def close_pbar(self):
121 | for pbar in self.pbar[::-1]:
122 | pbar.close()
123 | self.pbar = []
124 | self.pbar_kwargs = {}
125 |
--------------------------------------------------------------------------------
/utils/loggers/rich_logger.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the class of rich logger.
3 |
4 | This class is based on the module `rich`. Please refer to
5 | https://github.com/Textualize/rich for more details.
6 | """
7 |
8 | import sys
9 | import logging
10 | from copy import deepcopy
11 | from rich.console import Console
12 | from rich.logging import RichHandler
13 | from rich.progress import Progress
14 | from rich.progress import ProgressColumn
15 | from rich.progress import TextColumn
16 | from rich.progress import BarColumn
17 | from rich.text import Text
18 |
19 | from .base_logger import BaseLogger
20 |
21 | __all__ = ['RichLogger']
22 |
23 |
24 | def _format_time(seconds):
25 | """Formats seconds to readable time string.
26 |
27 | This function is used to display time in progress bar.
28 | """
29 | if not seconds:
30 | return '--:--'
31 |
32 | seconds = int(seconds)
33 | hours, seconds = divmod(seconds, 3600)
34 | minutes, seconds = divmod(seconds, 60)
35 | if hours:
36 | return f'{hours}:{minutes:02d}:{seconds:02d}'
37 | return f'{minutes:02d}:{seconds:02d}'
38 |
39 |
40 | class TimeColumn(ProgressColumn):
41 | """Renders total time, ETA, and speed in progress bar."""
42 |
43 | max_refresh = 0.5 # Only refresh twice a second to prevent jitter
44 |
45 | def render(self, task):
46 | elapsed_time = _format_time(task.elapsed)
47 | eta = _format_time(task.time_remaining)
48 | speed = f'{task.speed:.2f}/s' if task.speed else '?/s'
49 | return Text(f'[{elapsed_time}<{eta}, {speed}]',
50 | style='progress.remaining')
51 |
52 |
53 | class RichLogger(BaseLogger):
54 | """Implements the logger based on `rich` module."""
55 |
56 | def __init__(self,
57 | logger_name='logger',
58 | logfile=None,
59 | screen_level=logging.INFO,
60 | file_level=logging.DEBUG,
61 | indent_space=4,
62 | verbose_log=False):
63 | super().__init__(logger_name=logger_name,
64 | logfile=logfile,
65 | screen_level=screen_level,
66 | file_level=file_level,
67 | indent_space=indent_space,
68 | verbose_log=verbose_log)
69 |
70 | # Get logger and check whether the logger has already been created.
71 | self.logger = logging.getLogger(self.logger_name)
72 | self.logger.propagate = False
73 | if self.logger.hasHandlers(): # Already existed
74 | raise SystemExit(f'Logger `{self.logger_name}` has already '
75 | f'existed!\n'
76 | f'Please use another name, or otherwise the '
77 | f'messages may be mixed up.')
78 |
79 | # Set format.
80 | self.logger.setLevel(logging.DEBUG)
81 |
82 | # Print log message onto the screen.
83 | terminal_console = Console(
84 | file=sys.stdout, log_time=False, log_path=False)
85 | terminal_handler = RichHandler(
86 | level=self.screen_level,
87 | console=terminal_console,
88 | show_time=True,
89 | show_level=True,
90 | show_path=False,
91 | log_time_format='[%Y-%m-%d %H:%M:%S] ')
92 | terminal_handler.setFormatter(logging.Formatter('%(message)s'))
93 | self.logger.addHandler(terminal_handler)
94 |
95 | # Save log message into log file if needed.
96 | if self.logfile:
97 | # File will be closed when the logger is closed in `self.close()`.
98 | self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with
99 | file_console = Console(
100 | file=self.file_stream, log_time=False, log_path=False)
101 | file_handler = RichHandler(
102 | level=self.file_level,
103 | console=file_console,
104 | show_time=True,
105 | show_level=True,
106 | show_path=False,
107 | log_time_format='[%Y-%m-%d %H:%M:%S] ')
108 | file_handler.setFormatter(logging.Formatter('%(message)s'))
109 | self.logger.addHandler(file_handler)
110 |
111 | self.pbar = None
112 | self.pbar_kwargs = {}
113 |
114 | def _log(self, message, **kwargs):
115 | self.logger.log(message, **kwargs)
116 |
117 | def _debug(self, message, **kwargs):
118 | self.logger.debug(message, **kwargs)
119 |
120 | def _info(self, message, **kwargs):
121 | self.logger.info(message, **kwargs)
122 |
123 | def _warning(self, message, **kwargs):
124 | self.logger.warning(message, **kwargs)
125 |
126 | def _error(self, message, **kwargs):
127 | self.logger.error(message, **kwargs)
128 |
129 | def _exception(self, message, **kwargs):
130 | self.logger.exception(message, **kwargs)
131 |
132 | def _critical(self, message, **kwargs):
133 | self.logger.critical(message, **kwargs)
134 |
135 | def _print(self, *messages, **kwargs):
136 | for handler in self.logger.handlers:
137 | handler.console.print(*messages, **kwargs)
138 |
139 | def init_pbar(self, leave=False):
140 | assert self.pbar is None
141 |
142 | # Columns shown in the progress bar.
143 | columns = (
144 | TextColumn('[progress.description]{task.description}'),
145 | BarColumn(bar_width=None),
146 | TextColumn('[progress.percentage]{task.percentage:>5.1f}%'),
147 | TimeColumn(),
148 | )
149 |
150 | self.pbar = Progress(*columns,
151 | console=self.logger.handlers[0].console,
152 | transient=not leave,
153 | auto_refresh=True,
154 | refresh_per_second=10)
155 | self.pbar.start()
156 |
157 | def add_pbar_task(self, name, total, **kwargs):
158 | assert isinstance(self.pbar, Progress)
159 | assert isinstance(self.pbar_kwargs, dict)
160 | pbar_kwargs = deepcopy(self.pbar_kwargs)
161 | pbar_kwargs.update(**kwargs)
162 | task_id = self.pbar.add_task(name, total=total, **pbar_kwargs)
163 | return task_id
164 |
165 | def update_pbar(self, task_id, advance=1):
166 | assert isinstance(self.pbar, Progress)
167 | if self.pbar.tasks[task_id].finished:
168 | if self.pbar.tasks[task_id].stop_time is None:
169 | self.pbar.stop_task(task_id)
170 | else:
171 | self.pbar.update(task_id, advance=advance)
172 |
173 | def close_pbar(self):
174 | assert isinstance(self.pbar, Progress)
175 | self.pbar.stop()
176 | self.pbar = None
177 | self.pbar_kwargs = {}
178 |
--------------------------------------------------------------------------------
/utils/loggers/test.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Unit test for logger."""
3 |
4 | import os
5 | import time
6 |
7 | from . import build_logger
8 |
9 | __all__ = ['test_logger']
10 |
11 | _TEST_DIR = 'logger_test'
12 |
13 |
14 | def test_logger(test_dir=_TEST_DIR):
15 | """Tests loggers."""
16 | print('========== Start Logger Test ==========')
17 |
18 | os.makedirs(test_dir, exist_ok=True)
19 |
20 | for logger_type in ['normal', 'rich', 'dummy']:
21 | for indent_space in [2, 4]:
22 | for verbose_log in [False, True]:
23 | if logger_type == 'normal':
24 | class_name = 'Logger'
25 | elif logger_type == 'rich':
26 | class_name = 'RichLogger'
27 | elif logger_type == 'dummy':
28 | class_name = 'DummyLogger'
29 |
30 | print(f'===== '
31 | f'Testing `utils.logger.{class_name}` '
32 | f' (indent: {indent_space}, verbose: {verbose_log}) '
33 | f'=====')
34 | logger_name = (f'{logger_type}_logger_'
35 | f'indent_{indent_space}_'
36 | f'verbose_{verbose_log}')
37 | logger = build_logger(
38 | logger_type,
39 | logger_name=logger_name,
40 | logfile=os.path.join(test_dir, f'test_{logger_name}.log'),
41 | verbose_log=verbose_log,
42 | indent_space=indent_space)
43 | logger.print('print log')
44 | logger.print('print log,', 'log 2')
45 | logger.print('print log (indent level 0)', indent_level=0)
46 | logger.print('print log (indent level 1)', indent_level=1)
47 | logger.print('print log (indent level 2)', indent_level=2)
48 | logger.print('print log (verbose `False`)', is_verbose=False)
49 | logger.print('print log (verbose `True`)', is_verbose=True)
50 | logger.debug('debug log')
51 | logger.info('info log')
52 | logger.warning('warning log')
53 | logger.init_pbar()
54 | task_1 = logger.add_pbar_task('Task 1', 500)
55 | task_2 = logger.add_pbar_task('Task 2', 1000)
56 | for _ in range(1000):
57 | logger.update_pbar(task_1, 1)
58 | logger.update_pbar(task_2, 1)
59 | time.sleep(0.002)
60 | logger.close_pbar()
61 | print('Success!')
62 |
63 | print('========== Finish Logger Test ==========')
64 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Misc utility functions."""
3 |
4 | import os
5 | import hashlib
6 |
7 | from pathlib import Path
8 | from torch.hub import download_url_to_file
9 |
10 | __all__ = [
11 | 'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext',
12 | 'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS',
13 | 'parse_file_format', 'set_cache_dir', 'get_cache_dir',
14 | 'md5_update_from_file', 'md5_update_from_dir', 'md5', 'download_url'
15 | ]
16 |
17 | REPO_NAME = 'Hammer' # Name of the repository (project).
18 |
19 |
20 | class Infix(object):
21 | """Helper class to create custom infix operators.
22 |
23 | When using it, make sure to put the operator between `<<` and `>>`.
24 | `<< INFIX_OP_NAME >>` should be considered as a whole operator.
25 |
26 | Examples:
27 |
28 | # Use `Infix` to create infix operators directly.
29 | add = Infix(lambda a, b: a + b)
30 | 1 << add >> 2 # gives 3
31 | 1 << add >> 2 << add >> 3 # gives 6
32 |
33 | # Use `Infix` as a decorator.
34 | @Infix
35 | def mul(a, b):
36 | return a * b
37 | 2 << mul >> 4 # gives 8
38 | 2 << mul >> 3 << mul >> 7 # gives 42
39 | """
40 |
41 | def __init__(self, function):
42 | self.function = function
43 | self.left_value = None
44 |
45 | def __rlshift__(self, left_value): # override `<<` before `Infix` instance
46 | assert self.left_value is None # make sure left is only called once
47 | self.left_value = left_value
48 | return self
49 |
50 | def __rshift__(self, right_value): # override `>>` after `Infix` instance
51 | result = self.function(self.left_value, right_value)
52 | self.left_value = None # reset to None
53 | return result
54 |
55 |
56 | def print_and_execute(cmd):
57 | """Prints and executes a system command.
58 |
59 | Args:
60 | cmd: Command to be executed.
61 | """
62 | print(cmd)
63 | os.system(cmd)
64 |
65 |
66 | def check_file_ext(filename, *ext_list):
67 | """Checks whether the given filename is with target extension(s).
68 |
69 | NOTE: If `ext_list` is empty, this function will always return `False`.
70 |
71 | Args:
72 | filename: Filename to check.
73 | *ext_list: A list of extensions.
74 |
75 | Returns:
76 | `True` if the filename is with one of extensions in `ext_list`,
77 | otherwise `False`.
78 | """
79 | if len(ext_list) == 0:
80 | return False
81 | ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
82 | ext_list = [ext.lower() for ext in ext_list]
83 | basename = os.path.basename(filename)
84 | ext = os.path.splitext(basename)[1].lower()
85 | return ext in ext_list
86 |
87 |
88 | # File extensions regarding images (not including GIFs).
89 | IMAGE_EXTENSIONS = (
90 | '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
91 | '.tiff', '.tif'
92 | )
93 | # File extensions regarding videos.
94 | VIDEO_EXTENSIONS = (
95 | '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm',
96 | '.3gp'
97 | )
98 | # File extensions regarding media, i.e., images, videos, GIFs.
99 | MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS)
100 |
101 |
102 | def parse_file_format(path):
103 | """Parses the file format of a given path.
104 |
105 | This function basically parses the file format according to its extension.
106 | It will also return `dir` is the given path is a directory.
107 |
108 | Parable file formats:
109 |
110 | - zip: with `.zip` extension.
111 | - tar: with `.tar` / `.tgz` / `.tar.gz` extension.
112 | - lmdb: a folder ending with `lmdb`.
113 | - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE).
114 | - json: with `.json` extension.
115 | - jpg: with `.jpeg` / `jpg` / `jpe` extension.
116 | - png: with `.png` extension.
117 |
118 | Args:
119 | path: The path to the file to parse format from.
120 |
121 | Returns:
122 | A lower-case string, indicating the file format, or `None` if the format
123 | cannot be successfully parsed.
124 | """
125 | # Handle directory.
126 | if os.path.isdir(path) or path.endswith('/'):
127 | if path.rstrip('/').lower().endswith('lmdb'):
128 | return 'lmdb'
129 | return 'dir'
130 | # Handle file.
131 | if os.path.isfile(path) and os.path.splitext(path)[1] == '':
132 | return 'txt'
133 | path = path.lower()
134 | if path.endswith('.tar.gz'): # Cannot parse accurate extension.
135 | return 'tar'
136 | ext = os.path.splitext(path)[1]
137 | if ext == '.zip':
138 | return 'zip'
139 | if ext in ['.tar', '.tgz']:
140 | return 'tar'
141 | if ext in ['.txt', '.text']:
142 | return 'txt'
143 | if ext == '.json':
144 | return 'json'
145 | if ext in ['.jpeg', '.jpg', '.jpe']:
146 | return 'jpg'
147 | if ext == '.png':
148 | return 'png'
149 | # Unparsable.
150 | return None
151 |
152 |
153 | _cache_dir = None
154 |
155 |
156 | def set_cache_dir(directory=None):
157 | """Sets the global cache directory.
158 |
159 | The cache directory can be used to save some files that will be shared
160 | across jobs. The default cache directory is set as `~/.cache/`. This
161 | function can be used to redirect the cache directory. Or, users can use
162 | `None` to reset the cache directory back to default.
163 |
164 | Args:
165 | directory: The target directory used to cache files. If set as `None`,
166 | the cache directory will be reset back to default. (default: None)
167 | """
168 | assert directory is None or isinstance(directory, str), 'Invalid directory!'
169 | global _cache_dir # pylint: disable=global-statement
170 | _cache_dir = directory
171 |
172 |
173 | def get_cache_dir(use_repo_name=True):
174 | """Gets the global cache directory.
175 |
176 | The global cache directory is primarily set as `~/.cache/` by default, and
177 | can be redirected with `set_cache_dir()`.
178 |
179 | Args:
180 | use_repo_name: Whether to create a folder, named `REPO_NAME`, under
181 | `_cache_dir` as the actual cache directory. (default: True)
182 |
183 | Returns:
184 | A string, representing the global cache directory.
185 | """
186 | if _cache_dir is None:
187 | cache_dir = os.path.join(os.path.expanduser('~'), '.cache')
188 | else:
189 | cache_dir = _cache_dir
190 | if use_repo_name:
191 | return os.path.join(cache_dir, REPO_NAME)
192 | return cache_dir
193 |
194 |
195 | def md5_update_from_file(filename, hash_obj):
196 | """Updates the `hash_obj` with a given file.
197 |
198 | Args:
199 | filename: Path to the file.
200 | hash_obj: The `hashlib.md5` object to be updated.
201 |
202 | Returns:
203 | An updated `hashlib.md5` object.
204 | """
205 | assert Path(filename).is_file()
206 | with open(str(filename), 'rb') as f:
207 | for chunk in iter(lambda: f.read(4096), b''):
208 | hash_obj.update(chunk)
209 | return hash_obj
210 |
211 |
212 | def md5_update_from_dir(directory, hash_obj):
213 | """Updates the `hash_obj` with a given directory.
214 |
215 | Args:
216 | directory: Path to the directory.
217 | hash_obj: The `hashlib.md5` object to be updated.
218 |
219 | Returns:
220 | An updated `hashlib.md5` object.
221 | """
222 | assert Path(directory).is_dir()
223 | for path in sorted(Path(directory).iterdir()):
224 | hash_obj.update(path.name.encode())
225 | if path.is_file():
226 | hash_obj = md5_update_from_file(path, hash_obj)
227 | elif path.is_dir():
228 | hash_obj = md5_update_from_dir(path, hash_obj)
229 | return hash_obj
230 |
231 |
232 | def md5(path):
233 | """Returns the MD5 of the given path in hex digest."""
234 | if Path(path).is_file():
235 | md5_hash = md5_update_from_file(path, hashlib.md5())
236 | return md5_hash.hexdigest()
237 | if Path(path).is_dir():
238 | md5_hash = md5_update_from_dir(path, hashlib.md5())
239 | return md5_hash.hexdigest()
240 | raise ValueError(f'Currently calculation of MD5 does not supported for '
241 | f'`{path}`')
242 |
243 |
244 | def download_url(url, path=None, filename=None, sha256=None):
245 | """Downloads file from URL.
246 |
247 | This function downloads a file from given URL, and executes Hash check if
248 | needed.
249 |
250 | Args:
251 | url: The URL to download file from.
252 | path: Path (directory) to save the downloaded file. If set as `None`,
253 | the cache directory will be used. Please see `get_cache_dir()` for
254 | more details. (default: None)
255 | filename: The name to save the file. If set as `None`, this name will be
256 | automatically parsed from the given URL. (default: None)
257 | sha256: The expected sha256 of the downloaded file. If set as `None`,
258 | the hash check will be skipped. Otherwise, this function will check
259 | whether the sha256 of the downloaded file matches this field.
260 |
261 | Returns:
262 | A two-element tuple, where the first term is the full path of the
263 | downloaded file, and the second term indicate the hash check result.
264 | `True` means hash check passes, `False` means hash check fails,
265 | while `None` means no hash check is executed.
266 | """
267 | # Handle file path.
268 | if path is None:
269 | path = get_cache_dir()
270 | if filename is None:
271 | filename = os.path.basename(url)
272 | save_path = os.path.join(path, filename)
273 | # Download file if needed.
274 | if not os.path.exists(save_path):
275 | print(f'Downloading URL `{url}` to path `{save_path}` ...')
276 | os.makedirs(path, exist_ok=True)
277 | download_url_to_file(url, save_path, hash_prefix=None, progress=True)
278 | # Check hash if needed.
279 | check_result = None
280 | if sha256 is not None:
281 | with open(save_path, 'rb') as f:
282 | file_hash = hashlib.sha256(f.read())
283 | check_result = (file_hash.hexdigest() == sha256)
284 |
285 | return save_path, check_result
286 |
--------------------------------------------------------------------------------
/utils/parsing_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the utility functions for parsing arguments."""
3 |
4 | import json
5 | import argparse
6 | import click
7 |
8 | __all__ = [
9 | 'parse_int', 'parse_float', 'parse_bool', 'parse_index', 'parse_json',
10 | 'IntegerParamType', 'FloatParamType', 'BooleanParamType', 'IndexParamType',
11 | 'JsonParamType', 'DictAction'
12 | ]
13 |
14 |
15 | def parse_int(arg):
16 | """Parses an argument to integer.
17 |
18 | Support converting string `none` and `null` to `None`.
19 | """
20 | if arg is None:
21 | return None
22 | if isinstance(arg, str) and arg.lower() in ['none', 'null']:
23 | return None
24 | return int(arg)
25 |
26 |
27 | def parse_float(arg):
28 | """Parses an argument to float number.
29 |
30 | Support converting string `none` and `null` to `None`.
31 | """
32 | if arg is None:
33 | return None
34 | if isinstance(arg, str) and arg.lower() in ['none', 'null']:
35 | return None
36 | return float(arg)
37 |
38 |
39 | def parse_bool(arg):
40 | """Parses an argument to boolean.
41 |
42 | `None` will be converted to `False`.
43 | """
44 | if isinstance(arg, bool):
45 | return arg
46 | if arg is None:
47 | return False
48 | if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
49 | return True
50 | if arg.lower() in ['0', 'false', 'f', 'no', 'n', 'none', 'null']:
51 | return False
52 | raise ValueError(f'`{arg}` cannot be converted to boolean!')
53 |
54 |
55 | def parse_index(arg, min_val=None, max_val=None):
56 | """Parses indices.
57 |
58 | If the input is a list or tuple, this function has no effect.
59 |
60 | If the input is a string, it can be either a comma separated list of numbers
61 | `1, 3, 5`, or a dash separated range `3 - 10`. Spaces in the string will be
62 | ignored.
63 |
64 | Args:
65 | arg: The input argument to parse indices from.
66 | min_val: If not `None`, this function will check that all indices are
67 | equal to or larger than this value. (default: None)
68 | max_val: If not `None`, this function will check that all indices are
69 | equal to or smaller than this field. (default: None)
70 |
71 | Returns:
72 | A list of integers.
73 |
74 | Raises:
75 | ValueError: If the input is invalid, i.e., neither a list or tuple, nor
76 | a string.
77 | """
78 | if arg is None or arg == '':
79 | indices = []
80 | elif isinstance(arg, int):
81 | indices = [arg]
82 | elif isinstance(arg, (list, tuple)):
83 | indices = list(arg)
84 | elif isinstance(arg, str):
85 | indices = []
86 | if arg.lower() not in ['none', 'null']:
87 | splits = arg.replace(' ', '').split(',')
88 | for split in splits:
89 | numbers = list(map(int, split.split('-')))
90 | if len(numbers) == 1:
91 | indices.append(numbers[0])
92 | elif len(numbers) == 2:
93 | indices.extend(list(range(numbers[0], numbers[1] + 1)))
94 | else:
95 | raise ValueError(f'Invalid type of input: `{type(arg)}`!')
96 |
97 | assert isinstance(indices, list)
98 | indices = sorted(list(set(indices)))
99 | for idx in indices:
100 | assert isinstance(idx, int)
101 | if min_val is not None:
102 | assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
103 | if max_val is not None:
104 | assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
105 |
106 | return indices
107 |
108 |
109 | def parse_json(arg):
110 | """Parses a string-like argument following JSON format.
111 |
112 | - `None` arguments will be kept.
113 | - Non-string arguments will be kept.
114 | """
115 | if not isinstance(arg, str):
116 | return arg
117 | try:
118 | return json.loads(arg)
119 | except json.decoder.JSONDecodeError:
120 | return arg
121 |
122 |
123 | class IntegerParamType(click.ParamType):
124 | """Defines a `click.ParamType` to parse integer arguments."""
125 |
126 | name = 'int'
127 |
128 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements
129 | try:
130 | return parse_int(value)
131 | except ValueError:
132 | self.fail(f'`{value}` cannot be parsed as an integer!', param, ctx)
133 |
134 |
135 | class FloatParamType(click.ParamType):
136 | """Defines a `click.ParamType` to parse float arguments."""
137 |
138 | name = 'float'
139 |
140 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements
141 | try:
142 | return parse_float(value)
143 | except ValueError:
144 | self.fail(f'`{value}` cannot be parsed as a float!', param, ctx)
145 |
146 |
147 | class BooleanParamType(click.ParamType):
148 | """Defines a `click.ParamType` to parse boolean arguments."""
149 |
150 | name = 'bool'
151 |
152 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements
153 | try:
154 | return parse_bool(value)
155 | except ValueError:
156 | self.fail(f'`{value}` cannot be parsed as a boolean!', param, ctx)
157 |
158 |
159 | class IndexParamType(click.ParamType):
160 | """Defines a `click.ParamType` to parse indices arguments."""
161 |
162 | name = 'index'
163 |
164 | def __init__(self, min_val=None, max_val=None):
165 | self.min_val = min_val
166 | self.max_val = max_val
167 |
168 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements
169 | try:
170 | return parse_index(value, self.min_val, self.max_val)
171 | except ValueError:
172 | self.fail(
173 | f'`{value}` cannot be parsed as a list of indices!', param, ctx)
174 |
175 |
176 | class JsonParamType(click.ParamType):
177 | """Defines a `click.ParamType` to parse arguments following JSON format."""
178 |
179 | name = 'json'
180 |
181 | def convert(self, value, param, ctx):
182 | return parse_json(value)
183 |
184 |
185 | class DictAction(argparse.Action):
186 | """Argparse action to split each argument into (key, value) pair.
187 |
188 | Each argument should be with `key=value` format, where `value` should be a
189 | string with JSON format.
190 |
191 | For example, with an argparse:
192 |
193 | parser.add_argument('--options', nargs='+', action=DictAction)
194 |
195 | , you can use following arguments in the command line:
196 |
197 | --options \
198 | a=1 \
199 | b=1.5
200 | c=true \
201 | d=null \
202 | e=[1,2,3,4,5] \
203 | f='{"x":1,"y":2,"z":3}' \
204 |
205 | NOTE: No space is allowed in each argument. Also, the dictionary-type
206 | argument should be quoted with single quotation marks `'`.
207 | """
208 |
209 | def __call__(self, parser, namespace, values, option_string=None):
210 | options = {}
211 | for argument in values:
212 | key, val = argument.split('=', maxsplit=1)
213 | options[key] = parse_json(val)
214 | setattr(namespace, self.dest, options)
215 |
--------------------------------------------------------------------------------
/utils/tf_utils.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the utility functions to handle import TensorFlow modules.
3 |
4 | Basically, TensorFlow may not be supported in the current environment, or may
5 | cause some warnings. This file provides functions to help ease TensorFlow
6 | related imports, such as TensorBoard.
7 | """
8 |
9 | import warnings
10 |
11 | __all__ = ['import_tf', 'import_tb_writer']
12 |
13 |
14 | def import_tf():
15 | """Imports TensorFlow module if possible.
16 |
17 | If `ImportError` is raised, `None` will be returned. Otherwise, the module
18 | `tensorflow` will be returned.
19 | """
20 | warnings.filterwarnings('ignore', category=FutureWarning)
21 | try:
22 | import tensorflow as tf # pylint: disable=import-outside-toplevel
23 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
24 | module = tf
25 | except ImportError:
26 | module = None
27 | warnings.filterwarnings('default', category=FutureWarning)
28 | return module
29 |
30 |
31 | def import_tb_writer():
32 | """Imports the SummaryWriter of TensorBoard.
33 |
34 | If `ImportError` is raised, `None` will be returned. Otherwise, the class
35 | `SummaryWriter` will be returned.
36 |
37 | NOTE: This function attempts to import `SummaryWriter` from
38 | `torch.utils.tensorboard`. But it does not necessarily mean the import
39 | always succeeds because installing TensorBoard is not a duty of `PyTorch`.
40 | """
41 | warnings.filterwarnings('ignore', category=FutureWarning)
42 | try:
43 | from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel
44 | except ImportError: # In case TensorBoard is not supported.
45 | SummaryWriter = None
46 | warnings.filterwarnings('default', category=FutureWarning)
47 | return SummaryWriter
48 |
--------------------------------------------------------------------------------
/utils/visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Collects all visualizers."""
3 |
4 | from .grid_visualizer import GridVisualizer
5 | from .gif_visualizer import GifVisualizer
6 | from .html_visualizer import HtmlVisualizer
7 | from .html_visualizer import HtmlReader
8 | from .video_visualizer import VideoVisualizer
9 | from .video_visualizer import VideoReader
10 |
11 | __all__ = [
12 | 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader',
13 | 'VideoVisualizer', 'VideoReader'
14 | ]
15 |
--------------------------------------------------------------------------------
/utils/visualizers/gif_visualizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the visualizer to visualize images as a GIF."""
3 |
4 | from PIL import Image
5 |
6 | from ..image_utils import parse_image_size
7 | from ..image_utils import load_image
8 | from ..image_utils import resize_image
9 | from ..image_utils import list_images_from_dir
10 |
11 | __all__ = ['GifVisualizer']
12 |
13 |
14 | class GifVisualizer(object):
15 | """Defines the visualizer that visualizes an image collection as GIF."""
16 |
17 | def __init__(self, image_size=None, duration=100, loop=0):
18 | """Initializes the GIF visualizer.
19 |
20 | Args:
21 | image_size: Size for image visualization. (default: None)
22 | duration: Duration between two frames, in milliseconds.
23 | (default: 100)
24 | loop: How many times to loop the GIF. `0` means infinite.
25 | (default: 0)
26 | """
27 | self.set_image_size(image_size)
28 | self.set_duration(duration)
29 | self.set_loop(loop)
30 |
31 | def set_image_size(self, image_size=None):
32 | """Sets the image size of the GIF."""
33 | height, width = parse_image_size(image_size)
34 | self.image_height = height
35 | self.image_width = width
36 |
37 | def set_duration(self, duration=100):
38 | """Sets the GIF duration."""
39 | self.duration = duration
40 |
41 | def set_loop(self, loop=0):
42 | """Sets how many times the GIF will be looped. `0` means infinite."""
43 | self.loop = loop
44 |
45 | def visualize_collection(self, images, save_path):
46 | """Visualizes a collection of images one by one."""
47 | height, width = images[0].shape[0:2]
48 | height = self.image_height or height
49 | width = self.image_width or width
50 | pil_images = []
51 | for image in images:
52 | if image.shape[0:2] != (height, width):
53 | image = resize_image(image, (width, height))
54 | pil_images.append(Image.fromarray(image))
55 | pil_images[0].save(save_path, format='GIF', save_all=True,
56 | append_images=pil_images[1:],
57 | duration=self.duration,
58 | loop=self.loop)
59 |
60 | def visualize_list(self, image_list, save_path):
61 | """Visualizes a list of image files."""
62 | height, width = load_image(image_list[0]).shape[0:2]
63 | height = self.image_height or height
64 | width = self.image_width or width
65 | pil_images = []
66 | for filename in image_list:
67 | image = load_image(filename)
68 | if image.shape[0:2] != (height, width):
69 | image = resize_image(image, (width, height))
70 | pil_images.append(Image.fromarray(image))
71 | pil_images[0].save(save_path, format='GIF', save_all=True,
72 | append_images=pil_images[1:],
73 | duration=self.duration,
74 | loop=self.loop)
75 |
76 | def visualize_directory(self, directory, save_path):
77 | """Visualizes all images under a directory."""
78 | image_list = list_images_from_dir(directory)
79 | self.visualize_list(image_list, save_path)
80 |
--------------------------------------------------------------------------------
/utils/visualizers/grid_visualizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the visualizer to visualize images by composing them as a gird."""
3 |
4 | from ..image_utils import get_blank_image
5 | from ..image_utils import get_grid_shape
6 | from ..image_utils import parse_image_size
7 | from ..image_utils import load_image
8 | from ..image_utils import save_image
9 | from ..image_utils import resize_image
10 | from ..image_utils import list_images_from_dir
11 |
12 | __all__ = ['GridVisualizer']
13 |
14 |
15 | class GridVisualizer(object):
16 | """Defines the visualizer that visualizes images as a grid.
17 |
18 | Basically, given a collection of images, this visualizer stitches them one
19 | by one. Notably, this class also supports adding spaces between images,
20 | adding borders around images, and using white/black background.
21 |
22 | Example:
23 |
24 | grid = GridVisualizer(num_rows, num_cols)
25 | for i in range(num_rows):
26 | for j in range(num_cols):
27 | grid.add(i, j, image)
28 | grid.save('visualize.jpg')
29 | """
30 |
31 | def __init__(self,
32 | grid_size=0,
33 | num_rows=0,
34 | num_cols=0,
35 | is_portrait=False,
36 | image_size=None,
37 | image_channels=0,
38 | row_spacing=0,
39 | col_spacing=0,
40 | border_left=0,
41 | border_right=0,
42 | border_top=0,
43 | border_bottom=0,
44 | use_black_background=True):
45 | """Initializes the grid visualizer.
46 |
47 | Args:
48 | grid_size: Total number of cells, i.e., height * width. (default: 0)
49 | num_rows: Number of rows. (default: 0)
50 | num_cols: Number of columns. (default: 0)
51 | is_portrait: Whether the grid should be portrait or landscape.
52 | This is only used when it requires to compute `num_rows` and
53 | `num_cols` automatically. See function `get_grid_shape()` in
54 | file `./image_utils.py` for details. (default: False)
55 | image_size: Size to visualize each image. (default: 0)
56 | image_channels: Number of image channels. (default: 0)
57 | row_spacing: Spacing between rows. (default: 0)
58 | col_spacing: Spacing between columns. (default: 0)
59 | border_left: Width of left border. (default: 0)
60 | border_right: Width of right border. (default: 0)
61 | border_top: Width of top border. (default: 0)
62 | border_bottom: Width of bottom border. (default: 0)
63 | use_black_background: Whether to use black background.
64 | (default: True)
65 | """
66 | self.reset(grid_size, num_rows, num_cols, is_portrait)
67 | self.set_image_size(image_size)
68 | self.set_image_channels(image_channels)
69 | self.set_row_spacing(row_spacing)
70 | self.set_col_spacing(col_spacing)
71 | self.set_border_left(border_left)
72 | self.set_border_right(border_right)
73 | self.set_border_top(border_top)
74 | self.set_border_bottom(border_bottom)
75 | self.set_background(use_black_background)
76 | self.grid = None
77 |
78 | def reset(self,
79 | grid_size=0,
80 | num_rows=0,
81 | num_cols=0,
82 | is_portrait=False):
83 | """Resets the grid shape, i.e., number of rows/columns."""
84 | if grid_size > 0:
85 | num_rows, num_cols = get_grid_shape(grid_size,
86 | height=num_rows,
87 | width=num_cols,
88 | is_portrait=is_portrait)
89 | self.grid_size = num_rows * num_cols
90 | self.num_rows = num_rows
91 | self.num_cols = num_cols
92 | self.grid = None
93 |
94 | def set_image_size(self, image_size=None):
95 | """Sets the image size of each cell in the grid."""
96 | height, width = parse_image_size(image_size)
97 | self.image_height = height
98 | self.image_width = width
99 |
100 | def set_image_channels(self, image_channels=0):
101 | """Sets the number of channels of the grid."""
102 | self.image_channels = image_channels
103 |
104 | def set_row_spacing(self, row_spacing=0):
105 | """Sets the spacing between grid rows."""
106 | self.row_spacing = row_spacing
107 |
108 | def set_col_spacing(self, col_spacing=0):
109 | """Sets the spacing between grid columns."""
110 | self.col_spacing = col_spacing
111 |
112 | def set_border_left(self, border_left=0):
113 | """Sets the width of the left border of the grid."""
114 | self.border_left = border_left
115 |
116 | def set_border_right(self, border_right=0):
117 | """Sets the width of the right border of the grid."""
118 | self.border_right = border_right
119 |
120 | def set_border_top(self, border_top=0):
121 | """Sets the width of the top border of the grid."""
122 | self.border_top = border_top
123 |
124 | def set_border_bottom(self, border_bottom=0):
125 | """Sets the width of the bottom border of the grid."""
126 | self.border_bottom = border_bottom
127 |
128 | def set_background(self, use_black=True):
129 | """Sets the grid background."""
130 | self.use_black_background = use_black
131 |
132 | def init_grid(self):
133 | """Initializes the grid with a blank image."""
134 | assert self.num_rows > 0
135 | assert self.num_cols > 0
136 | assert self.image_height > 0
137 | assert self.image_width > 0
138 | assert self.image_channels > 0
139 | grid_height = (self.image_height * self.num_rows +
140 | self.row_spacing * (self.num_rows - 1) +
141 | self.border_top + self.border_bottom)
142 | grid_width = (self.image_width * self.num_cols +
143 | self.col_spacing * (self.num_cols - 1) +
144 | self.border_left + self.border_right)
145 | self.grid = get_blank_image(grid_height, grid_width,
146 | channels=self.image_channels,
147 | use_black=self.use_black_background)
148 |
149 | def add(self, i, j, image):
150 | """Adds an image into the grid.
151 |
152 | NOTE: The input image is assumed to be with `RGB` channel order.
153 | """
154 | channels = 1 if image.ndim == 2 else image.shape[2]
155 | if self.grid is None:
156 | height, width = image.shape[0:2]
157 | height = self.image_height or height
158 | width = self.image_width or width
159 | channels = self.image_channels or channels
160 | self.set_image_size((height, width))
161 | self.set_image_channels(channels)
162 | self.init_grid()
163 | if image.shape[0:2] != (self.image_height, self.image_width):
164 | image = resize_image(image, (self.image_width, self.image_height))
165 | y = self.border_top + i * (self.image_height + self.row_spacing)
166 | x = self.border_left + j * (self.image_width + self.col_spacing)
167 | self.grid[y:y + self.image_height,
168 | x:x + self.image_width,
169 | :channels] = image
170 |
171 | def visualize_collection(self,
172 | images,
173 | save_path=None,
174 | num_rows=0,
175 | num_cols=0,
176 | is_portrait=False,
177 | is_row_major=True):
178 | """Visualizes a collection of images one by one."""
179 | self.grid = None
180 | self.reset(grid_size=len(images),
181 | num_rows=num_rows,
182 | num_cols=num_cols,
183 | is_portrait=is_portrait)
184 | for idx, image in enumerate(images):
185 | if is_row_major:
186 | row_idx, col_idx = divmod(idx, self.num_cols)
187 | else:
188 | col_idx, row_idx = divmod(idx, self.num_rows)
189 | self.add(row_idx, col_idx, image)
190 | if save_path:
191 | self.save(save_path)
192 |
193 | def visualize_list(self,
194 | image_list,
195 | save_path=None,
196 | num_rows=0,
197 | num_cols=0,
198 | is_portrait=False,
199 | is_row_major=True):
200 | """Visualizes a list of image files."""
201 | self.grid = None
202 | self.reset(grid_size=len(image_list),
203 | num_rows=num_rows,
204 | num_cols=num_cols,
205 | is_portrait=is_portrait)
206 | for idx, filename in enumerate(image_list):
207 | image = load_image(filename)
208 | if is_row_major:
209 | row_idx, col_idx = divmod(idx, self.num_cols)
210 | else:
211 | col_idx, row_idx = divmod(idx, self.num_rows)
212 | self.add(row_idx, col_idx, image)
213 | if save_path:
214 | self.save(save_path)
215 |
216 | def visualize_directory(self,
217 | directory,
218 | save_path=None,
219 | num_rows=0,
220 | num_cols=0,
221 | is_portrait=False,
222 | is_row_major=True):
223 | """Visualizes all images under a directory."""
224 | image_list = list_images_from_dir(directory)
225 | self.visualize_list(image_list=image_list,
226 | save_path=save_path,
227 | num_rows=num_rows,
228 | num_cols=num_cols,
229 | is_portrait=is_portrait,
230 | is_row_major=is_row_major)
231 |
232 | def save(self, path):
233 | """Saves the grid."""
234 | save_image(path, self.grid)
235 |
--------------------------------------------------------------------------------
/utils/visualizers/test.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Unit test for visualizer."""
3 |
4 | import os
5 | import skvideo.datasets
6 |
7 | from ..image_utils import save_image
8 | from . import GridVisualizer
9 | from . import HtmlVisualizer
10 | from . import HtmlReader
11 | from . import GifVisualizer
12 | from . import VideoVisualizer
13 | from . import VideoReader
14 |
15 | __all__ = ['test_visualizer']
16 |
17 | _TEST_DIR = 'visualizer_test'
18 |
19 |
20 | def test_visualizer(test_dir=_TEST_DIR):
21 | """Tests visualizers."""
22 | print('========== Start Visualizer Test ==========')
23 |
24 | frame_dir = os.path.join(test_dir, 'test_frames')
25 | os.makedirs(frame_dir, exist_ok=True)
26 |
27 | print('===== Testing `VideoReader` =====')
28 | # Total 132 frames, with size (720, 1080).
29 | video_reader = VideoReader(skvideo.datasets.bigbuckbunny())
30 | frame_height = video_reader.frame_height
31 | frame_width = video_reader.frame_width
32 | frame_size = (frame_height, frame_width)
33 | half_size = (frame_height // 2, frame_width // 2)
34 | # Save frames as the test set.
35 | for idx in range(80):
36 | frame = video_reader.read()
37 | save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame)
38 |
39 | print('===== Testing `GirdVisualizer` =====')
40 | grid_visualizer = GridVisualizer()
41 | grid_visualizer.set_row_spacing(30)
42 | grid_visualizer.set_col_spacing(30)
43 | grid_visualizer.set_background(use_black=True)
44 | path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png')
45 | grid_visualizer.visualize_directory(frame_dir, path,
46 | is_portrait=True, is_row_major=True)
47 | path = os.path.join(
48 | test_dir, 'landscape_col_major_downsample_space15_white.png')
49 | grid_visualizer.set_image_size(half_size)
50 | grid_visualizer.set_row_spacing(15)
51 | grid_visualizer.set_col_spacing(15)
52 | grid_visualizer.set_background(use_black=False)
53 | grid_visualizer.visualize_directory(frame_dir, path,
54 | is_portrait=False, is_row_major=False)
55 |
56 | print('===== Testing `HtmlVisualizer` =====')
57 | html_visualizer = HtmlVisualizer()
58 | path = os.path.join(test_dir, 'portrait_col_major_ori.html')
59 | html_visualizer.visualize_directory(frame_dir, path,
60 | is_portrait=True, is_row_major=False)
61 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html')
62 | html_visualizer.set_image_size(half_size)
63 | html_visualizer.visualize_directory(frame_dir, path,
64 | is_portrait=False, is_row_major=True)
65 |
66 | print('===== Testing `HtmlReader` =====')
67 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html')
68 | html_reader = HtmlReader(path)
69 | for j in range(html_reader.num_cols):
70 | assert html_reader.get_header(j) == ''
71 | parsed_dir = os.path.join(test_dir, 'parsed_frames')
72 | os.makedirs(parsed_dir, exist_ok=True)
73 | for i in range(html_reader.num_rows):
74 | for j in range(html_reader.num_cols):
75 | idx = i * html_reader.num_cols + j
76 | assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})')
77 | image = html_reader.get_image(i, j, image_size=frame_size)
78 | assert image.shape[0:2] == frame_size
79 | save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image)
80 |
81 | print('===== Testing `GifVisualizer` =====')
82 | gif_visualizer = GifVisualizer()
83 | path = os.path.join(test_dir, 'gif_ori.gif')
84 | gif_visualizer.visualize_directory(frame_dir, path)
85 | gif_visualizer.set_image_size(half_size)
86 | path = os.path.join(test_dir, 'gif_downsample.gif')
87 | gif_visualizer.visualize_directory(frame_dir, path)
88 |
89 | print('===== Testing `VideoVisualizer` =====')
90 | video_visualizer = VideoVisualizer()
91 | path = os.path.join(test_dir, 'video_ori.mp4')
92 | video_visualizer.visualize_directory(frame_dir, path)
93 | path = os.path.join(test_dir, 'video_downsample.mp4')
94 | video_visualizer.set_frame_size(half_size)
95 | video_visualizer.visualize_directory(frame_dir, path)
96 |
97 | print('========== Finish Visualizer Test ==========')
98 |
--------------------------------------------------------------------------------
/utils/visualizers/video_visualizer.py:
--------------------------------------------------------------------------------
1 | # python3.7
2 | """Contains the visualizer to visualize images as a video.
3 |
4 | This file relies on `FFmpeg`. Use `sudo apt-get install ffmpeg` and
5 | `brew install ffmpeg` to install on Ubuntu and MacOS respectively.
6 | """
7 |
8 | import os.path
9 | from skvideo.io import FFmpegWriter
10 | from skvideo.io import FFmpegReader
11 |
12 | from ..image_utils import parse_image_size
13 | from ..image_utils import load_image
14 | from ..image_utils import resize_image
15 | from ..image_utils import list_images_from_dir
16 |
17 | __all__ = ['VideoVisualizer', 'VideoReader']
18 |
19 |
20 | class VideoVisualizer(object):
21 | """Defines the video visualizer that presents images as a video."""
22 |
23 | def __init__(self,
24 | path=None,
25 | frame_size=None,
26 | fps=25.0,
27 | codec='libx264',
28 | pix_fmt='yuv420p',
29 | crf=1):
30 | """Initializes the video visualizer.
31 |
32 | Args:
33 | path: Path to write the video. (default: None)
34 | frame_size: Frame size, i.e., (height, width). (default: None)
35 | fps: Frames per second. (default: 24)
36 | codec: Codec. (default: `libx264`)
37 | pix_fmt: Pixel format. (default: `yuv420p`)
38 | crf: Constant rate factor, which controls the compression. The
39 | larger this field is, the higher compression and lower quality.
40 | `0` means no compression and consequently the highest quality.
41 | To enable QuickTime playing (requires YUV to be 4:2:0, but
42 | `crf = 0` results YUV to be 4:4:4), please set this field as
43 | at least 1. (default: 1)
44 | """
45 | self.set_path(path)
46 | self.set_frame_size(frame_size)
47 | self.set_fps(fps)
48 | self.set_codec(codec)
49 | self.set_pix_fmt(pix_fmt)
50 | self.set_crf(crf)
51 | self.video = None
52 |
53 | def set_path(self, path=None):
54 | """Sets the path to save the video."""
55 | self.path = path
56 |
57 | def set_frame_size(self, frame_size=None):
58 | """Sets the video frame size."""
59 | height, width = parse_image_size(frame_size)
60 | self.frame_height = height
61 | self.frame_width = width
62 |
63 | def set_fps(self, fps=25.0):
64 | """Sets the FPS (frame per second) of the video."""
65 | self.fps = fps
66 |
67 | def set_codec(self, codec='libx264'):
68 | """Sets the video codec."""
69 | self.codec = codec
70 |
71 | def set_pix_fmt(self, pix_fmt='yuv420p'):
72 | """Sets the video pixel format."""
73 | self.pix_fmt = pix_fmt
74 |
75 | def set_crf(self, crf=1):
76 | """Sets the CRF (constant rate factor) of the video."""
77 | self.crf = crf
78 |
79 | def init_video(self):
80 | """Initializes an empty video with expected settings."""
81 | assert not os.path.exists(self.path), f'Video `{self.path}` existed!'
82 | assert self.frame_height > 0
83 | assert self.frame_width > 0
84 |
85 | video_setting = {
86 | '-r': f'{self.fps:.2f}',
87 | '-s': f'{self.frame_width}x{self.frame_height}',
88 | '-vcodec': f'{self.codec}',
89 | '-crf': f'{self.crf}',
90 | '-pix_fmt': f'{self.pix_fmt}',
91 | }
92 | self.video = FFmpegWriter(self.path, outputdict=video_setting)
93 |
94 | def add(self, frame):
95 | """Adds a frame into the video visualizer.
96 |
97 | NOTE: The input frame is assumed to be with `RGB` channel order.
98 | """
99 | if self.video is None:
100 | height, width = frame.shape[0:2]
101 | height = self.frame_height or height
102 | width = self.frame_width or width
103 | self.set_frame_size((height, width))
104 | self.init_video()
105 | if frame.shape[0:2] != (self.frame_height, self.frame_width):
106 | frame = resize_image(frame, (self.frame_width, self.frame_height))
107 | self.video.writeFrame(frame)
108 |
109 | def visualize_collection(self, images, save_path=None):
110 | """Visualizes a collection of images one by one."""
111 | if save_path is not None and save_path != self.path:
112 | self.save()
113 | self.set_path(save_path)
114 | for image in images:
115 | self.add(image)
116 | self.save()
117 |
118 | def visualize_list(self, image_list, save_path=None):
119 | """Visualizes a list of image files."""
120 | if save_path is not None and save_path != self.path:
121 | self.save()
122 | self.set_path(save_path)
123 | for filename in image_list:
124 | image = load_image(filename)
125 | self.add(image)
126 | self.save()
127 |
128 | def visualize_directory(self, directory, save_path=None):
129 | """Visualizes all images under a directory."""
130 | image_list = list_images_from_dir(directory)
131 | self.visualize_list(image_list, save_path)
132 |
133 | def save(self):
134 | """Saves the video by closing the file."""
135 | if self.video is not None:
136 | self.video.close()
137 | self.video = None
138 | self.set_path(None)
139 |
140 |
141 | class VideoReader(object):
142 | """Defines the video reader.
143 |
144 | This class can be used to read frames from a given video.
145 |
146 | NOTE: Each frame can be read only once.
147 | TODO: Fix this?
148 | """
149 |
150 | def __init__(self, path, inputdict=None):
151 | """Initializes the video reader by loading the video from disk."""
152 | self.path = path
153 | self.video = FFmpegReader(path, inputdict=inputdict)
154 |
155 | self.length = self.video.inputframenum
156 | self.frame_height = self.video.inputheight
157 | self.frame_width = self.video.inputwidth
158 | self.fps = self.video.inputfps
159 | self.pix_fmt = self.video.pix_fmt
160 |
161 | def __del__(self):
162 | """Releases the opened video."""
163 | self.video.close()
164 |
165 | def read(self, image_size=None):
166 | """Reads the next frame."""
167 | frame = next(self.video.nextFrame())
168 | height, width = parse_image_size(image_size)
169 | height = height or frame.shape[0]
170 | width = width or frame.shape[1]
171 | if frame.shape[0:2] != (height, width):
172 | frame = resize_image(frame, (width, height))
173 | return frame
174 |
--------------------------------------------------------------------------------