├── LICENSE ├── README.md ├── docs └── assets │ ├── interpolation.gif │ └── synthesis.jpg ├── models ├── __init__.py ├── clip_model.py ├── text_generator.py └── utils │ ├── __init__.py │ └── ops.py ├── requirements ├── convert.txt ├── develop.txt └── minimal.txt ├── run_interpolate.py ├── run_synthesize.py ├── third_party ├── .DS_Store ├── __init__.py ├── clip_official │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ └── simple_tokenizer.py ├── stylegan2_official_ops │ ├── README.md │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── custom_ops.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── misc.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py └── stylegan3_official_ops │ ├── README.md │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── custom_ops.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── misc.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py └── utils ├── .DS_Store ├── __init__.py ├── dist_utils.py ├── file_transmitters ├── __init__.py ├── base_file_transmitter.py ├── dummy_file_transmitter.py └── local_file_transmitter.py ├── formatting_utils.py ├── image_utils.py ├── loggers ├── __init__.py ├── base_logger.py ├── dummy_logger.py ├── normal_logger.py ├── rich_logger.py └── test.py ├── misc.py ├── parsing_utils.py ├── tf_utils.py └── visualizers ├── __init__.py ├── gif_visualizer.py ├── grid_visualizer.py ├── html_visualizer.py ├── test.py └── video_visualizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | ------------------------------ LICENSE for Aurora ------------------------------ 2 | 3 | Copyright (c) 2023 Ant Group. 4 | 5 | MIT License 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | ------------------------------- LICENSE for CLIP ------------------------------- 26 | 27 | MIT License 28 | 29 | Copyright (c) 2021 OpenAI 30 | 31 | Permission is hereby granted, free of charge, to any person obtaining a copy 32 | of this software and associated documentation files (the "Software"), to deal 33 | in the Software without restriction, including without limitation the rights 34 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 35 | copies of the Software, and to permit persons to whom the Software is 36 | furnished to do so, subject to the following conditions: 37 | 38 | The above copyright notice and this permission notice shall be included in all 39 | copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | 49 | ------------------------------ LICENSE for Hammer ------------------------------ 50 | 51 | Copyright (c) 2022 ByteDance, Inc. 52 | 53 | MIT License 54 | 55 | Permission is hereby granted, free of charge, to any person obtaining a copy 56 | of this software and associated documentation files (the "Software"), to deal 57 | in the Software without restriction, including without limitation the rights 58 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 59 | copies of the Software, and to permit persons to whom the Software is 60 | furnished to do so, subject to the following conditions: 61 | 62 | The above copyright notice and this permission notice shall be included in all 63 | copies or substantial portions of the Software. 64 | 65 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 66 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 67 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 68 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 69 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 70 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 71 | SOFTWARE. 72 | 73 | ---------- LICENSE for custom CUDA kernels in StyleGAN2 and StyleGAN3 ---------- 74 | 75 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 76 | 77 | NVIDIA Source Code License 78 | 79 | ======================================================================= 80 | 81 | 1. Definitions 82 | 83 | "Licensor" means any person or entity that distributes its Work. 84 | 85 | "Software" means the original work of authorship made available under 86 | this License. 87 | 88 | "Work" means the Software and any additions to or derivative works of 89 | the Software that are made available under this License. 90 | 91 | The terms "reproduce," "reproduction," "derivative works," and 92 | "distribution" have the meaning as provided under U.S. copyright law; 93 | provided, however, that for the purposes of this License, derivative 94 | works shall not include works that remain separable from, or merely 95 | link (or bind by name) to the interfaces of, the Work. 96 | 97 | Works, including the Software, are "made available" under this License 98 | by including in or with the Work either (a) a copyright notice 99 | referencing the applicability of this License to the Work, or (b) a 100 | copy of this License. 101 | 102 | 2. License Grants 103 | 104 | 2.1 Copyright Grant. Subject to the terms and conditions of this 105 | License, each Licensor grants to you a perpetual, worldwide, 106 | non-exclusive, royalty-free, copyright license to reproduce, 107 | prepare derivative works of, publicly display, publicly perform, 108 | sublicense and distribute its Work and any resulting derivative 109 | works in any form. 110 | 111 | 3. Limitations 112 | 113 | 3.1 Redistribution. You may reproduce or distribute the Work only 114 | if (a) you do so under this License, (b) you include a complete 115 | copy of this License with your distribution, and (c) you retain 116 | without modification any copyright, patent, trademark, or 117 | attribution notices that are present in the Work. 118 | 119 | 3.2 Derivative Works. You may specify that additional or different 120 | terms apply to the use, reproduction, and distribution of your 121 | derivative works of the Work ("Your Terms") only if (a) Your Terms 122 | provide that the use limitation in Section 3.3 applies to your 123 | derivative works, and (b) you identify the specific derivative 124 | works that are subject to Your Terms. Notwithstanding Your Terms, 125 | this License (including the redistribution requirements in Section 126 | 3.1) will continue to apply to the Work itself. 127 | 128 | 3.3 Use Limitation. The Work and any derivative works thereof only 129 | may be used or intended for use non-commercially. Notwithstanding 130 | the foregoing, NVIDIA and its affiliates may use the Work and any 131 | derivative works commercially. As used herein, "non-commercially" 132 | means for research or evaluation purposes only. 133 | 134 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 135 | against any Licensor (including any claim, cross-claim or 136 | counterclaim in a lawsuit) to enforce any patents that you allege 137 | are infringed by any Work, then your rights under this License from 138 | such Licensor (including the grant in Section 2.1) will terminate 139 | immediately. 140 | 141 | 3.5 Trademarks. This License does not grant any rights to use any 142 | Licensor’s or its affiliates’ names, logos, or trademarks, except 143 | as necessary to reproduce the notices described in this License. 144 | 145 | 3.6 Termination. If you violate any term of this License, then your 146 | rights under this License (including the grant in Section 2.1) will 147 | terminate immediately. 148 | 149 | 4. Disclaimer of Warranty. 150 | 151 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 152 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 153 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 154 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 155 | THIS LICENSE. 156 | 157 | 5. Limitation of Liability. 158 | 159 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 160 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 161 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 162 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 163 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 164 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 165 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 166 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 167 | THE POSSIBILITY OF SUCH DAMAGES. 168 | 169 | ======================================================================= 170 | 171 | ----------------------------- LICENSE for GenForce ----------------------------- 172 | 173 | Copyright (c) 2020 GenForce 174 | 175 | MIT License 176 | 177 | Permission is hereby granted, free of charge, to any person obtaining a copy 178 | of this software and associated documentation files (the "Software"), to deal 179 | in the Software without restriction, including without limitation the rights 180 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 181 | copies of the Software, and to permit persons to whom the Software is 182 | furnished to do so, subject to the following conditions: 183 | 184 | The above copyright notice and this permission notice shall be included in all 185 | copies or substantial portions of the Software. 186 | 187 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 188 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 189 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 190 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 191 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 192 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 193 | SOFTWARE. 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Aurora -- An Open-sourced GAN-based Text-to-Image Generation Model 2 | 3 | > **Exploring Sparse MoE in GANs for Text-conditioned Image Synthesis**
4 | > Jiapeng Zhu*, Ceyuan Yang*, Kecheng Zheng, Yinghao Xu, Zifan Shi, Yujun Shen
5 | > *arXiv preprint arXiv:2309.03904*
6 | 7 | [[Paper](https://arxiv.org/pdf/2309.03904.pdf)] 8 | 9 | ## TODO 10 | 11 | - [x] Release inference code 12 | - [x] Release text-to-image generator at 64x64 resolution 13 | - [ ] Release models at higher resolution 14 | - [ ] Release training code 15 | - [ ] Release plug-ins/efficient algorithms for more functionalities 16 | 17 | ## Installation 18 | 19 | This repository is developed based on [Hammer](https://github.com/bytedance/Hammer), where you can find more detailed instructions on installation. Here, we summarize the necessary steps to facilitate reproduction. 20 | 21 | 1. Environment: CUDA version == 11.3. 22 | 23 | 2. Install package requirements with `conda`: 24 | 25 | ```shell 26 | conda create -n aurora python=3.8 # create virtual environment with Python 3.8 27 | conda activate aurora 28 | pip install -r requirements/minimal.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html 29 | ``` 30 | 31 | ## Inference 32 | 33 | First, please download the pre-trained model [here](https://hkustconnect-my.sharepoint.com/:u:/g/personal/jzhubt_connect_ust_hk/EZJRG5BV_URMjywZLcAW95YBKNQaD7M35Ba6PCHe_Gf16w?e=0n5BOm). 34 | 35 | To synthesize an image with given text prompt, you can use the following command 36 | 37 | ```bash 38 | python run_synthesize.py aurora_v1.pth 'A photo of a tree with autumn leaves' 39 | ``` 40 | 41 | To make interpolation between two text prompts, you can use the following command 42 | 43 | ```bash 44 | python run_interpolate.py aurora_v1.pth \ 45 | --src_prompt 'A photo of a tree with autumn leaves' \ 46 | --dst_prompt 'A photo of a victorian house' 47 | ``` 48 | 49 | ## Results 50 | 51 | - Text-conditioned image generation 52 | 53 | ![image](./docs/assets/synthesis.jpg) 54 | 55 | - Text prompt interpolation 56 | 57 | ![image](./docs/assets/interpolation.gif) 58 | 59 | ## LICENSE 60 | 61 | The project is under [MIT License](./LICENSE), and is for research purpose ONLY. 62 | 63 | ## Acknowledgements 64 | 65 | We highly appreciate [StyleGAN2](https://github.com/NVlabs/stylegan2), [StyleGAN3](https://github.com/NVlabs/stylegan3), [CLIP](https://github.com/openai/CLIP), and [Hammer](https://github.com/bytedance/Hammer) for their contributions to the community. 66 | 67 | ## BibTeX 68 | 69 | ```bibtex 70 | @article{zhu2023aurora, 71 | title = {Exploring Sparse {MoE} in {GANs} for Text-conditioned Image Synthesis}, 72 | author = {Zhu, Jiapeng and Yang, Ceyuan and Zheng, Kecheng and Xu, Yinghao and Shi, Zifan and Shen, Yujun}, 73 | journal = {arXiv preprint arXiv:2309.03904}, 74 | year = {2023} 75 | } 76 | ``` 77 | -------------------------------------------------------------------------------- /docs/assets/interpolation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/docs/assets/interpolation.gif -------------------------------------------------------------------------------- /docs/assets/synthesis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/docs/assets/synthesis.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all models.""" 3 | 4 | from .clip_model import CLIPModel 5 | from .text_generator import Text2ImageGenerator 6 | 7 | __all__ = ['build_model'] 8 | 9 | _MODELS = { 10 | 'CLIPModel': CLIPModel, 11 | 'Text2ImageGenerator': Text2ImageGenerator 12 | } 13 | 14 | 15 | def build_model(model_type, **kwargs): 16 | """Builds a model based on its class type. 17 | 18 | Args: 19 | model_type: Class type to which the model belongs, which is case 20 | sensitive. 21 | **kwargs: Additional arguments to build the model. 22 | 23 | Raises: 24 | ValueError: If the `model_type` is not supported. 25 | """ 26 | if model_type not in _MODELS: 27 | raise ValueError(f'Invalid model type: `{model_type}`!\n' 28 | f'Types allowed: {list(_MODELS)}.') 29 | return _MODELS[model_type](**kwargs) 30 | -------------------------------------------------------------------------------- /models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/models/utils/__init__.py -------------------------------------------------------------------------------- /models/utils/ops.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains operators for neural networks.""" 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | __all__ = ['all_gather'] 8 | 9 | 10 | def all_gather(tensor): 11 | """Gathers tensor from all devices and executes averaging.""" 12 | if not dist.is_initialized(): 13 | return tensor 14 | 15 | world_size = dist.get_world_size() 16 | tensor_list = [torch.ones_like(tensor) for _ in range(world_size)] 17 | dist.all_gather(tensor_list, tensor, async_op=False) 18 | return torch.stack(tensor_list, dim=0).mean(dim=0) 19 | -------------------------------------------------------------------------------- /requirements/convert.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | tensorflow-gpu==1.15 3 | ninja==1.10.2 4 | scikit-video==1.1.11 5 | pillow==9.0.0 6 | opencv-python-headless==4.5.5.62 7 | requests 8 | bs4 9 | tqdm 10 | rich 11 | easydict 12 | -------------------------------------------------------------------------------- /requirements/develop.txt: -------------------------------------------------------------------------------- 1 | bpytop # Monitor system resources. 2 | gpustat # Monitor GPU usage. 3 | pylint # Check coding style. 4 | -------------------------------------------------------------------------------- /requirements/minimal.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 2 | torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu111/torch_stable.html 3 | tensorboard==2.7.0 4 | torch-tb-profiler==0.3.1 5 | ninja==1.10.2 6 | numpy==1.22.3 7 | scipy==1.7.3 8 | scikit-learn==1.0.2 9 | scikit-video==1.1.11 10 | pillow==9.0.0 11 | opencv-python-headless==4.5.5.62 12 | requests 13 | bs4 14 | tqdm 15 | rich 16 | click 17 | cloup 18 | psutil 19 | easydict 20 | lmdb 21 | matplotlib 22 | regex 23 | ftfy 24 | einops==0.6.1 25 | huggingface_hub==0.15.1 26 | -------------------------------------------------------------------------------- /run_synthesize.py: -------------------------------------------------------------------------------- 1 | 2 | # python3.7 3 | """Contains the code to synthesize images from a pre-trained models. 4 | """ 5 | import os 6 | import argparse 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from models import build_model 11 | from utils.image_utils import postprocess_image, save_image 12 | from utils.visualizers import HtmlVisualizer 13 | 14 | 15 | def run_mapping(G, z, context, eot_ind=None): 16 | """Run mapping network of the generator.""" 17 | with torch.no_grad(): 18 | global_text, local_text = G.text_head(context, eot_ind=eot_ind) 19 | mapping_results = G.mapping(z, 20 | label=None, 21 | context=global_text) 22 | return mapping_results['wp'], local_text 23 | 24 | 25 | def run_synthesize(G, wp, local_text): 26 | """Run synthesis network of the generator.""" 27 | with torch.no_grad(): 28 | res = G.synthesis(wp, context=local_text) 29 | return res 30 | 31 | 32 | def read_text(text_path): 33 | """Prepare snapshot text that will be used for evaluation.""" 34 | print(f'Loading text from {text_path}') 35 | with open(text_path) as f: 36 | text = [line.strip() for line in f.readlines()] 37 | return text 38 | 39 | 40 | def parse_float(arg): 41 | """Parse float number in string.""" 42 | if not arg: 43 | return None 44 | arg = arg.split(',') 45 | arg = [float(i) for i in arg] 46 | return arg 47 | 48 | 49 | def parse_args(): 50 | """Parses arguments.""" 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('weight_path', type=str, default='', 53 | help='Path to the pre-trained models.') 54 | parser.add_argument('text_prompt', type=str, default='', 55 | help='The text prompt, support reading from a file ' 56 | 'or just given a text prompt.') 57 | parser.add_argument('--batch_size', type=int, default=1, 58 | help='Batch size.') 59 | parser.add_argument('--syn_num', type=int, default=100, 60 | help='Number of synthesized images.') 61 | parser.add_argument('--resolution', type=int, default=64, 62 | help='Resolution of the model output.') 63 | parser.add_argument('--results_dir', type=str, default='work_dirs/syn_res', 64 | help='Results directory.') 65 | parser.add_argument('--seed', type=int, default=4, 66 | help='Random seed.') 67 | parser.add_argument('--trunc_layers', type=int, default=None, 68 | help='Number of layers to perform truncation.') 69 | parser.add_argument('--loop_mapping', type=int, default=16, 70 | help='Loop number for getting average for wp.') 71 | parser.add_argument('--save_name', type=str, default='0', 72 | help='Name to help save the file.') 73 | parser.add_argument('--num_z', type=int, default=3, 74 | help='Number of z for each text prompt.') 75 | parser.add_argument('--save_png', action='store_true', 76 | help='Whether or not to save the synthesized images.') 77 | parser.add_argument('--trunc_vals', type=str, default='0,0.05,0.1,0.15,0.2', 78 | help='Default values for truncation.') 79 | return parser.parse_args() 80 | 81 | 82 | def main(): 83 | """Main function.""" 84 | args = parse_args() 85 | assert args.batch_size == 1, 'Current script only support bs equals to 1.' 86 | if os.path.exists(args.text_prompt): 87 | text_prompt = read_text(args.text_prompt) 88 | else: 89 | text_prompt = [args.text_prompt] 90 | syn_num = min(args.syn_num, len(text_prompt)) 91 | if torch.cuda.is_available(): 92 | device = torch.device('cuda') 93 | else: 94 | device = torch.device('cpu') 95 | 96 | clip_config = {'model_name':'ViT-L-14', 97 | 'pretrained':'openai', 98 | 'freeze_clip': True} 99 | clip = build_model('CLIPModel', **clip_config) 100 | 101 | g_config = {'resolution': 64, 102 | 'image_channels':3, 103 | 'init_res': 4, 104 | 'z_dim': 128, 105 | 'w_dim': 1024, 106 | 'mapping_fmaps': 1024, 107 | 'label_dim': 0, 108 | 'context_dim': 1024, 109 | 'clip_out_dim': 768, 110 | 'head_dim': 64, 111 | 'embedding_dim': 1024, 112 | 'use_text_cond': True, 113 | 'num_layers_text_enc': 4, 114 | 'use_w_cond': False, 115 | 'use_class_label': False, 116 | 'mapping_layers': 4, 117 | 'fmaps_base': 16384, 118 | 'fmaps_max': 1600, 119 | 'num_adaptive_kernels': {"4":1,"8":1,"16":2,"32":4,"64":8}, 120 | 'num_block_per_res': {"4":3,"8":3,"16":3,"32":2,"64":2}, 121 | 'attn_resolutions': ['8', '16', '32', '64'], 122 | 'attn_depth': {"8":2,"16":2,"32":2,"64":1}, 123 | 'attn_ch_factor': 1, 124 | 'attn_gain': 0.3, 125 | 'residual_gain': 0.4, 126 | 'text_head_gain': 1.0, 127 | 'zero_out': True, 128 | 'fourier_feat': True, 129 | 'l2_attention': True, 130 | 'tie': False, 131 | 'scale_in': False, 132 | 'include_ff': True, 133 | 'use_checkpoint': False, 134 | 'checkpoint_res': ['8', '16', '32'], 135 | 'mask_self': False, 136 | 'conv_clamp': None, 137 | 'mtm': True, 138 | 'num_experts': {"8":4,"16":8,"32":16,"64":16}, 139 | 'ms_training_res': ['4','8','16','32','64'], 140 | 'skip_connection': True} 141 | 142 | G = build_model('Text2ImageGenerator', **g_config) 143 | checkpoint = torch.load(args.weight_path, map_location='cpu') 144 | if 'generator_smooth' in checkpoint: 145 | print('Loading checkpoint from generator smooth!') 146 | G.load_state_dict(checkpoint['generator_smooth']) 147 | else: 148 | print('Loading checkpoint from generator!') 149 | G.load_state_dict(checkpoint['generator']) 150 | G = G.eval().to(device) 151 | 152 | trunc_vals = parse_float(args.trunc_vals) 153 | if not trunc_vals: 154 | trunc_vals = [0, 0.05, 0.1, 0.15, 0.2] 155 | trunc_layers = args.trunc_layers 156 | if not trunc_layers: 157 | trunc_layers = G.num_layers 158 | 159 | visualizer_syn = HtmlVisualizer(image_size=args.resolution) 160 | visualizer_syn.reset(num_rows=syn_num * args.num_z, 161 | num_cols=len(trunc_vals) + 1) 162 | head = ['Number Z'] 163 | head += [f'trunc_val_{val_}' for val_ in trunc_vals] 164 | visualizer_syn.set_headers(head) 165 | torch.manual_seed(args.seed) 166 | os.makedirs(args.results_dir, exist_ok=True) 167 | if args.save_png: 168 | os.makedirs(f'{args.results_dir}/images', exist_ok=True) 169 | w_avg = G.w_avg.reshape(1, -1, G.w_dim)[:, :args.trunc_layers] 170 | for idx in tqdm(range(syn_num)): 171 | text = text_prompt[idx] 172 | _, enc_text, eot_ind = clip.encode_text(text=text, is_tokenize=True) 173 | if args.loop_mapping > 0: 174 | sum_wp = 0 175 | for _ in range(args.loop_mapping): 176 | z = torch.randn((args.batch_size, *G.latent_dim), device=device) 177 | tmp_res, _ = run_mapping(G, z, enc_text, eot_ind=eot_ind) 178 | sum_wp += tmp_res 179 | avg_wp = sum_wp / args.loop_mapping 180 | avg_wp = avg_wp[:, :args.trunc_layers] 181 | for z_i in range(args.num_z): 182 | z = torch.randn((args.batch_size, *G.latent_dim), device=device) 183 | row_ind = idx * args.num_z + z_i 184 | visualizer_syn.set_cell(row_ind, 0, text=f'z_{z_i}') 185 | for col_idx, trunc_psi in enumerate(trunc_vals): 186 | wp, local_text = run_mapping(G, z, enc_text, eot_ind=eot_ind) 187 | wp[:, :args.trunc_layers] = w_avg.lerp(wp[:, :args.trunc_layers], trunc_psi) 188 | if args.loop_mapping > 0: 189 | wp[:, :args.trunc_layers] = avg_wp.lerp(wp[:, :args.trunc_layers], trunc_psi) 190 | fake_results = run_synthesize(G, wp, local_text) 191 | syn_imgs = fake_results['image'].detach().cpu().numpy() 192 | syn_imgs = postprocess_image(syn_imgs) 193 | visualizer_syn.set_cell(row_ind, 194 | col_idx + 1, 195 | text=text, 196 | image=syn_imgs[0]) 197 | if args.save_png: 198 | prefix = f'{args.results_dir}/images/text_{idx:04d}_z_' 199 | save_path0 = f'{prefix}{z_i:02d}_psi_{trunc_psi:-2.1f}.png' 200 | save_image(save_path0, syn_imgs[0]) 201 | 202 | # Save result. 203 | save_name_syn = f'syn_{syn_num:04d}_seed_{args.seed}_{args.save_name}.html' 204 | save_path_syn = os.path.join(args.results_dir, save_name_syn) 205 | visualizer_syn.save(save_path_syn) 206 | 207 | 208 | if __name__ == '__main__': 209 | main() 210 | -------------------------------------------------------------------------------- /third_party/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/.DS_Store -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/clip_official/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/clip_official/__init__.py -------------------------------------------------------------------------------- /third_party/clip_official/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/clip_official/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /third_party/clip_official/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021 OpenAI. All rights reserved. 4 | 5 | """Text tokenizer. 6 | 7 | Please refer to https://github.com/openai/CLIP 8 | """ 9 | 10 | # pylint: disable=line-too-long 11 | # pylint: disable=inconsistent-quotes 12 | # pylint: disable=missing-class-docstring 13 | # pylint: disable=missing-function-docstring 14 | # pylint: disable=bare-except 15 | # pylint: disable=no-else-break 16 | 17 | import gzip 18 | import html 19 | import os 20 | from functools import lru_cache 21 | 22 | import ftfy 23 | import regex as re 24 | 25 | 26 | @lru_cache() 27 | def default_bpe(): 28 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 29 | 30 | 31 | @lru_cache() 32 | def bytes_to_unicode(): 33 | """ 34 | Returns list of utf-8 byte and a corresponding list of unicode strings. 35 | The reversible bpe codes work on unicode strings. 36 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 37 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 38 | This is a signficant percentage of your normal, say, 32K bpe vocab. 39 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 40 | And avoids mapping to whitespace/control characters the bpe code barfs on. 41 | """ 42 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 43 | cs = bs[:] 44 | n = 0 45 | for b in range(2**8): 46 | if b not in bs: 47 | bs.append(b) 48 | cs.append(2**8+n) 49 | n += 1 50 | cs = [chr(n) for n in cs] 51 | return dict(zip(bs, cs)) 52 | 53 | 54 | def get_pairs(word): 55 | """Return set of symbol pairs in a word. 56 | Word is represented as tuple of symbols (symbols being variable-length strings). 57 | """ 58 | pairs = set() 59 | prev_char = word[0] 60 | for char in word[1:]: 61 | pairs.add((prev_char, char)) 62 | prev_char = char 63 | return pairs 64 | 65 | 66 | def basic_clean(text): 67 | text = ftfy.fix_text(text) 68 | text = html.unescape(html.unescape(text)) 69 | return text.strip() 70 | 71 | 72 | def whitespace_clean(text): 73 | text = re.sub(r'\s+', ' ', text) 74 | text = text.strip() 75 | return text 76 | 77 | 78 | class SimpleTokenizer(object): 79 | def __init__(self, bpe_path: str = default_bpe()): 80 | self.byte_encoder = bytes_to_unicode() 81 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 82 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 83 | merges = merges[1:49152-256-2+1] 84 | merges = [tuple(merge.split()) for merge in merges] 85 | vocab = list(bytes_to_unicode().values()) 86 | vocab = vocab + [v+'' for v in vocab] 87 | for merge in merges: 88 | vocab.append(''.join(merge)) 89 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 94 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 95 | 96 | def bpe(self, token): 97 | if token in self.cache: 98 | return self.cache[token] 99 | word = tuple(token[:-1]) + ( token[-1] + '',) 100 | pairs = get_pairs(word) 101 | 102 | if not pairs: 103 | return token+'' 104 | 105 | while True: 106 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 107 | if bigram not in self.bpe_ranks: 108 | break 109 | first, second = bigram 110 | new_word = [] 111 | i = 0 112 | while i < len(word): 113 | try: 114 | j = word.index(first, i) 115 | new_word.extend(word[i:j]) 116 | i = j 117 | except: 118 | new_word.extend(word[i:]) 119 | break 120 | 121 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 122 | new_word.append(first+second) 123 | i += 2 124 | else: 125 | new_word.append(word[i]) 126 | i += 1 127 | new_word = tuple(new_word) 128 | word = new_word 129 | if len(word) == 1: 130 | break 131 | else: 132 | pairs = get_pairs(word) 133 | word = ' '.join(word) 134 | self.cache[token] = word 135 | return word 136 | 137 | def encode(self, text): 138 | bpe_tokens = [] 139 | text = whitespace_clean(basic_clean(text)).lower() 140 | for token in re.findall(self.pat, text): 141 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 142 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 143 | return bpe_tokens 144 | 145 | def decode(self, tokens): 146 | text = ''.join([self.decoder[token] for token in tokens]) 147 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 148 | return text 149 | 150 | # pylint: enable=line-too-long 151 | # pylint: enable=inconsistent-quotes 152 | # pylint: enable=missing-class-docstring 153 | # pylint: enable=missing-function-docstring 154 | # pylint: enable=bare-except 155 | # pylint: enable=no-else-break 156 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/README.md: -------------------------------------------------------------------------------- 1 | # Operators for StyleGAN2 2 | 3 | All files in this directory are borrowed from repository [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including 4 | 5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. 6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. 7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. 8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. 9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. 10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. 11 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 12 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 13 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) 14 | 15 | We make following slight modifications beyond disabling some lint warnings: 16 | 17 | - Line 25 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). 18 | - Line 35 of file `custom_ops.py`: Disable log message when setting up customized operators. 19 | - Line 53/89 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) 20 | - Line 24 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan2-ada-pytorch](https://github.com/NVlabs/stylegan2-ada-pytorch). 21 | - Line 32 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. 22 | - Line 36 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. 23 | - Line 33 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. 24 | - Line 46/51 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. 25 | - Line 66 of file `conv2d_gradfix.py`: Update PyTorch version check considering the sustained development of the community. 26 | - Line 47 of file `grid_sample_gradfix.py`: Update PyTorch version check considering the sustained development of the community. 27 | - Line 36/66 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. 28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. 29 | 30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. 31 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/stylegan2_official_ops/__init__.py -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Custom replacement for convolution operators. 12 | 13 | Operators in this file support arbitrarily high order gradients with zero 14 | performance penalty. Please set `impl` as `cuda` to use faster customized 15 | operators, OR as `ref` to use native `torch.nn.functional.conv2d` and 16 | `torch.nn.functional.conv_transpose2d`. 17 | 18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 19 | """ 20 | 21 | # pylint: disable=redefined-builtin 22 | # pylint: disable=arguments-differ 23 | # pylint: disable=protected-access 24 | # pylint: disable=line-too-long 25 | # pylint: disable=global-statement 26 | # pylint: disable=missing-class-docstring 27 | # pylint: disable=missing-function-docstring 28 | # pylint: disable=deprecated-module 29 | # pylint: disable=wrong-import-order 30 | 31 | import warnings 32 | import contextlib 33 | import torch 34 | 35 | from distutils.version import LooseVersion 36 | 37 | enabled = True # Enable the custom op by setting this to true. 38 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 39 | 40 | @contextlib.contextmanager 41 | def no_weight_gradients(): 42 | global weight_gradients_disabled 43 | old = weight_gradients_disabled 44 | weight_gradients_disabled = True 45 | yield 46 | weight_gradients_disabled = old 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, impl='cuda'): 51 | if impl == 'cuda' and _should_use_custom_op(input): 52 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 53 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 54 | 55 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1, impl='cuda'): 56 | if impl == 'cuda' and _should_use_custom_op(input): 57 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 58 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 59 | 60 | #---------------------------------------------------------------------------- 61 | 62 | def _should_use_custom_op(input): 63 | assert isinstance(input, torch.Tensor) 64 | if (not enabled) or (not torch.backends.cudnn.enabled): 65 | return False 66 | if input.device.type != 'cuda': 67 | return False 68 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): 69 | return True 70 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 71 | return False 72 | 73 | def _tuple_of_ints(xs, ndim): 74 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 75 | assert len(xs) == ndim 76 | assert all(isinstance(x, int) for x in xs) 77 | return xs 78 | 79 | #---------------------------------------------------------------------------- 80 | 81 | _conv2d_gradfix_cache = dict() 82 | 83 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 84 | # Parse arguments. 85 | ndim = 2 86 | weight_shape = tuple(weight_shape) 87 | stride = _tuple_of_ints(stride, ndim) 88 | padding = _tuple_of_ints(padding, ndim) 89 | output_padding = _tuple_of_ints(output_padding, ndim) 90 | dilation = _tuple_of_ints(dilation, ndim) 91 | 92 | # Lookup from cache. 93 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 94 | if key in _conv2d_gradfix_cache: 95 | return _conv2d_gradfix_cache[key] 96 | 97 | # Validate arguments. 98 | assert groups >= 1 99 | assert len(weight_shape) == ndim + 2 100 | assert all(stride[i] >= 1 for i in range(ndim)) 101 | assert all(padding[i] >= 0 for i in range(ndim)) 102 | assert all(dilation[i] >= 0 for i in range(ndim)) 103 | if not transpose: 104 | assert all(output_padding[i] == 0 for i in range(ndim)) 105 | else: # transpose 106 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 107 | 108 | # Helpers. 109 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 110 | def calc_output_padding(input_shape, output_shape): 111 | if transpose: 112 | return [0, 0] 113 | return [ 114 | input_shape[i + 2] 115 | - (output_shape[i + 2] - 1) * stride[i] 116 | - (1 - 2 * padding[i]) 117 | - dilation[i] * (weight_shape[i + 2] - 1) 118 | for i in range(ndim) 119 | ] 120 | 121 | # Forward & backward. 122 | class Conv2d(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx, input, weight, bias): 125 | assert weight.shape == weight_shape 126 | if not transpose: 127 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | else: # transpose 129 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 130 | ctx.save_for_backward(input, weight) 131 | return output 132 | 133 | @staticmethod 134 | def backward(ctx, grad_output): 135 | input, weight = ctx.saved_tensors 136 | grad_input = None 137 | grad_weight = None 138 | grad_bias = None 139 | 140 | if ctx.needs_input_grad[0]: 141 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 142 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 143 | assert grad_input.shape == input.shape 144 | 145 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 146 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 147 | assert grad_weight.shape == weight_shape 148 | 149 | if ctx.needs_input_grad[2]: 150 | grad_bias = grad_output.sum([0, 2, 3]) 151 | 152 | return grad_input, grad_weight, grad_bias 153 | 154 | # Gradient with respect to the weights. 155 | class Conv2dGradWeight(torch.autograd.Function): 156 | @staticmethod 157 | def forward(ctx, grad_output, input): 158 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 159 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 160 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 161 | assert grad_weight.shape == weight_shape 162 | ctx.save_for_backward(grad_output, input) 163 | return grad_weight 164 | 165 | @staticmethod 166 | def backward(ctx, grad2_grad_weight): 167 | grad_output, input = ctx.saved_tensors 168 | grad2_grad_output = None 169 | grad2_input = None 170 | 171 | if ctx.needs_input_grad[0]: 172 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 173 | assert grad2_grad_output.shape == grad_output.shape 174 | 175 | if ctx.needs_input_grad[1]: 176 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 177 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 178 | assert grad2_input.shape == input.shape 179 | 180 | return grad2_grad_output, grad2_input 181 | 182 | _conv2d_gradfix_cache[key] = Conv2d 183 | return Conv2d 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | # pylint: enable=redefined-builtin 188 | # pylint: enable=arguments-differ 189 | # pylint: enable=protected-access 190 | # pylint: enable=line-too-long 191 | # pylint: enable=global-statement 192 | # pylint: enable=missing-class-docstring 193 | # pylint: enable=missing-function-docstring 194 | # pylint: enable=deprecated-module 195 | # pylint: enable=wrong-import-order 196 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """2D convolution with optional up/downsampling. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | 18 | import torch 19 | 20 | from . import misc 21 | from . import conv2d_gradfix 22 | from . import upfirdn2d 23 | from .upfirdn2d import _parse_padding 24 | from .upfirdn2d import _get_filter_size 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def _get_weight_shape(w): 29 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 30 | shape = [int(sz) for sz in w.shape] 31 | misc.assert_shape(w, shape) 32 | return shape 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): 37 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 38 | """ 39 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 40 | 41 | # Flip weight if requested. 42 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 43 | w = w.flip([2, 3]) 44 | 45 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 46 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 47 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 48 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 49 | if out_channels <= 4 and groups == 1: 50 | in_shape = x.shape 51 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 52 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 53 | else: 54 | x = x.to(memory_format=torch.contiguous_format) 55 | w = w.to(memory_format=torch.contiguous_format) 56 | x = conv2d_gradfix.conv2d(x, w, groups=groups, impl=impl) 57 | return x.to(memory_format=torch.channels_last) 58 | 59 | # Otherwise => execute using conv2d_gradfix. 60 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 61 | return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | @misc.profiled_function 66 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): 67 | r"""2D convolution with optional up/downsampling. 68 | 69 | Padding is performed only once at the beginning, not between the operations. 70 | 71 | Args: 72 | x: Input tensor of shape 73 | `[batch_size, in_channels, in_height, in_width]`. 74 | w: Weight tensor of shape 75 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 76 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 77 | calling upfirdn2d.setup_filter(). None = identity (default). 78 | up: Integer upsampling factor (default: 1). 79 | down: Integer downsampling factor (default: 1). 80 | padding: Padding with respect to the upsampled image. Can be a single number 81 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 82 | (default: 0). 83 | groups: Split input channels into N groups (default: 1). 84 | flip_weight: False = convolution, True = correlation (default: True). 85 | flip_filter: False = convolution, True = correlation (default: False). 86 | impl: Implementation mode of customized ops. 'ref' for native PyTorch 87 | implementation, 'cuda' for `.cu` implementation 88 | (default: 'cuda'). 89 | 90 | Returns: 91 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 92 | """ 93 | # Validate arguments. 94 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 95 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 96 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 97 | assert isinstance(up, int) and (up >= 1) 98 | assert isinstance(down, int) and (down >= 1) 99 | assert isinstance(groups, int) and (groups >= 1) 100 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 101 | fw, fh = _get_filter_size(f) 102 | px0, px1, py0, py1 = _parse_padding(padding) 103 | 104 | # Adjust padding to account for up/downsampling. 105 | if up > 1: 106 | px0 += (fw + up - 1) // 2 107 | px1 += (fw - up) // 2 108 | py0 += (fh + up - 1) // 2 109 | py1 += (fh - up) // 2 110 | if down > 1: 111 | px0 += (fw - down + 1) // 2 112 | px1 += (fw - down) // 2 113 | py0 += (fh - down + 1) // 2 114 | py1 += (fh - down) // 2 115 | 116 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 117 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 118 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) 119 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 120 | return x 121 | 122 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 123 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 124 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 125 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) 126 | return x 127 | 128 | # Fast path: downsampling only => use strided convolution. 129 | if down > 1 and up == 1: 130 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) 131 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) 132 | return x 133 | 134 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 135 | if up > 1: 136 | if groups == 1: 137 | w = w.transpose(0, 1) 138 | else: 139 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 140 | w = w.transpose(1, 2) 141 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 142 | px0 -= kw - 1 143 | px1 -= kw - up 144 | py0 -= kh - 1 145 | py1 -= kh - up 146 | pxt = max(min(-px0, -px1), 0) 147 | pyt = max(min(-py0, -py1), 0) 148 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) 149 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) 150 | if down > 1: 151 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) 152 | return x 153 | 154 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 155 | if up == 1 and down == 1: 156 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 157 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) 158 | 159 | # Fallback: Generic reference implementation. 160 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) 161 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 162 | if down > 1: 163 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) 164 | return x 165 | 166 | #---------------------------------------------------------------------------- 167 | 168 | # pylint: enable=line-too-long 169 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/custom_ops.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Utility functions to setup customized operators. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=missing-function-docstring 18 | # pylint: disable=useless-suppression 19 | # pylint: disable=inconsistent-quotes 20 | 21 | import os 22 | import glob 23 | import importlib 24 | import hashlib 25 | import shutil 26 | from pathlib import Path 27 | 28 | import torch 29 | from torch.utils.file_baton import FileBaton 30 | import torch.utils.cpp_extension 31 | 32 | #---------------------------------------------------------------------------- 33 | # Global options. 34 | 35 | verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' 36 | 37 | #---------------------------------------------------------------------------- 38 | # Internal helper funcs. 39 | 40 | def _find_compiler_bindir(): 41 | patterns = [ 42 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 43 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 44 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 45 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 46 | ] 47 | for pattern in patterns: 48 | matches = sorted(glob.glob(pattern)) 49 | if len(matches): 50 | return matches[-1] 51 | return None 52 | 53 | def _find_compiler_bindir_posix(): 54 | patterns = [ 55 | '/usr/local/cuda/bin' 56 | ] 57 | for pattern in patterns: 58 | matches = sorted(glob.glob(pattern)) 59 | if len(matches): 60 | return matches[-1] 61 | return None 62 | 63 | #---------------------------------------------------------------------------- 64 | # Main entry point for compiling and loading C++/CUDA plugins. 65 | 66 | _cached_plugins = dict() 67 | 68 | def get_plugin(module_name, sources, **build_kwargs): 69 | assert verbosity in ['none', 'brief', 'full'] 70 | 71 | # Already cached? 72 | if module_name in _cached_plugins: 73 | return _cached_plugins[module_name] 74 | 75 | # Print status. 76 | if verbosity == 'full': 77 | print(f'Setting up PyTorch plugin "{module_name}"...') 78 | elif verbosity == 'brief': 79 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 80 | 81 | try: # pylint: disable=too-many-nested-blocks 82 | # Make sure we can find the necessary compiler binaries. 83 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 84 | compiler_bindir = _find_compiler_bindir() 85 | if compiler_bindir is None: 86 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 87 | os.environ['PATH'] += ';' + compiler_bindir 88 | 89 | elif os.name == 'posix': 90 | compiler_bindir = _find_compiler_bindir_posix() 91 | if compiler_bindir is None: 92 | raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') 93 | os.environ['PATH'] += ';' + compiler_bindir 94 | 95 | # Compile and load. 96 | verbose_build = (verbosity == 'full') 97 | 98 | # Incremental build md5sum trickery. Copies all the input source files 99 | # into a cached build directory under a combined md5 digest of the input 100 | # source files. Copying is done only if the combined digest has changed. 101 | # This keeps input file timestamps and filenames the same as in previous 102 | # extension builds, allowing for fast incremental rebuilds. 103 | # 104 | # This optimization is done only in case all the source files reside in 105 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 106 | # environment variable is set (we take this as a signal that the user 107 | # actually cares about this.) 108 | source_dirs_set = set(os.path.dirname(source) for source in sources) 109 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 111 | 112 | # Compute a combined hash digest for all source files in the same 113 | # custom op directory (usually .cu, .cpp, .py and .h files). 114 | hash_md5 = hashlib.md5() 115 | for src in all_source_files: 116 | with open(src, 'rb') as f: 117 | hash_md5.update(f.read()) 118 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 119 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 120 | 121 | if not os.path.isdir(digest_build_dir): 122 | os.makedirs(digest_build_dir, exist_ok=True) 123 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 124 | if baton.try_acquire(): 125 | try: 126 | for src in all_source_files: 127 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 128 | finally: 129 | baton.release() 130 | else: 131 | # Someone else is copying source files under the digest dir, 132 | # wait until done and continue. 133 | baton.wait() 134 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 135 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 136 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 137 | else: 138 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 139 | module = importlib.import_module(module_name) 140 | 141 | except: 142 | if verbosity == 'brief': 143 | print('Failed!') 144 | raise 145 | 146 | # Print status and add to cache. 147 | if verbosity == 'full': 148 | print(f'Done setting up PyTorch plugin "{module_name}".') 149 | elif verbosity == 'brief': 150 | print('Done.') 151 | _cached_plugins[module_name] = module 152 | return module 153 | 154 | #---------------------------------------------------------------------------- 155 | 156 | # pylint: enable=line-too-long 157 | # pylint: enable=missing-function-docstring 158 | # pylint: enable=useless-suppression 159 | # pylint: enable=inconsistent-quotes 160 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/fma.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=missing-function-docstring 18 | 19 | import torch 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def fma(a, b, c, impl='cuda'): # => a * b + c 24 | if impl == 'cuda': 25 | return _FusedMultiplyAdd.apply(a, b, c) 26 | return torch.addcmul(c, a, b) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 31 | @staticmethod 32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 33 | out = torch.addcmul(c, a, b) 34 | ctx.save_for_backward(a, b) 35 | ctx.c_shape = c.shape 36 | return out 37 | 38 | @staticmethod 39 | def backward(ctx, dout): # pylint: disable=arguments-differ 40 | a, b = ctx.saved_tensors 41 | c_shape = ctx.c_shape 42 | da = None 43 | db = None 44 | dc = None 45 | 46 | if ctx.needs_input_grad[0]: 47 | da = _unbroadcast(dout * b, a.shape) 48 | 49 | if ctx.needs_input_grad[1]: 50 | db = _unbroadcast(dout * a, b.shape) 51 | 52 | if ctx.needs_input_grad[2]: 53 | dc = _unbroadcast(dout, c_shape) 54 | 55 | return da, db, dc 56 | 57 | #---------------------------------------------------------------------------- 58 | 59 | def _unbroadcast(x, shape): 60 | extra_dims = x.ndim - len(shape) 61 | assert extra_dims >= 0 62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 63 | if len(dim): 64 | x = x.sum(dim=dim, keepdim=True) 65 | if extra_dims: 66 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 67 | assert x.shape == shape 68 | return x 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | # pylint: enable=line-too-long 73 | # pylint: enable=missing-function-docstring 74 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample`. 12 | 13 | This is useful for differentiable augmentation. This customized operator 14 | supports arbitrarily high order gradients between the input and output. Only 15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and 16 | `align_corners=False`. 17 | 18 | Please refer to https://github.com/NVlabs/stylegan2-ada-pytorch 19 | """ 20 | 21 | # pylint: disable=redefined-builtin 22 | # pylint: disable=arguments-differ 23 | # pylint: disable=protected-access 24 | # pylint: disable=line-too-long 25 | # pylint: disable=missing-function-docstring 26 | # pylint: disable=deprecated-module 27 | # pylint: disable=wrong-import-order 28 | 29 | import warnings 30 | import torch 31 | from distutils.version import LooseVersion 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | enabled = True # Enable the custom op by setting this to true. 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def grid_sample(input, grid, impl='cuda'): 40 | if impl == 'cuda' and _should_use_custom_op(): 41 | return _GridSample2dForward.apply(input, grid) 42 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def _should_use_custom_op(): 47 | if not enabled: 48 | return False 49 | if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): 50 | return True 51 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 52 | return False 53 | 54 | #---------------------------------------------------------------------------- 55 | 56 | class _GridSample2dForward(torch.autograd.Function): 57 | @staticmethod 58 | def forward(ctx, input, grid): 59 | assert input.ndim == 4 60 | assert grid.ndim == 4 61 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 62 | ctx.save_for_backward(input, grid) 63 | return output 64 | 65 | @staticmethod 66 | def backward(ctx, grad_output): 67 | input, grid = ctx.saved_tensors 68 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 69 | return grad_input, grad_grid 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | class _GridSample2dBackward(torch.autograd.Function): 74 | @staticmethod 75 | def forward(ctx, grad_output, input, grid): 76 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 77 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 78 | ctx.save_for_backward(grid) 79 | return grad_input, grad_grid 80 | 81 | @staticmethod 82 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 83 | _ = grad2_grad_grid # unused 84 | grid, = ctx.saved_tensors 85 | grad2_grad_output = None 86 | grad2_input = None 87 | grad2_grid = None 88 | 89 | if ctx.needs_input_grad[0]: 90 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 91 | 92 | assert not ctx.needs_input_grad[2] 93 | return grad2_grad_output, grad2_input, grad2_grid 94 | 95 | #---------------------------------------------------------------------------- 96 | 97 | # pylint: enable=redefined-builtin 98 | # pylint: enable=arguments-differ 99 | # pylint: enable=protected-access 100 | # pylint: enable=line-too-long 101 | # pylint: enable=missing-function-docstring 102 | # pylint: enable=deprecated-module 103 | # pylint: enable=wrong-import-order 104 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /third_party/stylegan2_official_ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/README.md: -------------------------------------------------------------------------------- 1 | # Operators for StyleGAN2 2 | 3 | All files in this directory are borrowed from repository [stylegan3](https://github.com/NVlabs/stylegan3). Basically, these files implement customized operators, which are faster than the native operators from PyTorch, especially for second-derivative computation, including 4 | 5 | - `bias_act.bias_act()`: Fuse adding bias and then performing activation as one operator. 6 | - `upfirdn2d.setup_filter()`: Set up the kernel used for filtering. 7 | - `upfirdn2d.filter2d()`: Filtering a 2D feature map with given kernel. 8 | - `upfirdn2d.upsample2d()`: Upsampling a 2D feature map. 9 | - `upfirdn2d.downsample2d()`: Downsampling a 2D feature map. 10 | - `upfirdn2d.upfirdn2d()`: Upsampling, filtering, and then downsampling a 2D feature map. 11 | - `filtered_lrelu.filtered_lrelu()`: Leaky ReLU layer, wrapped with upsampling and downsampling for anti-aliasing. 12 | - `conv2d_gradfix.conv2d()`: Convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 13 | - `conv2d_gradfix.conv_transpose2d()`: Transposed convolutional layer, supporting arbitrarily high order gradients and fixing gradient when computing penalty. 14 | - `conv2d_resample.conv2d_resample()`: Wraps `upfirdn2d()` and `conv2d()` (or `conv_transpose2d()`). This is not used in our network implementation (*i.e.*, `models/stylegan2_generator.py` and `models/stylegan2_discriminator.py`) 15 | 16 | We make following slight modifications beyond disabling some lint warnings: 17 | 18 | - Line 24 of file `misc.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). 19 | - Line 36 of file `custom_ops.py`: Disable log message when setting up customized operators. 20 | - Line 54/109 of file `custom_ops.py`: Add necessary CUDA compiler path. (***NOTE**: If your cuda binary does not locate at `/usr/local/cuda/bin`, please specify in function `_find_compiler_bindir_posix()`.*) 21 | - Line 21 of file `bias_act.py`: Use `EasyDict` from module `easydict` to replace that from `dnnlib` from [stylegan3](https://github.com/NVlabs/stylegan3). 22 | - Line 162-165 of file `filtered_lrelu.py`: Change some implementations in `_filtered_lrelu_ref()` to `ref`. 23 | - Line 31 of file `grid_sample_gradfix.py`: Enable customized grid sampling operator by default. 24 | - Line 35 of file `grid_sample_gradfix.py`: Use `impl` to disable customized grid sample operator. 25 | - Line 34 of file `conv2d_gradfix.py`: Enable customized convolution operators by default. 26 | - Line 48/53 of file `conv2d_gradfix.py`: Use `impl` to disable customized convolution operators. 27 | - Line 36/53 of file `conv2d_resample.py`: Use `impl` to disable customized convolution operators. 28 | - Line 23 of file `fma.py`: Use `impl` to disable customized add-multiply operator. 29 | 30 | Please use `ref` or `cuda` to choose which implementation to use. `ref` refers to native PyTorch operators while `cuda` refers to the customized operators from the official repository. `cuda` is used by default. 31 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/third_party/stylegan3_official_ops/__init__.py -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """2D convolution with optional up/downsampling. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan3 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | 18 | import torch 19 | 20 | from . import misc 21 | from . import conv2d_gradfix 22 | from . import upfirdn2d 23 | from .upfirdn2d import _parse_padding 24 | from .upfirdn2d import _get_filter_size 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def _get_weight_shape(w): 29 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 30 | shape = [int(sz) for sz in w.shape] 31 | misc.assert_shape(w, shape) 32 | return shape 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True, impl='cuda'): 37 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 38 | """ 39 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 40 | 41 | # Flip weight if requested. 42 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 43 | if not flip_weight and (kw > 1 or kh > 1): 44 | w = w.flip([2, 3]) 45 | 46 | # Execute using conv2d_gradfix. 47 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 48 | return op(x, w, stride=stride, padding=padding, groups=groups, impl=impl) 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | @misc.profiled_function 53 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False, impl='cuda'): 54 | r"""2D convolution with optional up/downsampling. 55 | 56 | Padding is performed only once at the beginning, not between the operations. 57 | 58 | Args: 59 | x: Input tensor of shape 60 | `[batch_size, in_channels, in_height, in_width]`. 61 | w: Weight tensor of shape 62 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 63 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 64 | calling upfirdn2d.setup_filter(). None = identity (default). 65 | up: Integer upsampling factor (default: 1). 66 | down: Integer downsampling factor (default: 1). 67 | padding: Padding with respect to the upsampled image. Can be a single number 68 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 69 | (default: 0). 70 | groups: Split input channels into N groups (default: 1). 71 | flip_weight: False = convolution, True = correlation (default: True). 72 | flip_filter: False = convolution, True = correlation (default: False). 73 | impl: Implementation mode, 'cuda' for CUDA implementation, and 'ref' for 74 | native PyTorch implementation (default: 'cuda'). 75 | 76 | Returns: 77 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 78 | """ 79 | # Validate arguments. 80 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 81 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 82 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 83 | assert isinstance(up, int) and (up >= 1) 84 | assert isinstance(down, int) and (down >= 1) 85 | assert isinstance(groups, int) and (groups >= 1) 86 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 87 | fw, fh = _get_filter_size(f) 88 | px0, px1, py0, py1 = _parse_padding(padding) 89 | 90 | # Adjust padding to account for up/downsampling. 91 | if up > 1: 92 | px0 += (fw + up - 1) // 2 93 | px1 += (fw - up) // 2 94 | py0 += (fh + up - 1) // 2 95 | py1 += (fh - up) // 2 96 | if down > 1: 97 | px0 += (fw - down + 1) // 2 98 | px1 += (fw - down) // 2 99 | py0 += (fh - down + 1) // 2 100 | py1 += (fh - down) // 2 101 | 102 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 103 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 104 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) 105 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 106 | return x 107 | 108 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 109 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 110 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 111 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) 112 | return x 113 | 114 | # Fast path: downsampling only => use strided convolution. 115 | if down > 1 and up == 1: 116 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter, impl=impl) 117 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight, impl=impl) 118 | return x 119 | 120 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 121 | if up > 1: 122 | if groups == 1: 123 | w = w.transpose(0, 1) 124 | else: 125 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 126 | w = w.transpose(1, 2) 127 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 128 | px0 -= kw - 1 129 | px1 -= kw - up 130 | py0 -= kh - 1 131 | py1 -= kh - up 132 | pxt = max(min(-px0, -px1), 0) 133 | pyt = max(min(-py0, -py1), 0) 134 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight), impl=impl) 135 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter, impl=impl) 136 | if down > 1: 137 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) 138 | return x 139 | 140 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 141 | if up == 1 and down == 1: 142 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 143 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight, impl=impl) 144 | 145 | # Fallback: Generic reference implementation. 146 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter, impl=impl) 147 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight, impl=impl) 148 | if down > 1: 149 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter, impl=impl) 150 | return x 151 | 152 | #---------------------------------------------------------------------------- 153 | 154 | # pylint: enable=line-too-long 155 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/custom_ops.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Utility functions to setup customized operators. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan3 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=multiple-statements 18 | # pylint: disable=missing-function-docstring 19 | # pylint: disable=useless-suppression 20 | # pylint: disable=inconsistent-quotes 21 | 22 | import glob 23 | import hashlib 24 | import importlib 25 | import os 26 | import re 27 | import shutil 28 | import uuid 29 | 30 | import torch 31 | import torch.utils.cpp_extension 32 | 33 | #---------------------------------------------------------------------------- 34 | # Global options. 35 | 36 | verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' 37 | 38 | #---------------------------------------------------------------------------- 39 | # Internal helper funcs. 40 | 41 | def _find_compiler_bindir(): 42 | patterns = [ 43 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 44 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 45 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 46 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 47 | ] 48 | for pattern in patterns: 49 | matches = sorted(glob.glob(pattern)) 50 | if len(matches): 51 | return matches[-1] 52 | return None 53 | 54 | def _find_compiler_bindir_posix(): 55 | patterns = [ 56 | '/usr/local/cuda/bin' 57 | ] 58 | for pattern in patterns: 59 | matches = sorted(glob.glob(pattern)) 60 | if len(matches): 61 | return matches[-1] 62 | return None 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | def _get_mangled_gpu_name(): 67 | name = torch.cuda.get_device_name().lower() 68 | out = [] 69 | for c in name: 70 | if re.match('[a-z0-9_-]+', c): 71 | out.append(c) 72 | else: 73 | out.append('-') 74 | return ''.join(out) 75 | 76 | #---------------------------------------------------------------------------- 77 | # Main entry point for compiling and loading C++/CUDA plugins. 78 | 79 | _cached_plugins = dict() 80 | 81 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 82 | assert verbosity in ['none', 'brief', 'full'] 83 | if headers is None: 84 | headers = [] 85 | if source_dir is not None: 86 | sources = [os.path.join(source_dir, fname) for fname in sources] 87 | headers = [os.path.join(source_dir, fname) for fname in headers] 88 | 89 | # Already cached? 90 | if module_name in _cached_plugins: 91 | return _cached_plugins[module_name] 92 | 93 | # Print status. 94 | if verbosity == 'full': 95 | print(f'Setting up PyTorch plugin "{module_name}"...') 96 | elif verbosity == 'brief': 97 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 98 | verbose_build = (verbosity == 'full') 99 | 100 | # Compile and load. 101 | try: # pylint: disable=too-many-nested-blocks 102 | # Make sure we can find the necessary compiler binaries. 103 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 104 | compiler_bindir = _find_compiler_bindir() 105 | if compiler_bindir is None: 106 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 107 | os.environ['PATH'] += ';' + compiler_bindir 108 | 109 | elif os.name == 'posix': 110 | compiler_bindir = _find_compiler_bindir_posix() 111 | if compiler_bindir is None: 112 | raise RuntimeError(f'Could not find NVCC installation on this computer. Check _find_compiler_bindir_posix() in "{__file__}".') 113 | os.environ['PATH'] += ';' + compiler_bindir 114 | 115 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 116 | # break the build or unnecessarily restrict what's available to nvcc. 117 | # Unset it to let nvcc decide based on what's available on the 118 | # machine. 119 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 120 | 121 | # Incremental build md5sum trickery. Copies all the input source files 122 | # into a cached build directory under a combined md5 digest of the input 123 | # source files. Copying is done only if the combined digest has changed. 124 | # This keeps input file timestamps and filenames the same as in previous 125 | # extension builds, allowing for fast incremental rebuilds. 126 | # 127 | # This optimization is done only in case all the source files reside in 128 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 129 | # environment variable is set (we take this as a signal that the user 130 | # actually cares about this.) 131 | # 132 | # EDIT: We now do it regardless of TORCH_EXTENSIONS_DIR, in order to work 133 | # around the *.cu dependency bug in ninja config. 134 | # 135 | all_source_files = sorted(sources + headers) 136 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 137 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 138 | 139 | # Compute combined hash digest for all source files. 140 | hash_md5 = hashlib.md5() 141 | for src in all_source_files: 142 | with open(src, 'rb') as f: 143 | hash_md5.update(f.read()) 144 | 145 | # Select cached build directory name. 146 | source_digest = hash_md5.hexdigest() 147 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 148 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 149 | 150 | if not os.path.isdir(cached_build_dir): 151 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 152 | os.makedirs(tmpdir) 153 | for src in all_source_files: 154 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 155 | try: 156 | os.replace(tmpdir, cached_build_dir) # atomic 157 | except OSError: 158 | # source directory already exists, delete tmpdir and its contents. 159 | shutil.rmtree(tmpdir) 160 | if not os.path.isdir(cached_build_dir): raise 161 | 162 | # Compile. 163 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 164 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 165 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 166 | else: 167 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 168 | 169 | # Load. 170 | module = importlib.import_module(module_name) 171 | 172 | except: 173 | if verbosity == 'brief': 174 | print('Failed!') 175 | raise 176 | 177 | # Print status and add to cache dict. 178 | if verbosity == 'full': 179 | print(f'Done setting up PyTorch plugin "{module_name}".') 180 | elif verbosity == 'brief': 181 | print('Done.') 182 | _cached_plugins[module_name] = module 183 | return module 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | # pylint: enable=line-too-long 188 | # pylint: enable=multiple-statements 189 | # pylint: enable=missing-function-docstring 190 | # pylint: enable=useless-suppression 191 | # pylint: enable=inconsistent-quotes 192 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/fma.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`. 12 | 13 | Please refer to https://github.com/NVlabs/stylegan3 14 | """ 15 | 16 | # pylint: disable=line-too-long 17 | # pylint: disable=missing-function-docstring 18 | 19 | import torch 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def fma(a, b, c, impl='cuda'): # => a * b + c 24 | if impl == 'cuda': 25 | return _FusedMultiplyAdd.apply(a, b, c) 26 | return torch.addcmul(c, a, b) 27 | 28 | #---------------------------------------------------------------------------- 29 | 30 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 31 | @staticmethod 32 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 33 | out = torch.addcmul(c, a, b) 34 | ctx.save_for_backward(a, b) 35 | ctx.c_shape = c.shape 36 | return out 37 | 38 | @staticmethod 39 | def backward(ctx, dout): # pylint: disable=arguments-differ 40 | a, b = ctx.saved_tensors 41 | c_shape = ctx.c_shape 42 | da = None 43 | db = None 44 | dc = None 45 | 46 | if ctx.needs_input_grad[0]: 47 | da = _unbroadcast(dout * b, a.shape) 48 | 49 | if ctx.needs_input_grad[1]: 50 | db = _unbroadcast(dout * a, b.shape) 51 | 52 | if ctx.needs_input_grad[2]: 53 | dc = _unbroadcast(dout, c_shape) 54 | 55 | return da, db, dc 56 | 57 | #---------------------------------------------------------------------------- 58 | 59 | def _unbroadcast(x, shape): 60 | extra_dims = x.ndim - len(shape) 61 | assert extra_dims >= 0 62 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 63 | if len(dim): 64 | x = x.sum(dim=dim, keepdim=True) 65 | if extra_dims: 66 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 67 | assert x.shape == shape 68 | return x 69 | 70 | #---------------------------------------------------------------------------- 71 | 72 | # pylint: enable=line-too-long 73 | # pylint: enable=missing-function-docstring 74 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | 3 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 4 | # 5 | # NVIDIA CORPORATION and its licensors retain all intellectual property 6 | # and proprietary rights in and to this software, related documentation 7 | # and any modifications thereto. Any use, reproduction, disclosure or 8 | # distribution of this software and related documentation without an express 9 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 10 | 11 | """Custom replacement for `torch.nn.functional.grid_sample`. 12 | 13 | This is useful for differentiable augmentation. This customized operator 14 | supports arbitrarily high order gradients between the input and output. Only 15 | works on 2D images and assumes `mode=bilinear`, `padding_mode=zeros`, and 16 | `align_corners=False`. 17 | 18 | Please refer to https://github.com/NVlabs/stylegan3 19 | """ 20 | 21 | # pylint: disable=redefined-builtin 22 | # pylint: disable=arguments-differ 23 | # pylint: disable=protected-access 24 | # pylint: disable=line-too-long 25 | # pylint: disable=missing-function-docstring 26 | 27 | import torch 28 | from pkg_resources import parse_version 29 | 30 | #---------------------------------------------------------------------------- 31 | 32 | enabled = True # Enable the custom op by setting this to true. 33 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 34 | #---------------------------------------------------------------------------- 35 | 36 | def grid_sample(input, grid, impl='cuda'): 37 | if impl == 'cuda' and _should_use_custom_op(): 38 | return _GridSample2dForward.apply(input, grid) 39 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def _should_use_custom_op(): 44 | return enabled 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | class _GridSample2dForward(torch.autograd.Function): 49 | @staticmethod 50 | def forward(ctx, input, grid): 51 | assert input.ndim == 4 52 | assert grid.ndim == 4 53 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 54 | ctx.save_for_backward(input, grid) 55 | return output 56 | 57 | @staticmethod 58 | def backward(ctx, grad_output): 59 | input, grid = ctx.saved_tensors 60 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 61 | return grad_input, grad_grid 62 | 63 | #---------------------------------------------------------------------------- 64 | 65 | class _GridSample2dBackward(torch.autograd.Function): 66 | @staticmethod 67 | def forward(ctx, grad_output, input, grid): 68 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 69 | if _use_pytorch_1_11_api: 70 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 71 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 72 | else: 73 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 74 | ctx.save_for_backward(grid) 75 | return grad_input, grad_grid 76 | 77 | @staticmethod 78 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 79 | _ = grad2_grad_grid # unused 80 | grid, = ctx.saved_tensors 81 | grad2_grad_output = None 82 | grad2_input = None 83 | grad2_grid = None 84 | 85 | if ctx.needs_input_grad[0]: 86 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 87 | 88 | assert not ctx.needs_input_grad[2] 89 | return grad2_grad_output, grad2_input, grad2_grid 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | # pylint: enable=redefined-builtin 94 | # pylint: enable=arguments-differ 95 | # pylint: enable=protected-access 96 | # pylint: enable=line-too-long 97 | # pylint: enable=missing-function-docstring 98 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /third_party/stylegan3_official_ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ant-research/Aurora/c911413609275bdd31d49032ea0478427f237bd3/utils/__init__.py -------------------------------------------------------------------------------- /utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains utility functions used for distribution.""" 3 | 4 | import contextlib 5 | import os 6 | import subprocess 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | 12 | __all__ = ['init_dist', 'exit_dist', 'ddp_sync', 'get_ddp_module'] 13 | 14 | 15 | def init_dist(launcher, backend='nccl', **kwargs): 16 | """Initializes distributed environment.""" 17 | if mp.get_start_method(allow_none=True) is None: 18 | mp.set_start_method('spawn') 19 | if launcher == 'pytorch': 20 | rank = int(os.environ['RANK']) 21 | num_gpus = torch.cuda.device_count() 22 | torch.cuda.set_device(rank % num_gpus) 23 | dist.init_process_group(backend=backend, **kwargs) 24 | elif launcher == 'slurm': 25 | proc_id = int(os.environ['SLURM_PROCID']) 26 | ntasks = int(os.environ['SLURM_NTASKS']) 27 | node_list = os.environ['SLURM_NODELIST'] 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(proc_id % num_gpus) 30 | addr = subprocess.getoutput( 31 | f'scontrol show hostname {node_list} | head -n1') 32 | port = os.environ.get('PORT', 29500) 33 | os.environ['MASTER_PORT'] = str(port) 34 | os.environ['MASTER_ADDR'] = addr 35 | os.environ['WORLD_SIZE'] = str(ntasks) 36 | os.environ['RANK'] = str(proc_id) 37 | dist.init_process_group(backend=backend) 38 | else: 39 | raise NotImplementedError(f'Not implemented launcher type: ' 40 | f'`{launcher}`!') 41 | 42 | 43 | def exit_dist(): 44 | """Exits the distributed environment.""" 45 | if dist.is_initialized(): 46 | dist.destroy_process_group() 47 | 48 | 49 | @contextlib.contextmanager 50 | def ddp_sync(model, sync): 51 | """Controls whether the `DistributedDataParallel` model should be synced.""" 52 | assert isinstance(model, torch.nn.Module) 53 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) 54 | if sync or not is_ddp: 55 | yield 56 | else: 57 | with model.no_sync(): 58 | yield 59 | 60 | 61 | def get_ddp_module(model): 62 | """Gets the module from `DistributedDataParallel`.""" 63 | assert isinstance(model, torch.nn.Module) 64 | is_ddp = isinstance(model, torch.nn.parallel.DistributedDataParallel) 65 | if is_ddp: 66 | return model.module 67 | return model 68 | -------------------------------------------------------------------------------- /utils/file_transmitters/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all file transmitters.""" 3 | 4 | from .local_file_transmitter import LocalFileTransmitter 5 | from .dummy_file_transmitter import DummyFileTransmitter 6 | 7 | __all__ = ['build_file_transmitter'] 8 | 9 | _TRANSMITTERS = { 10 | 'local': LocalFileTransmitter, 11 | 'dummy': DummyFileTransmitter, 12 | } 13 | 14 | 15 | def build_file_transmitter(transmitter_type='local', **kwargs): 16 | """Builds a file transmitter. 17 | 18 | Args: 19 | transmitter_type: Type of the file transmitter_type, which is case 20 | insensitive. (default: `normal`) 21 | **kwargs: Additional arguments to build the file transmitter. 22 | 23 | Raises: 24 | ValueError: If the `transmitter_type` is not supported. 25 | """ 26 | transmitter_type = transmitter_type.lower() 27 | if transmitter_type not in _TRANSMITTERS: 28 | raise ValueError(f'Invalid transmitter type: `{transmitter_type}`!\n' 29 | f'Types allowed: {list(_TRANSMITTERS)}.') 30 | return _TRANSMITTERS[transmitter_type](**kwargs) 31 | -------------------------------------------------------------------------------- /utils/file_transmitters/base_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the base class to transmit files across file systems. 3 | 4 | Basically, a file transmitter connects the local file system, on which the 5 | programme runs, to a remote file system. This is particularly used for 6 | (1) pulling files that are required by the programme from remote, and 7 | (2) pushing results that are produced by the programme to remote. In this way, 8 | the programme can focus on local file system only. 9 | 10 | NOTE: The remote file system can be the same as the local file system, since 11 | users may want to transmit files across directories. 12 | """ 13 | 14 | import warnings 15 | 16 | __all__ = ['BaseFileTransmitter'] 17 | 18 | 19 | class BaseFileTransmitter(object): 20 | """Defines the base file transmitter. 21 | 22 | A transmitter should have the following functions: 23 | 24 | (1) pull(): The function to pull a file/directory from remote to local. 25 | (2) push(): The function to push a file/directory from local to remote. 26 | (3) remove(): The function to remove a file/directory. 27 | (4) make_remote_dir(): Make directory remotely. 28 | 29 | 30 | To simplify, each derived class just need to implement the following helper 31 | functions: 32 | 33 | (1) download_hard(): Hard download a file/directory from remote to local. 34 | (2) download_soft(): Soft download a file/directory from remote to local. 35 | This is especially used to save space (e.g., soft link). 36 | (3) upload(): Upload a file/directory from local to remote. 37 | (4) delete(): Delete a file/directory according to given path. 38 | """ 39 | 40 | def __init__(self): 41 | pass 42 | 43 | @property 44 | def name(self): 45 | """Returns the class name of the file transmitter.""" 46 | return self.__class__.__name__ 47 | 48 | @staticmethod 49 | def download_hard(src, dst): 50 | """Downloads (in hard mode) a file/directory from remote to local.""" 51 | raise NotImplementedError('Should be implemented in derived class!') 52 | 53 | @staticmethod 54 | def download_soft(src, dst): 55 | """Downloads (in soft mode) a file/directory from local to remote.""" 56 | raise NotImplementedError('Should be implemented in derived class!') 57 | 58 | @staticmethod 59 | def upload(src, dst): 60 | """Uploads a file/directory from local to remote.""" 61 | raise NotImplementedError('Should be implemented in derived class!') 62 | 63 | @staticmethod 64 | def delete(path): 65 | """Deletes the given path.""" 66 | # TODO: should we secure the path to avoid mis-removing / attacks? 67 | raise NotImplementedError('Should be implemented in derived class!') 68 | 69 | def pull(self, src, dst, hard=False): 70 | """Pulls a file/directory from remote to local. 71 | 72 | The argument `hard` is to control the download mode (hard or soft). 73 | For example, the hard mode may hardly copy the file while the soft mode 74 | may softly link the file. 75 | """ 76 | if hard: 77 | self.download_hard(src, dst) 78 | else: 79 | self.download_soft(src, dst) 80 | 81 | def push(self, src, dst): 82 | """Pushes a file/directory from local to remote.""" 83 | self.upload(src, dst) 84 | 85 | def remove(self, path): 86 | """Removes the given path.""" 87 | warnings.warn(f'`{path}` will be removed!') 88 | self.delete(path) 89 | 90 | def make_remote_dir(self, directory): 91 | """Makes a directory on the remote system.""" 92 | raise NotImplementedError('Should be implemented in derived class!') 93 | -------------------------------------------------------------------------------- /utils/file_transmitters/dummy_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of dummy file transmitter. 3 | 4 | This file transmitter has all expected data transmission functions but behaves 5 | silently, which is very useful in multi-processing mode. Only the chief process 6 | can have the file transmitter with normal behavior. 7 | """ 8 | 9 | from .base_file_transmitter import BaseFileTransmitter 10 | 11 | __all__ = ['DummyFileTransmitter'] 12 | 13 | 14 | class DummyFileTransmitter(BaseFileTransmitter): 15 | """Implements a dummy transmitter which transmits nothing.""" 16 | 17 | @staticmethod 18 | def download_hard(src, dst): 19 | return 20 | 21 | @staticmethod 22 | def download_soft(src, dst): 23 | return 24 | 25 | @staticmethod 26 | def upload(src, dst): 27 | return 28 | 29 | @staticmethod 30 | def delete(path): 31 | return 32 | 33 | def make_remote_dir(self, directory): 34 | return 35 | -------------------------------------------------------------------------------- /utils/file_transmitters/local_file_transmitter.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of local file transmitter. 3 | 4 | The transmitter builds the connection between the local file system and itself. 5 | This can be used to transmit files from one directory to another. Consequently, 6 | `remote` in this file also means `local`. 7 | """ 8 | 9 | from utils.misc import print_and_execute 10 | from .base_file_transmitter import BaseFileTransmitter 11 | 12 | __all__ = ['LocalFileTransmitter'] 13 | 14 | 15 | class LocalFileTransmitter(BaseFileTransmitter): 16 | """Implements the transmitter connecting local file system to itself.""" 17 | 18 | @staticmethod 19 | def download_hard(src, dst): 20 | print_and_execute(f'cp {src} {dst}') 21 | 22 | @staticmethod 23 | def download_soft(src, dst): 24 | print_and_execute(f'ln -s {src} {dst}') 25 | 26 | @staticmethod 27 | def upload(src, dst): 28 | print_and_execute(f'cp {src} {dst}') 29 | 30 | @staticmethod 31 | def delete(path): 32 | print_and_execute(f'rm -r {path}') 33 | 34 | def make_remote_dir(self, directory): 35 | print_and_execute(f'mkdir -p {directory}') 36 | -------------------------------------------------------------------------------- /utils/formatting_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains utility functions used for formatting.""" 3 | 4 | import cv2 5 | import numpy as np 6 | 7 | __all__ = [ 8 | 'format_time', 'format_range', 'format_image_size', 'format_image', 9 | 'raw_label_to_one_hot', 'one_hot_to_raw_label' 10 | ] 11 | 12 | 13 | def format_time(seconds): 14 | """Formats seconds to readable time string. 15 | 16 | Args: 17 | seconds: Number of seconds to format. 18 | 19 | Returns: 20 | The formatted time string. 21 | 22 | Raises: 23 | ValueError: If the input `seconds` is less than 0. 24 | """ 25 | if seconds < 0: 26 | raise ValueError(f'Input `seconds` should be greater than or equal to ' 27 | f'0, but `{seconds}` is received!') 28 | 29 | # Returns seconds as float if less than 1 minute. 30 | if seconds < 10: 31 | return f'{seconds:7.3f} s' 32 | if seconds < 60: 33 | return f'{seconds:7.2f} s' 34 | 35 | seconds = int(seconds + 0.5) 36 | days, seconds = divmod(seconds, 86400) 37 | hours, seconds = divmod(seconds, 3600) 38 | minutes, seconds = divmod(seconds, 60) 39 | if days: 40 | return f'{days:2d} d {hours:02d} h' 41 | if hours: 42 | return f'{hours:2d} h {minutes:02d} m' 43 | return f'{minutes:2d} m {seconds:02d} s' 44 | 45 | 46 | def format_range(obj, min_val=None, max_val=None): 47 | """Formats the given object to a valid range. 48 | 49 | If `min_val` or `max_val` is provided, both the starting value and the end 50 | value will be clamped to range `[min_val, max_val]`. 51 | 52 | NOTE: (a, b) is regarded as a valid range if and only if `a <= b`. 53 | 54 | Args: 55 | obj: The input object to format. 56 | min_val: The minimum value to cut off the input range. If not provided, 57 | the default minimum value is negative infinity. (default: None) 58 | max_val: The maximum value to cut off the input range. If not provided, 59 | the default maximum value is infinity. (default: None) 60 | 61 | Returns: 62 | A two-elements tuple, indicating the start and the end of the range. 63 | 64 | Raises: 65 | ValueError: If the input object is an invalid range. 66 | """ 67 | if not isinstance(obj, (tuple, list)): 68 | raise ValueError(f'Input object must be a tuple or a list, ' 69 | f'but `{type(obj)}` received!') 70 | if len(obj) != 2: 71 | raise ValueError(f'Input object is expected to contain two elements, ' 72 | f'but `{len(obj)}` received!') 73 | if obj[0] > obj[1]: 74 | raise ValueError(f'The second element is expected to be equal to or ' 75 | f'greater than the first one, ' 76 | f'but `({obj[0]}, {obj[1]})` received!') 77 | 78 | obj = list(obj) 79 | if min_val is not None: 80 | obj[0] = max(obj[0], min_val) 81 | obj[1] = max(obj[1], min_val) 82 | if max_val is not None: 83 | obj[0] = min(obj[0], max_val) 84 | obj[1] = min(obj[1], max_val) 85 | return tuple(obj) 86 | 87 | 88 | def format_image_size(size): 89 | """Formats the given image size to a two-element tuple. 90 | 91 | A valid image size can be an integer, indicating both the height and the 92 | width, OR can be a two-element list or tuple. Both height and width are 93 | assumed to be positive integer. 94 | 95 | Args: 96 | size: The input size to format. 97 | 98 | Returns: 99 | A two-elements tuple, indicating the height and the width, respectively. 100 | 101 | Raises: 102 | ValueError: If the input size is invalid. 103 | """ 104 | if not isinstance(size, (int, tuple, list)): 105 | raise ValueError(f'Input size must be an integer, a tuple, or a list, ' 106 | f'but `{type(size)}` received!') 107 | if isinstance(size, int): 108 | size = (size, size) 109 | else: 110 | if len(size) == 1: 111 | size = (size[0], size[0]) 112 | if not len(size) == 2: 113 | raise ValueError(f'Input size is expected to have two numbers at ' 114 | f'most, but `{len(size)}` numbers received!') 115 | if not isinstance(size[0], int) or size[0] < 0: 116 | raise ValueError(f'The height is expected to be a non-negative ' 117 | f'integer, but `{size[0]}` received!') 118 | if not isinstance(size[1], int) or size[1] < 0: 119 | raise ValueError(f'The width is expected to be a non-negative ' 120 | f'integer, but `{size[1]}` received!') 121 | return tuple(size) 122 | 123 | 124 | def format_image(image): 125 | """Formats an image read from `cv2`. 126 | 127 | NOTE: This function will always return a 3-dimensional image (i.e., with 128 | shape [H, W, C]) in pixel range [0, 255]. For color images, the channel 129 | order of the input is expected to be with `BGR` or `BGRA`, which is the 130 | raw image decoded by `cv2`; while the channel order of the output is set to 131 | `RGB` or `RGBA` by default. 132 | 133 | Args: 134 | image: `np.ndarray`, an image read by `cv2.imread()` or 135 | `cv2.imdecode()`. 136 | 137 | Returns: 138 | An image with shape [H, W, C] (where `C = 1` for grayscale image). 139 | """ 140 | if image.ndim == 2: # add additional axis if given a grayscale image 141 | image = image[:, :, np.newaxis] 142 | 143 | assert isinstance(image, np.ndarray) 144 | assert image.dtype == np.uint8 145 | assert image.ndim == 3 and image.shape[2] in [1, 3, 4] 146 | 147 | if image.shape[2] == 3: # BGR image 148 | return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 149 | if image.shape[2] == 4: # BGRA image 150 | return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) 151 | return image 152 | 153 | 154 | def raw_label_to_one_hot(raw_label, num_classes): 155 | """Converts a single label into one-hot vector. 156 | 157 | Args: 158 | raw_label: The raw label. 159 | num_classes: Total number of classes. 160 | 161 | Returns: 162 | one-hot vector of the given raw label. 163 | """ 164 | one_hot = np.zeros(num_classes, dtype=np.float32) 165 | one_hot[raw_label] = 1.0 166 | return one_hot 167 | 168 | 169 | def one_hot_to_raw_label(one_hot): 170 | """Converts a one-hot vector to a single value label. 171 | 172 | Args: 173 | one_hot: `np.ndarray`, a one-hot encoded vector. 174 | 175 | Returns: 176 | A single integer to represent the category. 177 | """ 178 | return np.argmax(one_hot) 179 | -------------------------------------------------------------------------------- /utils/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all loggers.""" 3 | 4 | from .normal_logger import NormalLogger 5 | from .rich_logger import RichLogger 6 | from .dummy_logger import DummyLogger 7 | 8 | __all__ = ['build_logger'] 9 | 10 | _LOGGERS = { 11 | 'normal': NormalLogger, 12 | 'rich': RichLogger, 13 | 'dummy': DummyLogger 14 | } 15 | 16 | 17 | def build_logger(logger_type='normal', **kwargs): 18 | """Builds a logger. 19 | 20 | Args: 21 | logger_type: Type of logger, which is case insensitive. 22 | (default: `normal`) 23 | **kwargs: Additional arguments to build the logger. 24 | 25 | Raises: 26 | ValueError: If the `logger_type` is not supported. 27 | """ 28 | logger_type = logger_type.lower() 29 | if logger_type not in _LOGGERS: 30 | raise ValueError(f'Invalid logger type: `{logger_type}`!\n' 31 | f'Types allowed: {list(_LOGGERS)}.') 32 | return _LOGGERS[logger_type](**kwargs) 33 | -------------------------------------------------------------------------------- /utils/loggers/dummy_logger.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of dummy logger. 3 | 4 | This logger has all expected logging functions but behaves silently, which is 5 | very useful in multi-processing mode. Only the chief process can have the logger 6 | with normal behavior. 7 | """ 8 | 9 | from .base_logger import BaseLogger 10 | 11 | __all__ = ['DummyLogger'] 12 | 13 | 14 | class DummyLogger(BaseLogger): 15 | """Implements a dummy logger which logs nothing.""" 16 | 17 | def __init__(self, 18 | logger_name='logger', 19 | logfile=None, 20 | screen_level=None, 21 | file_level=None, 22 | indent_space=4, 23 | verbose_log=False): 24 | super().__init__(logger_name=logger_name, 25 | logfile=logfile, 26 | screen_level=screen_level, 27 | file_level=file_level, 28 | indent_space=indent_space, 29 | verbose_log=verbose_log) 30 | 31 | def _log(self, message, **kwargs): 32 | return 33 | 34 | def _debug(self, message, **kwargs): 35 | return 36 | 37 | def _info(self, message, **kwargs): 38 | return 39 | 40 | def _warning(self, message, **kwargs): 41 | return 42 | 43 | def _error(self, message, **kwargs): 44 | return 45 | 46 | def _exception(self, message, **kwargs): 47 | return 48 | 49 | def _critical(self, message, **kwargs): 50 | return 51 | 52 | def _print(self, *messages, **kwargs): 53 | return 54 | 55 | def init_pbar(self, leave=False): 56 | return 57 | 58 | def add_pbar_task(self, name, total, **kwargs): 59 | return -1 60 | 61 | def update_pbar(self, task_id, advance=1): 62 | return 63 | 64 | def close_pbar(self): 65 | return 66 | -------------------------------------------------------------------------------- /utils/loggers/normal_logger.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of normal logger. 3 | 4 | This class is built based on the built-in function `print()`, the module 5 | `logging` and the module `tqdm` for progressive bar. 6 | """ 7 | 8 | import sys 9 | import logging 10 | from copy import deepcopy 11 | from tqdm import tqdm 12 | 13 | from .base_logger import BaseLogger 14 | 15 | __all__ = ['NormalLogger'] 16 | 17 | 18 | class NormalLogger(BaseLogger): 19 | """Implements the logger based on `logging` module.""" 20 | 21 | def __init__(self, 22 | logger_name='logger', 23 | logfile=None, 24 | screen_level=logging.INFO, 25 | file_level=logging.DEBUG, 26 | indent_space=4, 27 | verbose_log=False): 28 | super().__init__(logger_name=logger_name, 29 | logfile=logfile, 30 | screen_level=screen_level, 31 | file_level=file_level, 32 | indent_space=indent_space, 33 | verbose_log=verbose_log) 34 | 35 | # Get logger and check whether the logger has already been created. 36 | self.logger = logging.getLogger(self.logger_name) 37 | self.logger.propagate = False 38 | if self.logger.hasHandlers(): # Already existed 39 | raise SystemExit(f'Logger `{self.logger_name}` has already ' 40 | f'existed!\n' 41 | f'Please use another name, or otherwise the ' 42 | f'messages may be mixed up.') 43 | 44 | # Set format. 45 | self.logger.setLevel(logging.DEBUG) 46 | formatter = logging.Formatter( 47 | '[%(asctime)s][%(levelname)s] %(message)s', 48 | datefmt='%Y-%m-%d %H:%M:%S') 49 | 50 | # Print log message onto the screen. 51 | terminal_handler = logging.StreamHandler(stream=sys.stdout) 52 | terminal_handler.setLevel(self.screen_level) 53 | terminal_handler.setFormatter(formatter) 54 | self.logger.addHandler(terminal_handler) 55 | 56 | # Save log message into log file if needed. 57 | if self.logfile: 58 | # File will be closed when the logger is closed in `self.close()`. 59 | self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with 60 | file_handler = logging.StreamHandler(stream=self.file_stream) 61 | file_handler.setLevel(self.file_level) 62 | file_handler.setFormatter(formatter) 63 | self.logger.addHandler(file_handler) 64 | 65 | self.pbar = [] 66 | self.pbar_kwargs = {} 67 | 68 | def _log(self, message, **kwargs): 69 | self.logger.log(message, **kwargs) 70 | 71 | def _debug(self, message, **kwargs): 72 | self.logger.debug(message, **kwargs) 73 | 74 | def _info(self, message, **kwargs): 75 | self.logger.info(message, **kwargs) 76 | 77 | def _warning(self, message, **kwargs): 78 | self.logger.warning(message, **kwargs) 79 | 80 | def _error(self, message, **kwargs): 81 | self.logger.error(message, **kwargs) 82 | 83 | def _exception(self, message, **kwargs): 84 | self.logger.exception(message, **kwargs) 85 | 86 | def _critical(self, message, **kwargs): 87 | self.logger.critical(message, **kwargs) 88 | 89 | def _print(self, *messages, **kwargs): 90 | for handler in self.logger.handlers: 91 | print(*messages, file=handler.stream) 92 | 93 | def init_pbar(self, leave=False): 94 | columns = [ 95 | '{desc}', 96 | '{bar}', 97 | ' {percentage:5.1f}%', 98 | '[{elapsed}<{remaining}, {rate_fmt}{postfix}]', 99 | ] 100 | self.pbar_kwargs = dict( 101 | leave=leave, 102 | bar_format=' '.join(columns), 103 | unit='', 104 | ) 105 | 106 | def add_pbar_task(self, name, total, **kwargs): 107 | assert isinstance(self.pbar_kwargs, dict) 108 | pbar_kwargs = deepcopy(self.pbar_kwargs) 109 | pbar_kwargs.update(**kwargs) 110 | self.pbar.append(tqdm(desc=name, total=total, **pbar_kwargs)) 111 | return len(self.pbar) - 1 112 | 113 | def update_pbar(self, task_id, advance=1): 114 | assert len(self.pbar) > task_id and isinstance(self.pbar[task_id], tqdm) 115 | if self.pbar[task_id].n < self.pbar[task_id].total: 116 | self.pbar[task_id].update(advance) 117 | if self.pbar[task_id].n >= self.pbar[task_id].total: 118 | self.pbar[task_id].refresh() 119 | 120 | def close_pbar(self): 121 | for pbar in self.pbar[::-1]: 122 | pbar.close() 123 | self.pbar = [] 124 | self.pbar_kwargs = {} 125 | -------------------------------------------------------------------------------- /utils/loggers/rich_logger.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the class of rich logger. 3 | 4 | This class is based on the module `rich`. Please refer to 5 | https://github.com/Textualize/rich for more details. 6 | """ 7 | 8 | import sys 9 | import logging 10 | from copy import deepcopy 11 | from rich.console import Console 12 | from rich.logging import RichHandler 13 | from rich.progress import Progress 14 | from rich.progress import ProgressColumn 15 | from rich.progress import TextColumn 16 | from rich.progress import BarColumn 17 | from rich.text import Text 18 | 19 | from .base_logger import BaseLogger 20 | 21 | __all__ = ['RichLogger'] 22 | 23 | 24 | def _format_time(seconds): 25 | """Formats seconds to readable time string. 26 | 27 | This function is used to display time in progress bar. 28 | """ 29 | if not seconds: 30 | return '--:--' 31 | 32 | seconds = int(seconds) 33 | hours, seconds = divmod(seconds, 3600) 34 | minutes, seconds = divmod(seconds, 60) 35 | if hours: 36 | return f'{hours}:{minutes:02d}:{seconds:02d}' 37 | return f'{minutes:02d}:{seconds:02d}' 38 | 39 | 40 | class TimeColumn(ProgressColumn): 41 | """Renders total time, ETA, and speed in progress bar.""" 42 | 43 | max_refresh = 0.5 # Only refresh twice a second to prevent jitter 44 | 45 | def render(self, task): 46 | elapsed_time = _format_time(task.elapsed) 47 | eta = _format_time(task.time_remaining) 48 | speed = f'{task.speed:.2f}/s' if task.speed else '?/s' 49 | return Text(f'[{elapsed_time}<{eta}, {speed}]', 50 | style='progress.remaining') 51 | 52 | 53 | class RichLogger(BaseLogger): 54 | """Implements the logger based on `rich` module.""" 55 | 56 | def __init__(self, 57 | logger_name='logger', 58 | logfile=None, 59 | screen_level=logging.INFO, 60 | file_level=logging.DEBUG, 61 | indent_space=4, 62 | verbose_log=False): 63 | super().__init__(logger_name=logger_name, 64 | logfile=logfile, 65 | screen_level=screen_level, 66 | file_level=file_level, 67 | indent_space=indent_space, 68 | verbose_log=verbose_log) 69 | 70 | # Get logger and check whether the logger has already been created. 71 | self.logger = logging.getLogger(self.logger_name) 72 | self.logger.propagate = False 73 | if self.logger.hasHandlers(): # Already existed 74 | raise SystemExit(f'Logger `{self.logger_name}` has already ' 75 | f'existed!\n' 76 | f'Please use another name, or otherwise the ' 77 | f'messages may be mixed up.') 78 | 79 | # Set format. 80 | self.logger.setLevel(logging.DEBUG) 81 | 82 | # Print log message onto the screen. 83 | terminal_console = Console( 84 | file=sys.stdout, log_time=False, log_path=False) 85 | terminal_handler = RichHandler( 86 | level=self.screen_level, 87 | console=terminal_console, 88 | show_time=True, 89 | show_level=True, 90 | show_path=False, 91 | log_time_format='[%Y-%m-%d %H:%M:%S] ') 92 | terminal_handler.setFormatter(logging.Formatter('%(message)s')) 93 | self.logger.addHandler(terminal_handler) 94 | 95 | # Save log message into log file if needed. 96 | if self.logfile: 97 | # File will be closed when the logger is closed in `self.close()`. 98 | self.file_stream = open(self.logfile, 'a') # pylint: disable=consider-using-with 99 | file_console = Console( 100 | file=self.file_stream, log_time=False, log_path=False) 101 | file_handler = RichHandler( 102 | level=self.file_level, 103 | console=file_console, 104 | show_time=True, 105 | show_level=True, 106 | show_path=False, 107 | log_time_format='[%Y-%m-%d %H:%M:%S] ') 108 | file_handler.setFormatter(logging.Formatter('%(message)s')) 109 | self.logger.addHandler(file_handler) 110 | 111 | self.pbar = None 112 | self.pbar_kwargs = {} 113 | 114 | def _log(self, message, **kwargs): 115 | self.logger.log(message, **kwargs) 116 | 117 | def _debug(self, message, **kwargs): 118 | self.logger.debug(message, **kwargs) 119 | 120 | def _info(self, message, **kwargs): 121 | self.logger.info(message, **kwargs) 122 | 123 | def _warning(self, message, **kwargs): 124 | self.logger.warning(message, **kwargs) 125 | 126 | def _error(self, message, **kwargs): 127 | self.logger.error(message, **kwargs) 128 | 129 | def _exception(self, message, **kwargs): 130 | self.logger.exception(message, **kwargs) 131 | 132 | def _critical(self, message, **kwargs): 133 | self.logger.critical(message, **kwargs) 134 | 135 | def _print(self, *messages, **kwargs): 136 | for handler in self.logger.handlers: 137 | handler.console.print(*messages, **kwargs) 138 | 139 | def init_pbar(self, leave=False): 140 | assert self.pbar is None 141 | 142 | # Columns shown in the progress bar. 143 | columns = ( 144 | TextColumn('[progress.description]{task.description}'), 145 | BarColumn(bar_width=None), 146 | TextColumn('[progress.percentage]{task.percentage:>5.1f}%'), 147 | TimeColumn(), 148 | ) 149 | 150 | self.pbar = Progress(*columns, 151 | console=self.logger.handlers[0].console, 152 | transient=not leave, 153 | auto_refresh=True, 154 | refresh_per_second=10) 155 | self.pbar.start() 156 | 157 | def add_pbar_task(self, name, total, **kwargs): 158 | assert isinstance(self.pbar, Progress) 159 | assert isinstance(self.pbar_kwargs, dict) 160 | pbar_kwargs = deepcopy(self.pbar_kwargs) 161 | pbar_kwargs.update(**kwargs) 162 | task_id = self.pbar.add_task(name, total=total, **pbar_kwargs) 163 | return task_id 164 | 165 | def update_pbar(self, task_id, advance=1): 166 | assert isinstance(self.pbar, Progress) 167 | if self.pbar.tasks[task_id].finished: 168 | if self.pbar.tasks[task_id].stop_time is None: 169 | self.pbar.stop_task(task_id) 170 | else: 171 | self.pbar.update(task_id, advance=advance) 172 | 173 | def close_pbar(self): 174 | assert isinstance(self.pbar, Progress) 175 | self.pbar.stop() 176 | self.pbar = None 177 | self.pbar_kwargs = {} 178 | -------------------------------------------------------------------------------- /utils/loggers/test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Unit test for logger.""" 3 | 4 | import os 5 | import time 6 | 7 | from . import build_logger 8 | 9 | __all__ = ['test_logger'] 10 | 11 | _TEST_DIR = 'logger_test' 12 | 13 | 14 | def test_logger(test_dir=_TEST_DIR): 15 | """Tests loggers.""" 16 | print('========== Start Logger Test ==========') 17 | 18 | os.makedirs(test_dir, exist_ok=True) 19 | 20 | for logger_type in ['normal', 'rich', 'dummy']: 21 | for indent_space in [2, 4]: 22 | for verbose_log in [False, True]: 23 | if logger_type == 'normal': 24 | class_name = 'Logger' 25 | elif logger_type == 'rich': 26 | class_name = 'RichLogger' 27 | elif logger_type == 'dummy': 28 | class_name = 'DummyLogger' 29 | 30 | print(f'===== ' 31 | f'Testing `utils.logger.{class_name}` ' 32 | f' (indent: {indent_space}, verbose: {verbose_log}) ' 33 | f'=====') 34 | logger_name = (f'{logger_type}_logger_' 35 | f'indent_{indent_space}_' 36 | f'verbose_{verbose_log}') 37 | logger = build_logger( 38 | logger_type, 39 | logger_name=logger_name, 40 | logfile=os.path.join(test_dir, f'test_{logger_name}.log'), 41 | verbose_log=verbose_log, 42 | indent_space=indent_space) 43 | logger.print('print log') 44 | logger.print('print log,', 'log 2') 45 | logger.print('print log (indent level 0)', indent_level=0) 46 | logger.print('print log (indent level 1)', indent_level=1) 47 | logger.print('print log (indent level 2)', indent_level=2) 48 | logger.print('print log (verbose `False`)', is_verbose=False) 49 | logger.print('print log (verbose `True`)', is_verbose=True) 50 | logger.debug('debug log') 51 | logger.info('info log') 52 | logger.warning('warning log') 53 | logger.init_pbar() 54 | task_1 = logger.add_pbar_task('Task 1', 500) 55 | task_2 = logger.add_pbar_task('Task 2', 1000) 56 | for _ in range(1000): 57 | logger.update_pbar(task_1, 1) 58 | logger.update_pbar(task_2, 1) 59 | time.sleep(0.002) 60 | logger.close_pbar() 61 | print('Success!') 62 | 63 | print('========== Finish Logger Test ==========') 64 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Misc utility functions.""" 3 | 4 | import os 5 | import hashlib 6 | 7 | from pathlib import Path 8 | from torch.hub import download_url_to_file 9 | 10 | __all__ = [ 11 | 'REPO_NAME', 'Infix', 'print_and_execute', 'check_file_ext', 12 | 'IMAGE_EXTENSIONS', 'VIDEO_EXTENSIONS', 'MEDIA_EXTENSIONS', 13 | 'parse_file_format', 'set_cache_dir', 'get_cache_dir', 14 | 'md5_update_from_file', 'md5_update_from_dir', 'md5', 'download_url' 15 | ] 16 | 17 | REPO_NAME = 'Hammer' # Name of the repository (project). 18 | 19 | 20 | class Infix(object): 21 | """Helper class to create custom infix operators. 22 | 23 | When using it, make sure to put the operator between `<<` and `>>`. 24 | `<< INFIX_OP_NAME >>` should be considered as a whole operator. 25 | 26 | Examples: 27 | 28 | # Use `Infix` to create infix operators directly. 29 | add = Infix(lambda a, b: a + b) 30 | 1 << add >> 2 # gives 3 31 | 1 << add >> 2 << add >> 3 # gives 6 32 | 33 | # Use `Infix` as a decorator. 34 | @Infix 35 | def mul(a, b): 36 | return a * b 37 | 2 << mul >> 4 # gives 8 38 | 2 << mul >> 3 << mul >> 7 # gives 42 39 | """ 40 | 41 | def __init__(self, function): 42 | self.function = function 43 | self.left_value = None 44 | 45 | def __rlshift__(self, left_value): # override `<<` before `Infix` instance 46 | assert self.left_value is None # make sure left is only called once 47 | self.left_value = left_value 48 | return self 49 | 50 | def __rshift__(self, right_value): # override `>>` after `Infix` instance 51 | result = self.function(self.left_value, right_value) 52 | self.left_value = None # reset to None 53 | return result 54 | 55 | 56 | def print_and_execute(cmd): 57 | """Prints and executes a system command. 58 | 59 | Args: 60 | cmd: Command to be executed. 61 | """ 62 | print(cmd) 63 | os.system(cmd) 64 | 65 | 66 | def check_file_ext(filename, *ext_list): 67 | """Checks whether the given filename is with target extension(s). 68 | 69 | NOTE: If `ext_list` is empty, this function will always return `False`. 70 | 71 | Args: 72 | filename: Filename to check. 73 | *ext_list: A list of extensions. 74 | 75 | Returns: 76 | `True` if the filename is with one of extensions in `ext_list`, 77 | otherwise `False`. 78 | """ 79 | if len(ext_list) == 0: 80 | return False 81 | ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] 82 | ext_list = [ext.lower() for ext in ext_list] 83 | basename = os.path.basename(filename) 84 | ext = os.path.splitext(basename)[1].lower() 85 | return ext in ext_list 86 | 87 | 88 | # File extensions regarding images (not including GIFs). 89 | IMAGE_EXTENSIONS = ( 90 | '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', 91 | '.tiff', '.tif' 92 | ) 93 | # File extensions regarding videos. 94 | VIDEO_EXTENSIONS = ( 95 | '.avi', '.mkv', '.mp4', '.m4v', '.mov', '.webm', '.flv', '.rmvb', '.rm', 96 | '.3gp' 97 | ) 98 | # File extensions regarding media, i.e., images, videos, GIFs. 99 | MEDIA_EXTENSIONS = ('.gif', *IMAGE_EXTENSIONS, *VIDEO_EXTENSIONS) 100 | 101 | 102 | def parse_file_format(path): 103 | """Parses the file format of a given path. 104 | 105 | This function basically parses the file format according to its extension. 106 | It will also return `dir` is the given path is a directory. 107 | 108 | Parable file formats: 109 | 110 | - zip: with `.zip` extension. 111 | - tar: with `.tar` / `.tgz` / `.tar.gz` extension. 112 | - lmdb: a folder ending with `lmdb`. 113 | - txt: with `.txt` / `.text` extension, OR without extension (e.g. LICENSE). 114 | - json: with `.json` extension. 115 | - jpg: with `.jpeg` / `jpg` / `jpe` extension. 116 | - png: with `.png` extension. 117 | 118 | Args: 119 | path: The path to the file to parse format from. 120 | 121 | Returns: 122 | A lower-case string, indicating the file format, or `None` if the format 123 | cannot be successfully parsed. 124 | """ 125 | # Handle directory. 126 | if os.path.isdir(path) or path.endswith('/'): 127 | if path.rstrip('/').lower().endswith('lmdb'): 128 | return 'lmdb' 129 | return 'dir' 130 | # Handle file. 131 | if os.path.isfile(path) and os.path.splitext(path)[1] == '': 132 | return 'txt' 133 | path = path.lower() 134 | if path.endswith('.tar.gz'): # Cannot parse accurate extension. 135 | return 'tar' 136 | ext = os.path.splitext(path)[1] 137 | if ext == '.zip': 138 | return 'zip' 139 | if ext in ['.tar', '.tgz']: 140 | return 'tar' 141 | if ext in ['.txt', '.text']: 142 | return 'txt' 143 | if ext == '.json': 144 | return 'json' 145 | if ext in ['.jpeg', '.jpg', '.jpe']: 146 | return 'jpg' 147 | if ext == '.png': 148 | return 'png' 149 | # Unparsable. 150 | return None 151 | 152 | 153 | _cache_dir = None 154 | 155 | 156 | def set_cache_dir(directory=None): 157 | """Sets the global cache directory. 158 | 159 | The cache directory can be used to save some files that will be shared 160 | across jobs. The default cache directory is set as `~/.cache/`. This 161 | function can be used to redirect the cache directory. Or, users can use 162 | `None` to reset the cache directory back to default. 163 | 164 | Args: 165 | directory: The target directory used to cache files. If set as `None`, 166 | the cache directory will be reset back to default. (default: None) 167 | """ 168 | assert directory is None or isinstance(directory, str), 'Invalid directory!' 169 | global _cache_dir # pylint: disable=global-statement 170 | _cache_dir = directory 171 | 172 | 173 | def get_cache_dir(use_repo_name=True): 174 | """Gets the global cache directory. 175 | 176 | The global cache directory is primarily set as `~/.cache/` by default, and 177 | can be redirected with `set_cache_dir()`. 178 | 179 | Args: 180 | use_repo_name: Whether to create a folder, named `REPO_NAME`, under 181 | `_cache_dir` as the actual cache directory. (default: True) 182 | 183 | Returns: 184 | A string, representing the global cache directory. 185 | """ 186 | if _cache_dir is None: 187 | cache_dir = os.path.join(os.path.expanduser('~'), '.cache') 188 | else: 189 | cache_dir = _cache_dir 190 | if use_repo_name: 191 | return os.path.join(cache_dir, REPO_NAME) 192 | return cache_dir 193 | 194 | 195 | def md5_update_from_file(filename, hash_obj): 196 | """Updates the `hash_obj` with a given file. 197 | 198 | Args: 199 | filename: Path to the file. 200 | hash_obj: The `hashlib.md5` object to be updated. 201 | 202 | Returns: 203 | An updated `hashlib.md5` object. 204 | """ 205 | assert Path(filename).is_file() 206 | with open(str(filename), 'rb') as f: 207 | for chunk in iter(lambda: f.read(4096), b''): 208 | hash_obj.update(chunk) 209 | return hash_obj 210 | 211 | 212 | def md5_update_from_dir(directory, hash_obj): 213 | """Updates the `hash_obj` with a given directory. 214 | 215 | Args: 216 | directory: Path to the directory. 217 | hash_obj: The `hashlib.md5` object to be updated. 218 | 219 | Returns: 220 | An updated `hashlib.md5` object. 221 | """ 222 | assert Path(directory).is_dir() 223 | for path in sorted(Path(directory).iterdir()): 224 | hash_obj.update(path.name.encode()) 225 | if path.is_file(): 226 | hash_obj = md5_update_from_file(path, hash_obj) 227 | elif path.is_dir(): 228 | hash_obj = md5_update_from_dir(path, hash_obj) 229 | return hash_obj 230 | 231 | 232 | def md5(path): 233 | """Returns the MD5 of the given path in hex digest.""" 234 | if Path(path).is_file(): 235 | md5_hash = md5_update_from_file(path, hashlib.md5()) 236 | return md5_hash.hexdigest() 237 | if Path(path).is_dir(): 238 | md5_hash = md5_update_from_dir(path, hashlib.md5()) 239 | return md5_hash.hexdigest() 240 | raise ValueError(f'Currently calculation of MD5 does not supported for ' 241 | f'`{path}`') 242 | 243 | 244 | def download_url(url, path=None, filename=None, sha256=None): 245 | """Downloads file from URL. 246 | 247 | This function downloads a file from given URL, and executes Hash check if 248 | needed. 249 | 250 | Args: 251 | url: The URL to download file from. 252 | path: Path (directory) to save the downloaded file. If set as `None`, 253 | the cache directory will be used. Please see `get_cache_dir()` for 254 | more details. (default: None) 255 | filename: The name to save the file. If set as `None`, this name will be 256 | automatically parsed from the given URL. (default: None) 257 | sha256: The expected sha256 of the downloaded file. If set as `None`, 258 | the hash check will be skipped. Otherwise, this function will check 259 | whether the sha256 of the downloaded file matches this field. 260 | 261 | Returns: 262 | A two-element tuple, where the first term is the full path of the 263 | downloaded file, and the second term indicate the hash check result. 264 | `True` means hash check passes, `False` means hash check fails, 265 | while `None` means no hash check is executed. 266 | """ 267 | # Handle file path. 268 | if path is None: 269 | path = get_cache_dir() 270 | if filename is None: 271 | filename = os.path.basename(url) 272 | save_path = os.path.join(path, filename) 273 | # Download file if needed. 274 | if not os.path.exists(save_path): 275 | print(f'Downloading URL `{url}` to path `{save_path}` ...') 276 | os.makedirs(path, exist_ok=True) 277 | download_url_to_file(url, save_path, hash_prefix=None, progress=True) 278 | # Check hash if needed. 279 | check_result = None 280 | if sha256 is not None: 281 | with open(save_path, 'rb') as f: 282 | file_hash = hashlib.sha256(f.read()) 283 | check_result = (file_hash.hexdigest() == sha256) 284 | 285 | return save_path, check_result 286 | -------------------------------------------------------------------------------- /utils/parsing_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the utility functions for parsing arguments.""" 3 | 4 | import json 5 | import argparse 6 | import click 7 | 8 | __all__ = [ 9 | 'parse_int', 'parse_float', 'parse_bool', 'parse_index', 'parse_json', 10 | 'IntegerParamType', 'FloatParamType', 'BooleanParamType', 'IndexParamType', 11 | 'JsonParamType', 'DictAction' 12 | ] 13 | 14 | 15 | def parse_int(arg): 16 | """Parses an argument to integer. 17 | 18 | Support converting string `none` and `null` to `None`. 19 | """ 20 | if arg is None: 21 | return None 22 | if isinstance(arg, str) and arg.lower() in ['none', 'null']: 23 | return None 24 | return int(arg) 25 | 26 | 27 | def parse_float(arg): 28 | """Parses an argument to float number. 29 | 30 | Support converting string `none` and `null` to `None`. 31 | """ 32 | if arg is None: 33 | return None 34 | if isinstance(arg, str) and arg.lower() in ['none', 'null']: 35 | return None 36 | return float(arg) 37 | 38 | 39 | def parse_bool(arg): 40 | """Parses an argument to boolean. 41 | 42 | `None` will be converted to `False`. 43 | """ 44 | if isinstance(arg, bool): 45 | return arg 46 | if arg is None: 47 | return False 48 | if arg.lower() in ['1', 'true', 't', 'yes', 'y']: 49 | return True 50 | if arg.lower() in ['0', 'false', 'f', 'no', 'n', 'none', 'null']: 51 | return False 52 | raise ValueError(f'`{arg}` cannot be converted to boolean!') 53 | 54 | 55 | def parse_index(arg, min_val=None, max_val=None): 56 | """Parses indices. 57 | 58 | If the input is a list or tuple, this function has no effect. 59 | 60 | If the input is a string, it can be either a comma separated list of numbers 61 | `1, 3, 5`, or a dash separated range `3 - 10`. Spaces in the string will be 62 | ignored. 63 | 64 | Args: 65 | arg: The input argument to parse indices from. 66 | min_val: If not `None`, this function will check that all indices are 67 | equal to or larger than this value. (default: None) 68 | max_val: If not `None`, this function will check that all indices are 69 | equal to or smaller than this field. (default: None) 70 | 71 | Returns: 72 | A list of integers. 73 | 74 | Raises: 75 | ValueError: If the input is invalid, i.e., neither a list or tuple, nor 76 | a string. 77 | """ 78 | if arg is None or arg == '': 79 | indices = [] 80 | elif isinstance(arg, int): 81 | indices = [arg] 82 | elif isinstance(arg, (list, tuple)): 83 | indices = list(arg) 84 | elif isinstance(arg, str): 85 | indices = [] 86 | if arg.lower() not in ['none', 'null']: 87 | splits = arg.replace(' ', '').split(',') 88 | for split in splits: 89 | numbers = list(map(int, split.split('-'))) 90 | if len(numbers) == 1: 91 | indices.append(numbers[0]) 92 | elif len(numbers) == 2: 93 | indices.extend(list(range(numbers[0], numbers[1] + 1))) 94 | else: 95 | raise ValueError(f'Invalid type of input: `{type(arg)}`!') 96 | 97 | assert isinstance(indices, list) 98 | indices = sorted(list(set(indices))) 99 | for idx in indices: 100 | assert isinstance(idx, int) 101 | if min_val is not None: 102 | assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!' 103 | if max_val is not None: 104 | assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!' 105 | 106 | return indices 107 | 108 | 109 | def parse_json(arg): 110 | """Parses a string-like argument following JSON format. 111 | 112 | - `None` arguments will be kept. 113 | - Non-string arguments will be kept. 114 | """ 115 | if not isinstance(arg, str): 116 | return arg 117 | try: 118 | return json.loads(arg) 119 | except json.decoder.JSONDecodeError: 120 | return arg 121 | 122 | 123 | class IntegerParamType(click.ParamType): 124 | """Defines a `click.ParamType` to parse integer arguments.""" 125 | 126 | name = 'int' 127 | 128 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements 129 | try: 130 | return parse_int(value) 131 | except ValueError: 132 | self.fail(f'`{value}` cannot be parsed as an integer!', param, ctx) 133 | 134 | 135 | class FloatParamType(click.ParamType): 136 | """Defines a `click.ParamType` to parse float arguments.""" 137 | 138 | name = 'float' 139 | 140 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements 141 | try: 142 | return parse_float(value) 143 | except ValueError: 144 | self.fail(f'`{value}` cannot be parsed as a float!', param, ctx) 145 | 146 | 147 | class BooleanParamType(click.ParamType): 148 | """Defines a `click.ParamType` to parse boolean arguments.""" 149 | 150 | name = 'bool' 151 | 152 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements 153 | try: 154 | return parse_bool(value) 155 | except ValueError: 156 | self.fail(f'`{value}` cannot be parsed as a boolean!', param, ctx) 157 | 158 | 159 | class IndexParamType(click.ParamType): 160 | """Defines a `click.ParamType` to parse indices arguments.""" 161 | 162 | name = 'index' 163 | 164 | def __init__(self, min_val=None, max_val=None): 165 | self.min_val = min_val 166 | self.max_val = max_val 167 | 168 | def convert(self, value, param, ctx): # pylint: disable=inconsistent-return-statements 169 | try: 170 | return parse_index(value, self.min_val, self.max_val) 171 | except ValueError: 172 | self.fail( 173 | f'`{value}` cannot be parsed as a list of indices!', param, ctx) 174 | 175 | 176 | class JsonParamType(click.ParamType): 177 | """Defines a `click.ParamType` to parse arguments following JSON format.""" 178 | 179 | name = 'json' 180 | 181 | def convert(self, value, param, ctx): 182 | return parse_json(value) 183 | 184 | 185 | class DictAction(argparse.Action): 186 | """Argparse action to split each argument into (key, value) pair. 187 | 188 | Each argument should be with `key=value` format, where `value` should be a 189 | string with JSON format. 190 | 191 | For example, with an argparse: 192 | 193 | parser.add_argument('--options', nargs='+', action=DictAction) 194 | 195 | , you can use following arguments in the command line: 196 | 197 | --options \ 198 | a=1 \ 199 | b=1.5 200 | c=true \ 201 | d=null \ 202 | e=[1,2,3,4,5] \ 203 | f='{"x":1,"y":2,"z":3}' \ 204 | 205 | NOTE: No space is allowed in each argument. Also, the dictionary-type 206 | argument should be quoted with single quotation marks `'`. 207 | """ 208 | 209 | def __call__(self, parser, namespace, values, option_string=None): 210 | options = {} 211 | for argument in values: 212 | key, val = argument.split('=', maxsplit=1) 213 | options[key] = parse_json(val) 214 | setattr(namespace, self.dest, options) 215 | -------------------------------------------------------------------------------- /utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the utility functions to handle import TensorFlow modules. 3 | 4 | Basically, TensorFlow may not be supported in the current environment, or may 5 | cause some warnings. This file provides functions to help ease TensorFlow 6 | related imports, such as TensorBoard. 7 | """ 8 | 9 | import warnings 10 | 11 | __all__ = ['import_tf', 'import_tb_writer'] 12 | 13 | 14 | def import_tf(): 15 | """Imports TensorFlow module if possible. 16 | 17 | If `ImportError` is raised, `None` will be returned. Otherwise, the module 18 | `tensorflow` will be returned. 19 | """ 20 | warnings.filterwarnings('ignore', category=FutureWarning) 21 | try: 22 | import tensorflow as tf # pylint: disable=import-outside-toplevel 23 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 24 | module = tf 25 | except ImportError: 26 | module = None 27 | warnings.filterwarnings('default', category=FutureWarning) 28 | return module 29 | 30 | 31 | def import_tb_writer(): 32 | """Imports the SummaryWriter of TensorBoard. 33 | 34 | If `ImportError` is raised, `None` will be returned. Otherwise, the class 35 | `SummaryWriter` will be returned. 36 | 37 | NOTE: This function attempts to import `SummaryWriter` from 38 | `torch.utils.tensorboard`. But it does not necessarily mean the import 39 | always succeeds because installing TensorBoard is not a duty of `PyTorch`. 40 | """ 41 | warnings.filterwarnings('ignore', category=FutureWarning) 42 | try: 43 | from torch.utils.tensorboard import SummaryWriter # pylint: disable=import-outside-toplevel 44 | except ImportError: # In case TensorBoard is not supported. 45 | SummaryWriter = None 46 | warnings.filterwarnings('default', category=FutureWarning) 47 | return SummaryWriter 48 | -------------------------------------------------------------------------------- /utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Collects all visualizers.""" 3 | 4 | from .grid_visualizer import GridVisualizer 5 | from .gif_visualizer import GifVisualizer 6 | from .html_visualizer import HtmlVisualizer 7 | from .html_visualizer import HtmlReader 8 | from .video_visualizer import VideoVisualizer 9 | from .video_visualizer import VideoReader 10 | 11 | __all__ = [ 12 | 'GridVisualizer', 'GifVisualizer', 'HtmlVisualizer', 'HtmlReader', 13 | 'VideoVisualizer', 'VideoReader' 14 | ] 15 | -------------------------------------------------------------------------------- /utils/visualizers/gif_visualizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the visualizer to visualize images as a GIF.""" 3 | 4 | from PIL import Image 5 | 6 | from ..image_utils import parse_image_size 7 | from ..image_utils import load_image 8 | from ..image_utils import resize_image 9 | from ..image_utils import list_images_from_dir 10 | 11 | __all__ = ['GifVisualizer'] 12 | 13 | 14 | class GifVisualizer(object): 15 | """Defines the visualizer that visualizes an image collection as GIF.""" 16 | 17 | def __init__(self, image_size=None, duration=100, loop=0): 18 | """Initializes the GIF visualizer. 19 | 20 | Args: 21 | image_size: Size for image visualization. (default: None) 22 | duration: Duration between two frames, in milliseconds. 23 | (default: 100) 24 | loop: How many times to loop the GIF. `0` means infinite. 25 | (default: 0) 26 | """ 27 | self.set_image_size(image_size) 28 | self.set_duration(duration) 29 | self.set_loop(loop) 30 | 31 | def set_image_size(self, image_size=None): 32 | """Sets the image size of the GIF.""" 33 | height, width = parse_image_size(image_size) 34 | self.image_height = height 35 | self.image_width = width 36 | 37 | def set_duration(self, duration=100): 38 | """Sets the GIF duration.""" 39 | self.duration = duration 40 | 41 | def set_loop(self, loop=0): 42 | """Sets how many times the GIF will be looped. `0` means infinite.""" 43 | self.loop = loop 44 | 45 | def visualize_collection(self, images, save_path): 46 | """Visualizes a collection of images one by one.""" 47 | height, width = images[0].shape[0:2] 48 | height = self.image_height or height 49 | width = self.image_width or width 50 | pil_images = [] 51 | for image in images: 52 | if image.shape[0:2] != (height, width): 53 | image = resize_image(image, (width, height)) 54 | pil_images.append(Image.fromarray(image)) 55 | pil_images[0].save(save_path, format='GIF', save_all=True, 56 | append_images=pil_images[1:], 57 | duration=self.duration, 58 | loop=self.loop) 59 | 60 | def visualize_list(self, image_list, save_path): 61 | """Visualizes a list of image files.""" 62 | height, width = load_image(image_list[0]).shape[0:2] 63 | height = self.image_height or height 64 | width = self.image_width or width 65 | pil_images = [] 66 | for filename in image_list: 67 | image = load_image(filename) 68 | if image.shape[0:2] != (height, width): 69 | image = resize_image(image, (width, height)) 70 | pil_images.append(Image.fromarray(image)) 71 | pil_images[0].save(save_path, format='GIF', save_all=True, 72 | append_images=pil_images[1:], 73 | duration=self.duration, 74 | loop=self.loop) 75 | 76 | def visualize_directory(self, directory, save_path): 77 | """Visualizes all images under a directory.""" 78 | image_list = list_images_from_dir(directory) 79 | self.visualize_list(image_list, save_path) 80 | -------------------------------------------------------------------------------- /utils/visualizers/grid_visualizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the visualizer to visualize images by composing them as a gird.""" 3 | 4 | from ..image_utils import get_blank_image 5 | from ..image_utils import get_grid_shape 6 | from ..image_utils import parse_image_size 7 | from ..image_utils import load_image 8 | from ..image_utils import save_image 9 | from ..image_utils import resize_image 10 | from ..image_utils import list_images_from_dir 11 | 12 | __all__ = ['GridVisualizer'] 13 | 14 | 15 | class GridVisualizer(object): 16 | """Defines the visualizer that visualizes images as a grid. 17 | 18 | Basically, given a collection of images, this visualizer stitches them one 19 | by one. Notably, this class also supports adding spaces between images, 20 | adding borders around images, and using white/black background. 21 | 22 | Example: 23 | 24 | grid = GridVisualizer(num_rows, num_cols) 25 | for i in range(num_rows): 26 | for j in range(num_cols): 27 | grid.add(i, j, image) 28 | grid.save('visualize.jpg') 29 | """ 30 | 31 | def __init__(self, 32 | grid_size=0, 33 | num_rows=0, 34 | num_cols=0, 35 | is_portrait=False, 36 | image_size=None, 37 | image_channels=0, 38 | row_spacing=0, 39 | col_spacing=0, 40 | border_left=0, 41 | border_right=0, 42 | border_top=0, 43 | border_bottom=0, 44 | use_black_background=True): 45 | """Initializes the grid visualizer. 46 | 47 | Args: 48 | grid_size: Total number of cells, i.e., height * width. (default: 0) 49 | num_rows: Number of rows. (default: 0) 50 | num_cols: Number of columns. (default: 0) 51 | is_portrait: Whether the grid should be portrait or landscape. 52 | This is only used when it requires to compute `num_rows` and 53 | `num_cols` automatically. See function `get_grid_shape()` in 54 | file `./image_utils.py` for details. (default: False) 55 | image_size: Size to visualize each image. (default: 0) 56 | image_channels: Number of image channels. (default: 0) 57 | row_spacing: Spacing between rows. (default: 0) 58 | col_spacing: Spacing between columns. (default: 0) 59 | border_left: Width of left border. (default: 0) 60 | border_right: Width of right border. (default: 0) 61 | border_top: Width of top border. (default: 0) 62 | border_bottom: Width of bottom border. (default: 0) 63 | use_black_background: Whether to use black background. 64 | (default: True) 65 | """ 66 | self.reset(grid_size, num_rows, num_cols, is_portrait) 67 | self.set_image_size(image_size) 68 | self.set_image_channels(image_channels) 69 | self.set_row_spacing(row_spacing) 70 | self.set_col_spacing(col_spacing) 71 | self.set_border_left(border_left) 72 | self.set_border_right(border_right) 73 | self.set_border_top(border_top) 74 | self.set_border_bottom(border_bottom) 75 | self.set_background(use_black_background) 76 | self.grid = None 77 | 78 | def reset(self, 79 | grid_size=0, 80 | num_rows=0, 81 | num_cols=0, 82 | is_portrait=False): 83 | """Resets the grid shape, i.e., number of rows/columns.""" 84 | if grid_size > 0: 85 | num_rows, num_cols = get_grid_shape(grid_size, 86 | height=num_rows, 87 | width=num_cols, 88 | is_portrait=is_portrait) 89 | self.grid_size = num_rows * num_cols 90 | self.num_rows = num_rows 91 | self.num_cols = num_cols 92 | self.grid = None 93 | 94 | def set_image_size(self, image_size=None): 95 | """Sets the image size of each cell in the grid.""" 96 | height, width = parse_image_size(image_size) 97 | self.image_height = height 98 | self.image_width = width 99 | 100 | def set_image_channels(self, image_channels=0): 101 | """Sets the number of channels of the grid.""" 102 | self.image_channels = image_channels 103 | 104 | def set_row_spacing(self, row_spacing=0): 105 | """Sets the spacing between grid rows.""" 106 | self.row_spacing = row_spacing 107 | 108 | def set_col_spacing(self, col_spacing=0): 109 | """Sets the spacing between grid columns.""" 110 | self.col_spacing = col_spacing 111 | 112 | def set_border_left(self, border_left=0): 113 | """Sets the width of the left border of the grid.""" 114 | self.border_left = border_left 115 | 116 | def set_border_right(self, border_right=0): 117 | """Sets the width of the right border of the grid.""" 118 | self.border_right = border_right 119 | 120 | def set_border_top(self, border_top=0): 121 | """Sets the width of the top border of the grid.""" 122 | self.border_top = border_top 123 | 124 | def set_border_bottom(self, border_bottom=0): 125 | """Sets the width of the bottom border of the grid.""" 126 | self.border_bottom = border_bottom 127 | 128 | def set_background(self, use_black=True): 129 | """Sets the grid background.""" 130 | self.use_black_background = use_black 131 | 132 | def init_grid(self): 133 | """Initializes the grid with a blank image.""" 134 | assert self.num_rows > 0 135 | assert self.num_cols > 0 136 | assert self.image_height > 0 137 | assert self.image_width > 0 138 | assert self.image_channels > 0 139 | grid_height = (self.image_height * self.num_rows + 140 | self.row_spacing * (self.num_rows - 1) + 141 | self.border_top + self.border_bottom) 142 | grid_width = (self.image_width * self.num_cols + 143 | self.col_spacing * (self.num_cols - 1) + 144 | self.border_left + self.border_right) 145 | self.grid = get_blank_image(grid_height, grid_width, 146 | channels=self.image_channels, 147 | use_black=self.use_black_background) 148 | 149 | def add(self, i, j, image): 150 | """Adds an image into the grid. 151 | 152 | NOTE: The input image is assumed to be with `RGB` channel order. 153 | """ 154 | channels = 1 if image.ndim == 2 else image.shape[2] 155 | if self.grid is None: 156 | height, width = image.shape[0:2] 157 | height = self.image_height or height 158 | width = self.image_width or width 159 | channels = self.image_channels or channels 160 | self.set_image_size((height, width)) 161 | self.set_image_channels(channels) 162 | self.init_grid() 163 | if image.shape[0:2] != (self.image_height, self.image_width): 164 | image = resize_image(image, (self.image_width, self.image_height)) 165 | y = self.border_top + i * (self.image_height + self.row_spacing) 166 | x = self.border_left + j * (self.image_width + self.col_spacing) 167 | self.grid[y:y + self.image_height, 168 | x:x + self.image_width, 169 | :channels] = image 170 | 171 | def visualize_collection(self, 172 | images, 173 | save_path=None, 174 | num_rows=0, 175 | num_cols=0, 176 | is_portrait=False, 177 | is_row_major=True): 178 | """Visualizes a collection of images one by one.""" 179 | self.grid = None 180 | self.reset(grid_size=len(images), 181 | num_rows=num_rows, 182 | num_cols=num_cols, 183 | is_portrait=is_portrait) 184 | for idx, image in enumerate(images): 185 | if is_row_major: 186 | row_idx, col_idx = divmod(idx, self.num_cols) 187 | else: 188 | col_idx, row_idx = divmod(idx, self.num_rows) 189 | self.add(row_idx, col_idx, image) 190 | if save_path: 191 | self.save(save_path) 192 | 193 | def visualize_list(self, 194 | image_list, 195 | save_path=None, 196 | num_rows=0, 197 | num_cols=0, 198 | is_portrait=False, 199 | is_row_major=True): 200 | """Visualizes a list of image files.""" 201 | self.grid = None 202 | self.reset(grid_size=len(image_list), 203 | num_rows=num_rows, 204 | num_cols=num_cols, 205 | is_portrait=is_portrait) 206 | for idx, filename in enumerate(image_list): 207 | image = load_image(filename) 208 | if is_row_major: 209 | row_idx, col_idx = divmod(idx, self.num_cols) 210 | else: 211 | col_idx, row_idx = divmod(idx, self.num_rows) 212 | self.add(row_idx, col_idx, image) 213 | if save_path: 214 | self.save(save_path) 215 | 216 | def visualize_directory(self, 217 | directory, 218 | save_path=None, 219 | num_rows=0, 220 | num_cols=0, 221 | is_portrait=False, 222 | is_row_major=True): 223 | """Visualizes all images under a directory.""" 224 | image_list = list_images_from_dir(directory) 225 | self.visualize_list(image_list=image_list, 226 | save_path=save_path, 227 | num_rows=num_rows, 228 | num_cols=num_cols, 229 | is_portrait=is_portrait, 230 | is_row_major=is_row_major) 231 | 232 | def save(self, path): 233 | """Saves the grid.""" 234 | save_image(path, self.grid) 235 | -------------------------------------------------------------------------------- /utils/visualizers/test.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Unit test for visualizer.""" 3 | 4 | import os 5 | import skvideo.datasets 6 | 7 | from ..image_utils import save_image 8 | from . import GridVisualizer 9 | from . import HtmlVisualizer 10 | from . import HtmlReader 11 | from . import GifVisualizer 12 | from . import VideoVisualizer 13 | from . import VideoReader 14 | 15 | __all__ = ['test_visualizer'] 16 | 17 | _TEST_DIR = 'visualizer_test' 18 | 19 | 20 | def test_visualizer(test_dir=_TEST_DIR): 21 | """Tests visualizers.""" 22 | print('========== Start Visualizer Test ==========') 23 | 24 | frame_dir = os.path.join(test_dir, 'test_frames') 25 | os.makedirs(frame_dir, exist_ok=True) 26 | 27 | print('===== Testing `VideoReader` =====') 28 | # Total 132 frames, with size (720, 1080). 29 | video_reader = VideoReader(skvideo.datasets.bigbuckbunny()) 30 | frame_height = video_reader.frame_height 31 | frame_width = video_reader.frame_width 32 | frame_size = (frame_height, frame_width) 33 | half_size = (frame_height // 2, frame_width // 2) 34 | # Save frames as the test set. 35 | for idx in range(80): 36 | frame = video_reader.read() 37 | save_image(os.path.join(frame_dir, f'{idx:02d}.png'), frame) 38 | 39 | print('===== Testing `GirdVisualizer` =====') 40 | grid_visualizer = GridVisualizer() 41 | grid_visualizer.set_row_spacing(30) 42 | grid_visualizer.set_col_spacing(30) 43 | grid_visualizer.set_background(use_black=True) 44 | path = os.path.join(test_dir, 'portrait_row_major_ori_space30_black.png') 45 | grid_visualizer.visualize_directory(frame_dir, path, 46 | is_portrait=True, is_row_major=True) 47 | path = os.path.join( 48 | test_dir, 'landscape_col_major_downsample_space15_white.png') 49 | grid_visualizer.set_image_size(half_size) 50 | grid_visualizer.set_row_spacing(15) 51 | grid_visualizer.set_col_spacing(15) 52 | grid_visualizer.set_background(use_black=False) 53 | grid_visualizer.visualize_directory(frame_dir, path, 54 | is_portrait=False, is_row_major=False) 55 | 56 | print('===== Testing `HtmlVisualizer` =====') 57 | html_visualizer = HtmlVisualizer() 58 | path = os.path.join(test_dir, 'portrait_col_major_ori.html') 59 | html_visualizer.visualize_directory(frame_dir, path, 60 | is_portrait=True, is_row_major=False) 61 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html') 62 | html_visualizer.set_image_size(half_size) 63 | html_visualizer.visualize_directory(frame_dir, path, 64 | is_portrait=False, is_row_major=True) 65 | 66 | print('===== Testing `HtmlReader` =====') 67 | path = os.path.join(test_dir, 'landscape_row_major_downsample.html') 68 | html_reader = HtmlReader(path) 69 | for j in range(html_reader.num_cols): 70 | assert html_reader.get_header(j) == '' 71 | parsed_dir = os.path.join(test_dir, 'parsed_frames') 72 | os.makedirs(parsed_dir, exist_ok=True) 73 | for i in range(html_reader.num_rows): 74 | for j in range(html_reader.num_cols): 75 | idx = i * html_reader.num_cols + j 76 | assert html_reader.get_text(i, j).endswith(f'(index {idx:03d})') 77 | image = html_reader.get_image(i, j, image_size=frame_size) 78 | assert image.shape[0:2] == frame_size 79 | save_image(os.path.join(parsed_dir, f'{idx:02d}.png'), image) 80 | 81 | print('===== Testing `GifVisualizer` =====') 82 | gif_visualizer = GifVisualizer() 83 | path = os.path.join(test_dir, 'gif_ori.gif') 84 | gif_visualizer.visualize_directory(frame_dir, path) 85 | gif_visualizer.set_image_size(half_size) 86 | path = os.path.join(test_dir, 'gif_downsample.gif') 87 | gif_visualizer.visualize_directory(frame_dir, path) 88 | 89 | print('===== Testing `VideoVisualizer` =====') 90 | video_visualizer = VideoVisualizer() 91 | path = os.path.join(test_dir, 'video_ori.mp4') 92 | video_visualizer.visualize_directory(frame_dir, path) 93 | path = os.path.join(test_dir, 'video_downsample.mp4') 94 | video_visualizer.set_frame_size(half_size) 95 | video_visualizer.visualize_directory(frame_dir, path) 96 | 97 | print('========== Finish Visualizer Test ==========') 98 | -------------------------------------------------------------------------------- /utils/visualizers/video_visualizer.py: -------------------------------------------------------------------------------- 1 | # python3.7 2 | """Contains the visualizer to visualize images as a video. 3 | 4 | This file relies on `FFmpeg`. Use `sudo apt-get install ffmpeg` and 5 | `brew install ffmpeg` to install on Ubuntu and MacOS respectively. 6 | """ 7 | 8 | import os.path 9 | from skvideo.io import FFmpegWriter 10 | from skvideo.io import FFmpegReader 11 | 12 | from ..image_utils import parse_image_size 13 | from ..image_utils import load_image 14 | from ..image_utils import resize_image 15 | from ..image_utils import list_images_from_dir 16 | 17 | __all__ = ['VideoVisualizer', 'VideoReader'] 18 | 19 | 20 | class VideoVisualizer(object): 21 | """Defines the video visualizer that presents images as a video.""" 22 | 23 | def __init__(self, 24 | path=None, 25 | frame_size=None, 26 | fps=25.0, 27 | codec='libx264', 28 | pix_fmt='yuv420p', 29 | crf=1): 30 | """Initializes the video visualizer. 31 | 32 | Args: 33 | path: Path to write the video. (default: None) 34 | frame_size: Frame size, i.e., (height, width). (default: None) 35 | fps: Frames per second. (default: 24) 36 | codec: Codec. (default: `libx264`) 37 | pix_fmt: Pixel format. (default: `yuv420p`) 38 | crf: Constant rate factor, which controls the compression. The 39 | larger this field is, the higher compression and lower quality. 40 | `0` means no compression and consequently the highest quality. 41 | To enable QuickTime playing (requires YUV to be 4:2:0, but 42 | `crf = 0` results YUV to be 4:4:4), please set this field as 43 | at least 1. (default: 1) 44 | """ 45 | self.set_path(path) 46 | self.set_frame_size(frame_size) 47 | self.set_fps(fps) 48 | self.set_codec(codec) 49 | self.set_pix_fmt(pix_fmt) 50 | self.set_crf(crf) 51 | self.video = None 52 | 53 | def set_path(self, path=None): 54 | """Sets the path to save the video.""" 55 | self.path = path 56 | 57 | def set_frame_size(self, frame_size=None): 58 | """Sets the video frame size.""" 59 | height, width = parse_image_size(frame_size) 60 | self.frame_height = height 61 | self.frame_width = width 62 | 63 | def set_fps(self, fps=25.0): 64 | """Sets the FPS (frame per second) of the video.""" 65 | self.fps = fps 66 | 67 | def set_codec(self, codec='libx264'): 68 | """Sets the video codec.""" 69 | self.codec = codec 70 | 71 | def set_pix_fmt(self, pix_fmt='yuv420p'): 72 | """Sets the video pixel format.""" 73 | self.pix_fmt = pix_fmt 74 | 75 | def set_crf(self, crf=1): 76 | """Sets the CRF (constant rate factor) of the video.""" 77 | self.crf = crf 78 | 79 | def init_video(self): 80 | """Initializes an empty video with expected settings.""" 81 | assert not os.path.exists(self.path), f'Video `{self.path}` existed!' 82 | assert self.frame_height > 0 83 | assert self.frame_width > 0 84 | 85 | video_setting = { 86 | '-r': f'{self.fps:.2f}', 87 | '-s': f'{self.frame_width}x{self.frame_height}', 88 | '-vcodec': f'{self.codec}', 89 | '-crf': f'{self.crf}', 90 | '-pix_fmt': f'{self.pix_fmt}', 91 | } 92 | self.video = FFmpegWriter(self.path, outputdict=video_setting) 93 | 94 | def add(self, frame): 95 | """Adds a frame into the video visualizer. 96 | 97 | NOTE: The input frame is assumed to be with `RGB` channel order. 98 | """ 99 | if self.video is None: 100 | height, width = frame.shape[0:2] 101 | height = self.frame_height or height 102 | width = self.frame_width or width 103 | self.set_frame_size((height, width)) 104 | self.init_video() 105 | if frame.shape[0:2] != (self.frame_height, self.frame_width): 106 | frame = resize_image(frame, (self.frame_width, self.frame_height)) 107 | self.video.writeFrame(frame) 108 | 109 | def visualize_collection(self, images, save_path=None): 110 | """Visualizes a collection of images one by one.""" 111 | if save_path is not None and save_path != self.path: 112 | self.save() 113 | self.set_path(save_path) 114 | for image in images: 115 | self.add(image) 116 | self.save() 117 | 118 | def visualize_list(self, image_list, save_path=None): 119 | """Visualizes a list of image files.""" 120 | if save_path is not None and save_path != self.path: 121 | self.save() 122 | self.set_path(save_path) 123 | for filename in image_list: 124 | image = load_image(filename) 125 | self.add(image) 126 | self.save() 127 | 128 | def visualize_directory(self, directory, save_path=None): 129 | """Visualizes all images under a directory.""" 130 | image_list = list_images_from_dir(directory) 131 | self.visualize_list(image_list, save_path) 132 | 133 | def save(self): 134 | """Saves the video by closing the file.""" 135 | if self.video is not None: 136 | self.video.close() 137 | self.video = None 138 | self.set_path(None) 139 | 140 | 141 | class VideoReader(object): 142 | """Defines the video reader. 143 | 144 | This class can be used to read frames from a given video. 145 | 146 | NOTE: Each frame can be read only once. 147 | TODO: Fix this? 148 | """ 149 | 150 | def __init__(self, path, inputdict=None): 151 | """Initializes the video reader by loading the video from disk.""" 152 | self.path = path 153 | self.video = FFmpegReader(path, inputdict=inputdict) 154 | 155 | self.length = self.video.inputframenum 156 | self.frame_height = self.video.inputheight 157 | self.frame_width = self.video.inputwidth 158 | self.fps = self.video.inputfps 159 | self.pix_fmt = self.video.pix_fmt 160 | 161 | def __del__(self): 162 | """Releases the opened video.""" 163 | self.video.close() 164 | 165 | def read(self, image_size=None): 166 | """Reads the next frame.""" 167 | frame = next(self.video.nextFrame()) 168 | height, width = parse_image_size(image_size) 169 | height = height or frame.shape[0] 170 | width = width or frame.shape[1] 171 | if frame.shape[0:2] != (height, width): 172 | frame = resize_image(frame, (width, height)) 173 | return frame 174 | --------------------------------------------------------------------------------