├── LICENSE ├── LICENSE-LPIPS ├── LICENSE-Stylegan2 ├── README.md ├── dataloader ├── __init__.py └── dataset.py ├── figs ├── cxr-seg.png ├── datasetgan_demo.png ├── face-parts-opt-steps.png ├── face-parts-seg.png ├── method.png ├── skin-lesion-seg.png └── teaser3.png ├── giistr-cla.md ├── models ├── __init__.py ├── encoder_model.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── stylegan2.py ├── stylegan2_seg.py └── utils.py ├── requirements.txt ├── semanticGAN ├── __init__.py ├── inference.py ├── losses.py ├── prepare_inception.py ├── preprocessing │ ├── __init__.py │ ├── face_postprocessing.py │ └── face_preprocessing.py ├── ranger.py ├── samplers.py ├── train_enc.py └── train_seg_gan.py └── utils ├── __init__.py ├── data_util.py ├── distributed.py ├── inception_utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2021 NVIDIA Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | -------------------------------------------------------------------------------- /LICENSE-Stylegan2: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SemanticGAN 2 | This is the official code for: 3 | 4 | #### Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalization 5 | 6 | [Daiqing Li](https://scholar.google.ca/citations?user=8q2ISMIAAAAJ&hl=en), [Junlin Yang](https://scholar.google.com/citations?user=QYkscc4AAAAJ&hl=en), [Karsten Kreis](https://scholar.google.de/citations?user=rFd-DiAAAAAJ&hl=de), [Antonio Torralba](https://groups.csail.mit.edu/vision/torralbalab/), [Sanja Fidler](http://www.cs.toronto.edu/~fidler/) 7 | 8 | CVPR 2021 **[[Paper](https://arxiv.org/abs/2104.05833)] [[Supp](https://nv-tlabs.github.io/semanticGAN/resources/SemanticGAN_supp.pdf)] [[Page](https://nv-tlabs.github.io/semanticGAN/)]** 9 | 10 | 11 | 12 | 13 | ## Requirements 14 | - Python 3.6 or 3.7 are supported. 15 | - Pytorch 1.4.0 + is recommended. 16 | - This code is tested with CUDA 10.2 toolkit and CuDNN 7.5. 17 | - Please check the python package requirement from [`requirements.txt`](requirements.txt), and install using 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | ## Dataset 22 | We recently release MetFaces40 annotation we use as out-of-domain testing. Please notice this dataset is under the [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. To view a copy of this license, visit [LICENSE](https://github.com/nv-tlabs/semanticGAN_code/blob/main/LICENSE). Please see [GDrive](https://drive.google.com/drive/folders/1ibZzaWSUVoQ94OPoLNS0FrUufBeDvRP4?usp=sharing). 23 | ## Training 24 | 25 | To reproduce paper **Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalization**: 26 | 27 | 1. Run **Step1: Semantic GAN training** 28 | 2. Run **Step2: Encoder training** 29 | 3. Run **Inference & Optimization**. 30 | 31 | 32 | --- 33 | #### 0. Prepare for FID calculation 34 | In order to calculate FID score, you need to prepare inception features for your dataset, 35 | 36 | ``` 37 | python prepare_inception.py \ 38 | --size [resolution of the image] \ 39 | --batch [batch size] \ 40 | --output [path to save the inception file, in .pkl] \ 41 | --dataset_name celeba-mask \ 42 | [positional argument 1, path to the image folder]] \ 43 | ``` 44 | #### 1. GAN Training 45 | 46 | For training GAN with both image and its label, 47 | 48 | ``` 49 | python train_seg_gan.py \ 50 | --img_dataset [path-to-img-folder] \ 51 | --seg_dataset [path-to-seg-folder] \ 52 | --inception [path-to-inception file] \ 53 | --seg_name celeba-mask \ 54 | --checkpoint_dir [path-to-ckpt-dir] \ 55 | ``` 56 | 57 | To use multi-gpus training in the cloud, 58 | 59 | ``` 60 | python -m torch.distributed.launch \ 61 | --nproc_per_node=N_GPU \ 62 | --master_port=PORTtrain_gan.py \ 63 | train_gan.py \ 64 | --img_dataset [path-to-img-folder] \ 65 | --inception [path-to-inception file] \ 66 | --dataset_name celeba-mask \ 67 | --checkpoint_dir [path-to-ckpt-dir] \ 68 | ``` 69 | 70 | #### 2. Encoder Triaining 71 | 72 | ``` 73 | python train_enc.py \ 74 | --img_dataset [path-to-img-folder] \ 75 | --seg_dataset [path-to-seg-folder] \ 76 | --ckpt [path-to-pretrained GAN model] \ 77 | --seg_name celeba-mask \ 78 | --enc_backboend [fpn|res] \ 79 | --checkpoint_dir [path-to-ckpt-dir] \ 80 | ``` 81 | 82 | ## Inference 83 | 84 | For Face Parts Segmentation Task 85 | 86 | ![img](./figs/face-parts-seg.png?lastModify=1616189357) 87 | 88 | ``` 89 | python inference.py \ 90 | --ckpt [path-to-ckpt] \ 91 | --img_dir [path-to-test-folder] \ 92 | --outdir [path-to-output-folder] \ 93 | --dataset_name celeba-mask \ 94 | --w_plus \ 95 | --image_mode RGB \ 96 | --seg_dim 8 \ 97 | --step 200 [optimization steps] \ 98 | ``` 99 | 100 | Visualization of different optimization steps 101 | 102 | ![img](./figs/face-parts-opt-steps.png) 103 | 104 | 105 | ## Citation 106 | 107 | Please cite the following paper if you used the code in this repository. 108 | 109 | ``` 110 | @inproceedings{semanticGAN, 111 | title={Semantic Segmentation with Generative Models: Semi-Supervised Learning and Strong Out-of-Domain Generalization}, 112 | booktitle={Conference on Computer Vision and Pattern Recognition (CVPR)}, 113 | author={Li, Daiqing and Yang, Junlin and Kreis, Karsten and Torralba, Antonio and Fidler, Sanja}, 114 | year={2021}, 115 | } 116 | ``` 117 | 118 | 119 | 120 | ## License 121 | For any code dependency related to Stylegan2, the license is under the Nvidia Source Code License-NC. To view a copy of this license, visit https://nvlabs.github.io/stylegan2/license.html 122 | 123 | The work SemanticGAN is released under MIT License. 124 | 125 | ``` 126 | The MIT License (MIT) 127 | 128 | Copyright (c) 2021 NVIDIA Corporation. 129 | 130 | Permission is hereby granted, free of charge, to any person obtaining a copy of 131 | this software and associated documentation files (the "Software"), to deal in 132 | the Software without restriction, including without limitation the rights to 133 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 134 | the Software, and to permit persons to whom the Software is furnished to do so, 135 | subject to the following conditions: 136 | 137 | The above copyright notice and this permission notice shall be included in all 138 | copies or substantial portions of the Software. 139 | 140 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 141 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 142 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 143 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 144 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 145 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 146 | ``` -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ -------------------------------------------------------------------------------- /dataloader/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | from PIL import Image, ImageOps 24 | from torch.utils.data import Dataset 25 | from torchvision import transforms 26 | import os 27 | import numpy as np 28 | import torch 29 | import cv2 30 | import albumentations 31 | import albumentations.augmentations as A 32 | 33 | class HistogramEqualization(object): 34 | def __call__(self, img): 35 | img_eq = ImageOps.equalize(img) 36 | 37 | return img_eq 38 | 39 | class AdjustGamma(object): 40 | def __init__(self, gamma): 41 | self.gamma = gamma 42 | 43 | def __call__(self, img): 44 | img_gamma = transforms.functional.adjust_gamma(img, self.gamma) 45 | 46 | return img_gamma 47 | 48 | class CelebAMaskDataset(Dataset): 49 | def __init__(self, args, dataroot, unlabel_transform=None, latent_dir=None, is_label=True, phase='train', 50 | limit_size=None, unlabel_limit_size=None, aug=False, resolution=256): 51 | 52 | self.args = args 53 | self.is_label = is_label 54 | 55 | 56 | if is_label == True: 57 | self.latent_dir = latent_dir 58 | self.data_root = os.path.join(dataroot, 'label_data') 59 | 60 | if phase == 'train': 61 | if limit_size is None: 62 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 'train_full_list.txt'), dtype=str) 63 | else: 64 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 65 | 'train_{}_list.txt'.format(limit_size)), dtype=str).reshape(-1) 66 | elif phase == 'val': 67 | if limit_size is None: 68 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 'val_full_list.txt'), dtype=str) 69 | else: 70 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 71 | 'val_{}_list.txt'.format(limit_size)), dtype=str).reshape(-1) 72 | elif phase == 'train-val': 73 | # concat both train and val 74 | if limit_size is None: 75 | train_list = np.loadtxt(os.path.join(self.data_root, 'train_full_list.txt'), dtype=str) 76 | val_list = np.loadtxt(os.path.join(self.data_root, 'val_full_list.txt'), dtype=str) 77 | self.idx_list = list(train_list) + list(val_list) 78 | else: 79 | train_list = np.loadtxt(os.path.join(self.data_root, 80 | 'train_{}_list.txt'.format(limit_size)), dtype=str).reshape(-1) 81 | val_list = np.loadtxt(os.path.join(self.data_root, 82 | 'val_{}_list.txt'.format(limit_size)), dtype=str).reshape(-1) 83 | self.idx_list = list(train_list) + list(val_list) 84 | else: 85 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 'test_list.txt'), dtype=str) 86 | else: 87 | self.data_root = os.path.join(dataroot, 'unlabel_data') 88 | if unlabel_limit_size is None: 89 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 'unlabel_list.txt'), dtype=str) 90 | else: 91 | self.idx_list = np.loadtxt(os.path.join(self.data_root, 'unlabel_{}_list.txt'.format(unlabel_limit_size)), dtype=str) 92 | 93 | self.img_dir = os.path.join(self.data_root, 'image') 94 | self.label_dir = os.path.join(self.data_root, 'label') 95 | 96 | self.phase = phase 97 | self.color_map = { 98 | 0: [ 0, 0, 0], 99 | 1: [ 0,0,205], 100 | 2: [132,112,255], 101 | 3: [ 25,25,112], 102 | 4: [187,255,255], 103 | 5: [ 102,205,170], 104 | 6: [ 227,207,87], 105 | 7: [ 142,142,56] 106 | } 107 | 108 | self.data_size = len(self.idx_list) 109 | self.resolution = resolution 110 | 111 | self.aug = aug 112 | if aug == True: 113 | self.aug_t = albumentations.Compose([ 114 | A.transforms.HorizontalFlip(p=0.5), 115 | A.transforms.ShiftScaleRotate(shift_limit=0.1, 116 | scale_limit=0.2, 117 | rotate_limit=15, 118 | border_mode=cv2.BORDER_CONSTANT, 119 | value=0, 120 | mask_value=0, 121 | p=0.5), 122 | ]) 123 | 124 | self.unlabel_transform = unlabel_transform 125 | 126 | 127 | def _mask_labels(self, mask_np): 128 | label_size = len(self.color_map.keys()) 129 | labels = np.zeros((label_size, mask_np.shape[0], mask_np.shape[1])) 130 | for i in range(label_size): 131 | labels[i][mask_np==i] = 1.0 132 | 133 | return labels 134 | 135 | 136 | @staticmethod 137 | def preprocess(img): 138 | image_transform = transforms.Compose( 139 | [ 140 | transforms.ToTensor(), 141 | transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5), inplace=True) 142 | ] 143 | ) 144 | img_tensor = image_transform(img) 145 | # normalize 146 | # img_tensor = (img_tensor - img_tensor.min()) / (img_tensor.max() - img_tensor.min()) 147 | # img_tensor = (img_tensor - 0.5) / 0.5 148 | 149 | return img_tensor 150 | 151 | 152 | def __len__(self): 153 | if hasattr(self.args, 'n_gpu') == False: 154 | return self.data_size 155 | # make sure dataloader size is larger than batchxngpu size 156 | return max(self.args.batch*self.args.n_gpu, self.data_size) 157 | 158 | def __getitem__(self, idx): 159 | if idx >= self.data_size: 160 | idx = idx % (self.data_size) 161 | img_idx = self.idx_list[idx] 162 | img_pil = Image.open(os.path.join(self.img_dir, img_idx)).convert('RGB').resize((self.resolution, self.resolution)) 163 | mask_pil = Image.open(os.path.join(self.label_dir, img_idx)).convert('L').resize((self.resolution, self.resolution), resample=0) 164 | 165 | if self.is_label: 166 | if (self.phase == 'train' or self.phase == 'train-val') and self.aug: 167 | augmented = self.aug_t(image=np.array(img_pil), mask=np.array(mask_pil)) 168 | aug_img_pil = Image.fromarray(augmented['image']) 169 | # apply pixel-wise transformation 170 | img_tensor = self.preprocess(aug_img_pil) 171 | 172 | mask_np = np.array(augmented['mask']) 173 | labels = self._mask_labels(mask_np) 174 | 175 | mask_tensor = torch.tensor(labels, dtype=torch.float) 176 | mask_tensor = (mask_tensor - 0.5) / 0.5 177 | 178 | else: 179 | img_tensor = self.preprocess(img_pil) 180 | mask_np = np.array(mask_pil) 181 | labels = self._mask_labels(mask_np) 182 | 183 | mask_tensor = torch.tensor(labels, dtype=torch.float) 184 | mask_tensor = (mask_tensor - 0.5) / 0.5 185 | 186 | return { 187 | 'image': img_tensor, 188 | 'mask': mask_tensor 189 | } 190 | else: 191 | img_tensor = self.unlabel_transform(img_pil) 192 | return { 193 | 'image': img_tensor, 194 | } 195 | -------------------------------------------------------------------------------- /figs/cxr-seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/cxr-seg.png -------------------------------------------------------------------------------- /figs/datasetgan_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/datasetgan_demo.png -------------------------------------------------------------------------------- /figs/face-parts-opt-steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/face-parts-opt-steps.png -------------------------------------------------------------------------------- /figs/face-parts-seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/face-parts-seg.png -------------------------------------------------------------------------------- /figs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/method.png -------------------------------------------------------------------------------- /figs/skin-lesion-seg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/skin-lesion-seg.png -------------------------------------------------------------------------------- /figs/teaser3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/figs/teaser3.png -------------------------------------------------------------------------------- /giistr-cla.md: -------------------------------------------------------------------------------- 1 | ## Individual Contributor License Agreement (CLA) 2 | 3 | **Thank you for submitting your contributions to this project.** 4 | 5 | By signing this CLA, you agree that the following terms apply to all of your past, present and future contributions 6 | to the project. 7 | 8 | ### License. 9 | 10 | You hereby represent that all present, past and future contributions are governed by the 11 | [MIT License](https://opensource.org/licenses/MIT) 12 | copyright statement. 13 | 14 | This entails that to the extent possible under law, you transfer all copyright and related or neighboring rights 15 | of the code or documents you contribute to the project itself or its maintainers. 16 | Furthermore you also represent that you have the authority to perform the above waiver 17 | with respect to the entirety of you contributions. 18 | 19 | ### Moral Rights. 20 | 21 | To the fullest extent permitted under applicable law, you hereby waive, and agree not to 22 | assert, all of your “moral rights” in or relating to your contributions for the benefit of the project. 23 | 24 | ### Third Party Content. 25 | 26 | If your Contribution includes or is based on any source code, object code, bug fixes, configuration changes, tools, 27 | specifications, documentation, data, materials, feedback, information or other works of authorship that were not 28 | authored by you (“Third Party Content”) or if you are aware of any third party intellectual property or proprietary 29 | rights associated with your Contribution (“Third Party Rights”), 30 | then you agree to include with the submission of your Contribution full details respecting such Third Party 31 | Content and Third Party Rights, including, without limitation, identification of which aspects of your 32 | Contribution contain Third Party Content or are associated with Third Party Rights, the owner/author of the 33 | Third Party Content and Third Party Rights, where you obtained the Third Party Content, and any applicable 34 | third party license terms or restrictions respecting the Third Party Content and Third Party Rights. For greater 35 | certainty, the foregoing obligations respecting the identification of Third Party Content and Third Party Rights 36 | do not apply to any portion of a Project that is incorporated into your Contribution to that same Project. 37 | 38 | ### Representations. 39 | 40 | You represent that, other than the Third Party Content and Third Party Rights identified by 41 | you in accordance with this Agreement, you are the sole author of your Contributions and are legally entitled 42 | to grant the foregoing licenses and waivers in respect of your Contributions. If your Contributions were 43 | created in the course of your employment with your past or present employer(s), you represent that such 44 | employer(s) has authorized you to make your Contributions on behalf of such employer(s) or such employer 45 | (s) has waived all of their right, title or interest in or to your Contributions. 46 | 47 | ### Disclaimer. 48 | 49 | To the fullest extent permitted under applicable law, your Contributions are provided on an "as is" 50 | basis, without any warranties or conditions, express or implied, including, without limitation, any implied 51 | warranties or conditions of non-infringement, merchantability or fitness for a particular purpose. You are not 52 | required to provide support for your Contributions, except to the extent you desire to provide support. 53 | 54 | ### No Obligation. 55 | 56 | You acknowledge that the maintainers of this project are under no obligation to use or incorporate your contributions 57 | into the project. The decision to use or incorporate your contributions into the project will be made at the 58 | sole discretion of the maintainers or their authorized delegates. 59 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/models/__init__.py -------------------------------------------------------------------------------- /models/encoder_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | from models.stylegan2_seg import EqualLinear, ConvLayer, ResBlock 28 | import math 29 | 30 | class Bottleneck(nn.Module): 31 | expansion = 4 32 | 33 | def __init__(self, in_planes, planes, stride=1): 34 | super(Bottleneck, self).__init__() 35 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 36 | self.bn1 = nn.BatchNorm2d(planes) 37 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 40 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 41 | 42 | self.shortcut = nn.Sequential() 43 | if stride != 1 or in_planes != self.expansion*planes: 44 | self.shortcut = nn.Sequential( 45 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 46 | nn.BatchNorm2d(self.expansion*planes) 47 | ) 48 | 49 | def forward(self, x): 50 | out = F.relu(self.bn1(self.conv1(x))) 51 | out = F.relu(self.bn2(self.conv2(out))) 52 | out = self.bn3(self.conv3(out)) 53 | out += self.shortcut(x) 54 | out = F.relu(out) 55 | return out 56 | 57 | 58 | class FPN(nn.Module): 59 | def __init__(self, input_dim, block, num_blocks): 60 | super(FPN, self).__init__() 61 | self.in_planes = 64 62 | self.feature_dim = 512 63 | 64 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | 67 | # Bottom-up layers 68 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 69 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 70 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 71 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 72 | 73 | # Top layer 74 | self.toplayer = nn.Conv2d(2048, self.feature_dim, kernel_size=1, stride=1, padding=0) # Reduce channels 75 | 76 | # Smooth layers 77 | self.smooth1 = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=3, stride=1, padding=1) 78 | self.smooth2 = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=3, stride=1, padding=1) 79 | self.smooth3 = nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=3, stride=1, padding=1) 80 | 81 | # Lateral layers 82 | self.latlayer1 = nn.Conv2d(1024, self.feature_dim, kernel_size=1, stride=1, padding=0) 83 | self.latlayer2 = nn.Conv2d( 512, self.feature_dim, kernel_size=1, stride=1, padding=0) 84 | self.latlayer3 = nn.Conv2d( 256, self.feature_dim, kernel_size=1, stride=1, padding=0) 85 | 86 | def _make_layer(self, block, planes, num_blocks, stride): 87 | strides = [stride] + [1]*(num_blocks-1) 88 | layers = [] 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, stride)) 91 | self.in_planes = planes * block.expansion 92 | return nn.Sequential(*layers) 93 | 94 | def _upsample_add(self, x, y): 95 | '''Upsample and add two feature maps. 96 | Args: 97 | x: (Variable) top feature map to be upsampled. 98 | y: (Variable) lateral feature map. 99 | Returns: 100 | (Variable) added feature map. 101 | Note in PyTorch, when input size is odd, the upsampled feature map 102 | with `F.upsample(..., scale_factor=2, mode='nearest')` 103 | maybe not equal to the lateral feature map size. 104 | e.g. 105 | original input size: [N,_,15,15] -> 106 | conv2d feature map size: [N,_,8,8] -> 107 | upsampled feature map size: [N,_,16,16] 108 | So we choose bilinear upsample which supports arbitrary output sizes. 109 | ''' 110 | _,_,H,W = y.size() 111 | return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y 112 | 113 | def forward(self, x): 114 | # Bottom-up 115 | c1 = F.relu(self.bn1(self.conv1(x))) 116 | c1 = F.max_pool2d(c1, kernel_size=3, stride=2, padding=1) 117 | c2 = self.layer1(c1) 118 | c3 = self.layer2(c2) 119 | c4 = self.layer3(c3) 120 | c5 = self.layer4(c4) 121 | # Top-down 122 | p5 = self.toplayer(c5) 123 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 124 | p3 = self._upsample_add(p4, self.latlayer2(c3)) 125 | p2 = self._upsample_add(p3, self.latlayer3(c2)) 126 | # Smooth 127 | p4 = self.smooth1(p4) 128 | p3 = self.smooth2(p3) 129 | p2 = self.smooth3(p2) 130 | 131 | return p2, p3, p4 132 | 133 | def conv3x3(in_planes, out_planes, stride=1, has_bias=False): 134 | "3x3 convolution with padding" 135 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 136 | padding=1, bias=has_bias) 137 | 138 | 139 | def conv3x3_bn_relu(in_planes, out_planes, stride=1): 140 | return nn.Sequential( 141 | conv3x3(in_planes, out_planes, stride), 142 | nn.BatchNorm2d(out_planes), 143 | nn.ReLU(inplace=True), 144 | ) 145 | 146 | class ToStyleCode(nn.Module): 147 | def __init__(self, n_convs, input_dim=512, out_dim=512): 148 | super(ToStyleCode, self).__init__() 149 | self.convs = nn.ModuleList() 150 | self.out_dim = out_dim 151 | 152 | for i in range(n_convs): 153 | if i == 0: 154 | self.convs.append( 155 | nn.Conv2d(in_channels=input_dim, out_channels=out_dim, kernel_size=3, padding=1, stride=2)) 156 | #self.convs.append(nn.BatchNorm2d(out_dim)) 157 | #self.convs.append(nn.InstanceNorm2d(out_dim)) 158 | self.convs.append(nn.LeakyReLU(inplace=True)) 159 | else: 160 | self.convs.append(nn.Conv2d(in_channels=out_dim, out_channels=out_dim, kernel_size=3, padding=1, stride=2)) 161 | self.convs.append(nn.LeakyReLU(inplace=True)) 162 | 163 | self.convs = nn.Sequential(*self.convs) 164 | self.linear = EqualLinear(out_dim, out_dim) 165 | 166 | def forward(self, x): 167 | x = self.convs(x) 168 | x = x.view(-1, self.out_dim) 169 | x = self.linear(x) 170 | return x 171 | 172 | 173 | class ToStyleHead(nn.Module): 174 | def __init__(self, input_dim=512, out_dim=512): 175 | super(ToStyleHead, self).__init__() 176 | self.out_dim = out_dim 177 | 178 | self.convs = nn.Sequential( 179 | conv3x3_bn_relu(input_dim, input_dim, 1), 180 | nn.AdaptiveAvgPool2d(1), 181 | # output 1x1 182 | nn.Conv2d(in_channels=input_dim, out_channels=out_dim, kernel_size=1) 183 | ) 184 | 185 | def forward(self, x): 186 | x = self.convs(x) 187 | x = x.view(x.shape[0],self.out_dim) 188 | return x 189 | 190 | class FPNEncoder(nn.Module): 191 | def __init__(self, input_dim, n_latent=14, use_style_head=False, style_layers=[4,5,6]): 192 | super(FPNEncoder, self).__init__() 193 | 194 | self.n_latent = n_latent 195 | num_blocks = [3,4,6,3] #resnet 50 196 | self.FPN_module = FPN(input_dim, Bottleneck, num_blocks) 197 | # course block 0-2, 4x4->8x8 198 | self.course_styles = nn.ModuleList() 199 | for i in range(3): 200 | if use_style_head: 201 | self.course_styles.append(ToStyleHead()) 202 | else: 203 | self.course_styles.append(ToStyleCode(n_convs=style_layers[0])) 204 | # medium1 block 3-6 16x16->32x32 205 | self.medium_styles = nn.ModuleList() 206 | for i in range(4): 207 | if use_style_head: 208 | self.medium_styles.append(ToStyleHead()) 209 | else: 210 | self.medium_styles.append(ToStyleCode(n_convs=style_layers[1])) 211 | # fine block 7-13 64x64->256x256 212 | self.fine_styles = nn.ModuleList() 213 | for i in range(n_latent - 7): 214 | if use_style_head: 215 | self.fine_styles.append(ToStyleHead()) 216 | else: 217 | self.fine_styles.append(ToStyleCode(n_convs=style_layers[2])) 218 | 219 | def forward(self, x): 220 | styles = [] 221 | # FPN feature 222 | p2, p3, p4 = self.FPN_module(x) 223 | 224 | for style_map in self.course_styles: 225 | styles.append(style_map(p4)) 226 | 227 | for style_map in self.medium_styles: 228 | styles.append(style_map(p3)) 229 | 230 | for style_map in self.fine_styles: 231 | styles.append(style_map(p2)) 232 | 233 | styles = torch.stack(styles, dim=1) 234 | 235 | return styles 236 | 237 | 238 | class ResEncoder(nn.Module): 239 | def __init__(self, size, input_dim, n_latent, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 240 | super().__init__() 241 | 242 | self.channels = { 243 | 4: 512, 244 | 8: 512, 245 | 16: 512, 246 | 32: 512, 247 | 64: 256 * channel_multiplier, 248 | 128: 128 * channel_multiplier, 249 | 256: 64 * channel_multiplier, 250 | 512: 32 * channel_multiplier, 251 | 1024: 16 * channel_multiplier, 252 | } 253 | 254 | 255 | convs = [ConvLayer(input_dim, self.channels[size], 1)] 256 | 257 | log_size = int(math.log(size, 2)) 258 | 259 | in_channel = self.channels[size] 260 | 261 | for i in range(log_size, 2, -1): 262 | out_channel = self.channels[2 ** (i - 1)] 263 | 264 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 265 | 266 | in_channel = out_channel 267 | 268 | self.convs = nn.Sequential(*convs) 269 | 270 | self.n_latent = n_latent 271 | self.stddev_group = 4 272 | self.stddev_feat = 1 273 | 274 | self.final_conv = ConvLayer(in_channel + 1, self.channels[4], 3) 275 | self.final_linear = EqualLinear(self.channels[4] * 4 * 4, n_latent * 512) 276 | 277 | def _cal_stddev(self, x): 278 | batch, channel, height, width = x.shape 279 | group = min(batch, self.stddev_group) 280 | stddev = x.view( 281 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 282 | ) 283 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 284 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 285 | stddev = stddev.repeat(group, 1, height, width) 286 | x = torch.cat([x, stddev], 1) 287 | 288 | return x 289 | 290 | def forward(self, input): 291 | batch = input.shape[0] 292 | 293 | out = self.convs(input) 294 | 295 | out = self._cal_stddev(out) 296 | 297 | out = self.final_conv(out) 298 | 299 | out = out.view(batch, -1) 300 | out = self.final_linear(out) 301 | 302 | out = out.view(batch, self.n_latent, -1) 303 | 304 | return out 305 | -------------------------------------------------------------------------------- /models/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | 10 | from models.lpips import dist_model 11 | 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=False, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /models/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | class BaseModel(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | #pass 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, use_gpu=False, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print('Loading network from %s'%save_path) 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def update_learning_rate(): 48 | pass 49 | 50 | def get_image_paths(self): 51 | return self.image_paths 52 | 53 | def save_done(self, flag=False): 54 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 55 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 56 | 57 | -------------------------------------------------------------------------------- /models/lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | import torch 4 | import os 5 | from collections import OrderedDict 6 | from torch.autograd import Variable 7 | from .base_model import BaseModel 8 | from scipy.ndimage import zoom 9 | from . import networks_basic as networks 10 | from .. import lpips as util 11 | 12 | 13 | class DistModel(BaseModel): 14 | def name(self): 15 | return self.model_name 16 | 17 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 18 | use_gpu=False, printNet=False, spatial=False, 19 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 20 | ''' 21 | INPUTS 22 | model - ['net-lin'] for linearly calibrated network 23 | ['net'] for off-the-shelf network 24 | ['L2'] for L2 distance in Lab colorspace 25 | ['SSIM'] for ssim in RGB colorspace 26 | net - ['squeeze','alex','vgg'] 27 | model_path - if None, will look in weights/[NET_NAME].pth 28 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 29 | use_gpu - bool - whether or not to use a GPU 30 | printNet - bool - whether or not to print network architecture out 31 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 32 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 33 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 34 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 35 | is_train - bool - [True] for training mode 36 | lr - float - initial learning rate 37 | beta1 - float - initial momentum term for adam 38 | version - 0.1 for latest, 0.0 was original (with a bug) 39 | gpu_ids - int array - [0] by default, gpus to use 40 | ''' 41 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 42 | 43 | self.model = model 44 | self.net = net 45 | self.is_train = is_train 46 | self.spatial = spatial 47 | self.gpu_ids = gpu_ids 48 | self.model_name = '%s [%s]'%(model,net) 49 | 50 | if(self.model == 'net-lin'): # pretrained net + linear layer 51 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 52 | use_dropout=True, spatial=spatial, version=version, lpips=True) 53 | kw = {} 54 | if not use_gpu: 55 | kw['map_location'] = 'cpu' 56 | if(model_path is None): 57 | import inspect 58 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 59 | 60 | if(not is_train): 61 | print('Loading model from: %s'%model_path) 62 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 63 | 64 | elif(self.model=='net'): # pretrained network 65 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 66 | elif(self.model in ['L2','l2']): 67 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 68 | self.model_name = 'L2' 69 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 70 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 71 | self.model_name = 'SSIM' 72 | else: 73 | raise ValueError("Model [%s] not recognized." % self.model) 74 | 75 | #self.parameters = list(self.net.parameters()) 76 | 77 | if self.is_train: # training mode 78 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 79 | self.rankLoss = networks.BCERankingLoss() 80 | self.parameters += list(self.rankLoss.net.parameters()) 81 | self.lr = lr 82 | self.old_lr = lr 83 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 84 | else: # test mode 85 | self.net.eval() 86 | 87 | # if(use_gpu): 88 | # #self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 89 | # self.net = self.net.to(gpu_ids[0]) 90 | # if(self.is_train): 91 | # self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 92 | 93 | if(printNet): 94 | print('---------- Networks initialized -------------') 95 | networks.print_network(self.net) 96 | print('-----------------------------------------------') 97 | 98 | def forward(self, in0, in1, retPerLayer=False): 99 | ''' Function computes the distance between image patches in0 and in1 100 | INPUTS 101 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 102 | OUTPUT 103 | computed distances between in0 and in1 104 | ''' 105 | # check if input has 3 dim 106 | if in0.shape[1] == 1: 107 | in0 = in0.expand(-1,3,-1,-1) 108 | if in1.shape[1] == 1: 109 | in1 = in1.expand(-1,3,-1,-1) 110 | 111 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 112 | 113 | # ***** TRAINING FUNCTIONS ***** 114 | def optimize_parameters(self): 115 | self.forward_train() 116 | self.optimizer_net.zero_grad() 117 | self.backward_train() 118 | self.optimizer_net.step() 119 | self.clamp_weights() 120 | 121 | def clamp_weights(self): 122 | for module in self.net.modules(): 123 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 124 | module.weight.data = torch.clamp(module.weight.data,min=0) 125 | 126 | def set_input(self, data): 127 | self.input_ref = data['ref'] 128 | self.input_p0 = data['p0'] 129 | self.input_p1 = data['p1'] 130 | self.input_judge = data['judge'] 131 | 132 | if(self.use_gpu): 133 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 134 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 135 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 136 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 137 | 138 | self.var_ref = Variable(self.input_ref,requires_grad=True) 139 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 140 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 141 | 142 | def forward_train(self): # run forward pass 143 | # print(self.net.module.scaling_layer.shift) 144 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 145 | 146 | self.d0 = self.forward(self.var_ref, self.var_p0) 147 | self.d1 = self.forward(self.var_ref, self.var_p1) 148 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 149 | 150 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 151 | 152 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 153 | 154 | return self.loss_total 155 | 156 | def backward_train(self): 157 | torch.mean(self.loss_total).backward() 158 | 159 | def compute_accuracy(self,d0,d1,judge): 160 | ''' d0, d1 are Variables, judge is a Tensor ''' 161 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 204 | self.old_lr = lr 205 | 206 | def score_2afc_dataset(data_loader, func, name=''): 207 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 208 | distance function 'func' in dataloader 'data_loader' 209 | INPUTS 210 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 211 | func - callable distance function - calling d=func(in0,in1) should take 2 212 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 213 | OUTPUTS 214 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 215 | [1] - dictionary with following elements 216 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 217 | gts - N array in [0,1], preferred patch selected by human evaluators 218 | (closer to "0" for left patch p0, "1" for right patch p1, 219 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 220 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 221 | CONSTS 222 | N - number of test triplets in data_loader 223 | ''' 224 | 225 | d0s = [] 226 | d1s = [] 227 | gts = [] 228 | 229 | for data in data_loader.load_data(): 230 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 231 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 232 | gts+=data['judge'].cpu().numpy().flatten().tolist() 233 | 234 | d0s = np.array(d0s) 235 | d1s = np.array(d1s) 236 | gts = np.array(gts) 237 | scores = (d0s 2: 42 | dim += list(range(2, grad_input.ndim)) 43 | 44 | if bias: 45 | grad_bias = grad_input.sum(dim).detach() 46 | 47 | else: 48 | grad_bias = empty 49 | 50 | return grad_input, grad_bias 51 | 52 | @staticmethod 53 | def backward(ctx, gradgrad_input, gradgrad_bias): 54 | out, = ctx.saved_tensors 55 | gradgrad_out = fused.fused_bias_act( 56 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) -------------------------------------------------------------------------------- /models/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | 10 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 11 | int act, int grad, float alpha, float scale); 12 | 13 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 14 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 15 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 18 | int act, int grad, float alpha, float scale) { 19 | CHECK_CUDA(input); 20 | CHECK_CUDA(bias); 21 | 22 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 27 | } -------------------------------------------------------------------------------- /models/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | 10 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 11 | int up_x, int up_y, int down_x, int down_y, 12 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 13 | 14 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 15 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 16 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 19 | int up_x, int up_y, int down_x, int down_y, 20 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 21 | CHECK_CUDA(input); 22 | CHECK_CUDA(kernel); 23 | 24 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 25 | } 26 | 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 29 | } -------------------------------------------------------------------------------- /models/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import os 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | from torch.autograd import Function 12 | from torch.utils.cpp_extension import load 13 | 14 | 15 | module_path = os.path.dirname(__file__) 16 | upfirdn2d_op = load( 17 | "upfirdn2d", 18 | sources=[ 19 | os.path.join(module_path, "upfirdn2d.cpp"), 20 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 21 | ], 22 | ) 23 | 24 | 25 | class UpFirDn2dBackward(Function): 26 | @staticmethod 27 | def forward( 28 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 29 | ): 30 | 31 | up_x, up_y = up 32 | down_x, down_y = down 33 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 34 | 35 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 36 | 37 | grad_input = upfirdn2d_op.upfirdn2d( 38 | grad_output, 39 | grad_kernel, 40 | down_x, 41 | down_y, 42 | up_x, 43 | up_y, 44 | g_pad_x0, 45 | g_pad_x1, 46 | g_pad_y0, 47 | g_pad_y1, 48 | ) 49 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 50 | 51 | ctx.save_for_backward(kernel) 52 | 53 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 54 | 55 | ctx.up_x = up_x 56 | ctx.up_y = up_y 57 | ctx.down_x = down_x 58 | ctx.down_y = down_y 59 | ctx.pad_x0 = pad_x0 60 | ctx.pad_x1 = pad_x1 61 | ctx.pad_y0 = pad_y0 62 | ctx.pad_y1 = pad_y1 63 | ctx.in_size = in_size 64 | ctx.out_size = out_size 65 | 66 | return grad_input 67 | 68 | @staticmethod 69 | def backward(ctx, gradgrad_input): 70 | kernel, = ctx.saved_tensors 71 | 72 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 73 | 74 | gradgrad_out = upfirdn2d_op.upfirdn2d( 75 | gradgrad_input, 76 | kernel, 77 | ctx.up_x, 78 | ctx.up_y, 79 | ctx.down_x, 80 | ctx.down_y, 81 | ctx.pad_x0, 82 | ctx.pad_x1, 83 | ctx.pad_y0, 84 | ctx.pad_y1, 85 | ) 86 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 87 | gradgrad_out = gradgrad_out.view( 88 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 89 | ) 90 | 91 | return gradgrad_out, None, None, None, None, None, None, None, None 92 | 93 | 94 | class UpFirDn2d(Function): 95 | @staticmethod 96 | def forward(ctx, input, kernel, up, down, pad): 97 | up_x, up_y = up 98 | down_x, down_y = down 99 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 100 | 101 | kernel_h, kernel_w = kernel.shape 102 | batch, channel, in_h, in_w = input.shape 103 | ctx.in_size = input.shape 104 | 105 | input = input.reshape(-1, in_h, in_w, 1) 106 | 107 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 108 | 109 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 110 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 111 | ctx.out_size = (out_h, out_w) 112 | 113 | ctx.up = (up_x, up_y) 114 | ctx.down = (down_x, down_y) 115 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 116 | 117 | g_pad_x0 = kernel_w - pad_x0 - 1 118 | g_pad_y0 = kernel_h - pad_y0 - 1 119 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 120 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 121 | 122 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 123 | 124 | out = upfirdn2d_op.upfirdn2d( 125 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 126 | ) 127 | # out = out.view(major, out_h, out_w, minor) 128 | out = out.view(-1, channel, out_h, out_w) 129 | 130 | return out 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output): 134 | kernel, grad_kernel = ctx.saved_tensors 135 | 136 | grad_input = UpFirDn2dBackward.apply( 137 | grad_output, 138 | kernel, 139 | grad_kernel, 140 | ctx.up, 141 | ctx.down, 142 | ctx.pad, 143 | ctx.g_pad, 144 | ctx.in_size, 145 | ctx.out_size, 146 | ) 147 | 148 | return grad_input, None, None, None, None 149 | 150 | 151 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 152 | if input.device.type == "cpu": 153 | out = upfirdn2d_native( 154 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 155 | ) 156 | 157 | else: 158 | out = UpFirDn2d.apply( 159 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 160 | ) 161 | 162 | return out 163 | 164 | 165 | def upfirdn2d_native( 166 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 167 | ): 168 | _, channel, in_h, in_w = input.shape 169 | input = input.reshape(-1, in_h, in_w, 1) 170 | 171 | _, in_h, in_w, minor = input.shape 172 | kernel_h, kernel_w = kernel.shape 173 | 174 | out = input.view(-1, in_h, 1, in_w, 1, minor) 175 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 176 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 177 | 178 | out = F.pad( 179 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 180 | ) 181 | out = out[ 182 | :, 183 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 184 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 185 | :, 186 | ] 187 | 188 | out = out.permute(0, 3, 1, 2) 189 | out = out.reshape( 190 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 191 | ) 192 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 193 | out = F.conv2d(out, w) 194 | out = out.reshape( 195 | -1, 196 | minor, 197 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 198 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 199 | ) 200 | out = out.permute(0, 2, 3, 1) 201 | out = out[:, ::down_y, ::down_x, :] 202 | 203 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 204 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 205 | 206 | return out.view(-1, channel, out_h, out_w) -------------------------------------------------------------------------------- /models/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /models/stylegan2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | from torch import nn 8 | from models.utils import * 9 | import numpy as np 10 | 11 | 12 | class Generator(nn.Module): 13 | def __init__( 14 | self, 15 | size, 16 | style_dim, 17 | n_mlp, 18 | channel_multiplier=2, 19 | blur_kernel=[1, 3, 3, 1], 20 | lr_mlp=0.01, 21 | randomize_noise=True, 22 | image_mode='RGB', 23 | ): 24 | super().__init__() 25 | 26 | self.size = size 27 | 28 | self.style_dim = style_dim 29 | 30 | layers = [PixelNorm()] 31 | 32 | for i in range(n_mlp): 33 | layers.append( 34 | EqualLinear( 35 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 36 | ) 37 | ) 38 | 39 | self.style = nn.Sequential(*layers) 40 | 41 | self.channels = { 42 | 4: 512, 43 | 8: 512, 44 | 16: 512, 45 | 32: 512, 46 | 64: 256 * channel_multiplier, 47 | 128: 128 * channel_multiplier, 48 | 256: 64 * channel_multiplier, 49 | 512: 32 * channel_multiplier, 50 | 1024: 16 * channel_multiplier, 51 | } 52 | 53 | self.input = ConstantInput(self.channels[4]) 54 | self.conv1 = StyledConv( 55 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 56 | ) 57 | 58 | if image_mode == 'RGB': 59 | self.rgb_channel = 3 60 | else: 61 | self.rgb_channel = 1 62 | 63 | 64 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, out_channel=self.rgb_channel, upsample=False) 65 | 66 | self.log_size = int(math.log(size, 2)) 67 | self.num_layers = (self.log_size - 2) * 2 + 1 68 | 69 | self.convs = nn.ModuleList() 70 | self.upsamples = nn.ModuleList() 71 | self.to_rgbs = nn.ModuleList() 72 | self.noises = nn.Module() 73 | self.randomize_noise = randomize_noise 74 | in_channel = self.channels[4] 75 | 76 | for layer_idx in range(self.num_layers): 77 | res = (layer_idx + 5) // 2 78 | shape = [1, 1, 2 ** res, 2 ** res] 79 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 80 | 81 | for i in range(3, self.log_size + 1): 82 | out_channel = self.channels[2 ** i] 83 | 84 | self.convs.append( 85 | StyledConv( 86 | in_channel, 87 | out_channel, 88 | 3, 89 | style_dim, 90 | upsample=True, 91 | blur_kernel=blur_kernel, 92 | ) 93 | ) 94 | 95 | self.convs.append( 96 | StyledConv( 97 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 98 | ) 99 | ) 100 | 101 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 102 | 103 | in_channel = out_channel 104 | 105 | self.n_latent = self.log_size * 2 - 2 106 | 107 | 108 | def make_noise(self): 109 | device = self.input.input.device 110 | 111 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 112 | 113 | for i in range(3, self.log_size + 1): 114 | for _ in range(2): 115 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 116 | 117 | return noises 118 | 119 | def make_mean_latent(self, n_latent): 120 | latent_in = torch.randn( 121 | n_latent, self.style_dim, device=self.input.input.device 122 | ) 123 | latent = self.style(latent_in).mean(0, keepdim=True) 124 | self.mean_latent = latent 125 | 126 | return latent 127 | 128 | def get_latent(self, input): 129 | 130 | style = self.style(input) 131 | return style 132 | 133 | def truncation(self, input): 134 | 135 | out = self.mean_latent.unsqueeze(1) + 0.7 * (input - self.mean_latent.unsqueeze(1)) 136 | 137 | return out 138 | 139 | 140 | def g_mapping(self, input): 141 | style = self.style(input) 142 | style = style.unsqueeze(1).repeat(1, self.n_latent, 1) 143 | 144 | return style 145 | 146 | def g_synthesis(self, latent): 147 | if self.randomize_noise: 148 | noise = [None] * self.num_layers 149 | else: 150 | noise = [ 151 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 152 | ] 153 | styles_feature = [] 154 | 155 | out = self.input(latent) 156 | styles_feature.append(out) 157 | 158 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 159 | styles_feature.append(out) 160 | 161 | skip = self.to_rgb1(out, latent[:, 1]) 162 | 163 | i = 1 164 | 165 | for conv1, conv2, noise1, noise2, to_rgb in zip( 166 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 167 | ): 168 | out = conv1(out, latent[:, i], noise=noise1) 169 | 170 | styles_feature.append(out) 171 | 172 | out = conv2(out, latent[:, i + 1], noise=noise2) 173 | 174 | styles_feature.append(out) 175 | 176 | skip = to_rgb(out, latent[:, i + 2], skip) 177 | 178 | i += 2 179 | 180 | image = skip 181 | 182 | return image, styles_feature 183 | 184 | def sample(self, num, latent_space_type='Z'): 185 | """Samples latent codes randomly. 186 | Args: 187 | num: Number of latent codes to sample. Should be positive. 188 | latent_space_type: Type of latent space from which to sample latent code. 189 | Only [`Z`, `W`, `WP`] are supported. Case insensitive. (default: `Z`) 190 | Returns: 191 | A `numpy.ndarray` as sampled latend codes. 192 | Raises: 193 | ValueError: If the given `latent_space_type` is not supported. 194 | """ 195 | latent_space_type = latent_space_type.upper() 196 | if latent_space_type == 'Z': 197 | latent_codes = np.random.randn(num, self.style_dim) 198 | elif latent_space_type == 'W': 199 | latent_codes = np.random.randn(num, self.style_dim) 200 | elif latent_space_type == 'WP': 201 | latent_codes = np.random.randn(num, self.n_latent, self.style_dim) 202 | else: 203 | raise ValueError(f'Latent space type `{latent_space_type}` is invalid!') 204 | 205 | return latent_codes.astype(np.float32) 206 | 207 | 208 | def forward( 209 | self, 210 | styles, 211 | return_latents=False, 212 | inject_index=None, 213 | 214 | input_is_latent=False, 215 | noise=None, 216 | randomize_noise=True, 217 | ): 218 | if not input_is_latent: 219 | styles = [self.style(s) for s in styles] 220 | 221 | 222 | if noise is None: 223 | if randomize_noise: 224 | noise = [None] * self.num_layers 225 | else: 226 | noise = [ 227 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 228 | ] 229 | 230 | 231 | 232 | styles = self.truncation(styles) 233 | 234 | if len(styles) < 2: 235 | inject_index = self.n_latent 236 | 237 | if styles[0].ndim < 3: 238 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 239 | 240 | else: 241 | latent = styles[0] 242 | 243 | else: 244 | if inject_index is None: 245 | inject_index = random.randint(1, self.n_latent - 1) 246 | 247 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 248 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 249 | 250 | latent = torch.cat([latent, latent2], 1) 251 | 252 | out = self.input(latent) 253 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 254 | 255 | skip = self.to_rgb1(out, latent[:, 1]) 256 | 257 | i = 1 258 | for conv1, conv2, noise1, noise2, to_rgb in zip( 259 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 260 | ): 261 | out = conv1(out, latent[:, i], noise=noise1) 262 | out = conv2(out, latent[:, i + 1], noise=noise2) 263 | skip = to_rgb(out, latent[:, i + 2], skip) 264 | 265 | i += 2 266 | 267 | image = skip 268 | 269 | if return_latents: 270 | return image, latent 271 | 272 | else: 273 | return image, None 274 | 275 | 276 | class Discriminator(nn.Module): 277 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 278 | super().__init__() 279 | 280 | channels = { 281 | 4: 512, 282 | 8: 512, 283 | 16: 512, 284 | 32: 512, 285 | 64: 256 * channel_multiplier, 286 | 128: 128 * channel_multiplier, 287 | 256: 64 * channel_multiplier, 288 | 512: 32 * channel_multiplier, 289 | 1024: 16 * channel_multiplier, 290 | } 291 | 292 | convs = [ConvLayer(3, channels[size], 1)] 293 | 294 | log_size = int(math.log(size, 2)) 295 | 296 | in_channel = channels[size] 297 | 298 | for i in range(log_size, 2, -1): 299 | out_channel = channels[2 ** (i - 1)] 300 | 301 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 302 | 303 | in_channel = out_channel 304 | 305 | self.convs = nn.Sequential(*convs) 306 | 307 | self.stddev_group = 4 308 | self.stddev_feat = 1 309 | 310 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 311 | self.final_linear = nn.Sequential( 312 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 313 | EqualLinear(channels[4], 1), 314 | ) 315 | 316 | def forward(self, input): 317 | out = self.convs(input) 318 | 319 | batch, channel, height, width = out.shape 320 | group = min(batch, self.stddev_group) 321 | stddev = out.view( 322 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 323 | ) 324 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 325 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 326 | stddev = stddev.repeat(group, 1, height, width) 327 | out = torch.cat([out, stddev], 1) 328 | 329 | out = self.final_conv(out) 330 | 331 | out = out.view(batch, -1) 332 | out = self.final_linear(out) 333 | 334 | return out 335 | 336 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available under the Nvidia Source Code License-NC. 4 | # To view a copy of this license, visit 5 | # https://nvlabs.github.io/stylegan2/license.html 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from models.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 13 | 14 | class Blur(nn.Module): 15 | def __init__(self, kernel, pad, upsample_factor=1): 16 | super().__init__() 17 | 18 | kernel = make_kernel(kernel) 19 | 20 | if upsample_factor > 1: 21 | kernel = kernel * (upsample_factor ** 2) 22 | 23 | self.register_buffer('kernel', kernel) 24 | 25 | self.pad = pad 26 | 27 | def forward(self, input): 28 | out = upfirdn2d(input, self.kernel, pad=self.pad) 29 | 30 | return out 31 | 32 | 33 | class ConvLayer(nn.Sequential): 34 | def __init__( 35 | self, 36 | in_channel, 37 | out_channel, 38 | kernel_size, 39 | downsample=False, 40 | blur_kernel=[1, 3, 3, 1], 41 | bias=True, 42 | activate=True, 43 | ): 44 | layers = [] 45 | 46 | if downsample: 47 | factor = 2 48 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 49 | pad0 = (p + 1) // 2 50 | pad1 = p // 2 51 | 52 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 53 | 54 | stride = 2 55 | self.padding = 0 56 | 57 | else: 58 | stride = 1 59 | self.padding = kernel_size // 2 60 | 61 | layers.append( 62 | EqualConv2d( 63 | in_channel, 64 | out_channel, 65 | kernel_size, 66 | padding=self.padding, 67 | stride=stride, 68 | bias=bias and not activate, 69 | ) 70 | ) 71 | 72 | if activate: 73 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 74 | 75 | super().__init__(*layers) 76 | 77 | 78 | class ResBlock(nn.Module): 79 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 80 | super().__init__() 81 | 82 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 83 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 84 | 85 | self.skip = ConvLayer( 86 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 87 | ) 88 | 89 | def forward(self, input): 90 | out = self.conv1(input) 91 | out = self.conv2(out) 92 | 93 | skip = self.skip(input) 94 | out = (out + skip) / math.sqrt(2) 95 | 96 | return out 97 | 98 | 99 | class PixelNorm(nn.Module): 100 | def __init__(self): 101 | super().__init__() 102 | 103 | def forward(self, input): 104 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 105 | 106 | 107 | def make_kernel(k): 108 | k = torch.tensor(k, dtype=torch.float32) 109 | 110 | if k.ndim == 1: 111 | k = k[None, :] * k[:, None] 112 | 113 | k /= k.sum() 114 | 115 | return k 116 | 117 | 118 | class Upsample(nn.Module): 119 | def __init__(self, kernel, factor=2): 120 | super().__init__() 121 | 122 | self.factor = factor 123 | kernel = make_kernel(kernel) * (factor ** 2) 124 | self.register_buffer("kernel", kernel) 125 | 126 | p = kernel.shape[0] - factor 127 | 128 | pad0 = (p + 1) // 2 + factor - 1 129 | pad1 = p // 2 130 | 131 | self.pad = (pad0, pad1) 132 | 133 | def forward(self, input): 134 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 135 | 136 | return out 137 | 138 | 139 | class Downsample(nn.Module): 140 | def __init__(self, kernel, factor=2): 141 | super().__init__() 142 | 143 | self.factor = factor 144 | kernel = make_kernel(kernel) 145 | self.register_buffer("kernel", kernel) 146 | 147 | p = kernel.shape[0] - factor 148 | 149 | pad0 = (p + 1) // 2 150 | pad1 = p // 2 151 | 152 | self.pad = (pad0, pad1) 153 | 154 | def forward(self, input): 155 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 156 | 157 | return out 158 | 159 | 160 | class EqualConv2d(nn.Module): 161 | def __init__( 162 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 163 | ): 164 | super().__init__() 165 | 166 | self.weight = nn.Parameter( 167 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 168 | ) 169 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 170 | 171 | self.stride = stride 172 | self.padding = padding 173 | 174 | if bias: 175 | self.bias = nn.Parameter(torch.zeros(out_channel)) 176 | 177 | else: 178 | self.bias = None 179 | 180 | def forward(self, input): 181 | out = F.conv2d( 182 | input, 183 | self.weight * self.scale, 184 | bias=self.bias, 185 | stride=self.stride, 186 | padding=self.padding, 187 | ) 188 | 189 | return out 190 | 191 | def __repr__(self): 192 | return ( 193 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 194 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 195 | ) 196 | 197 | 198 | class EqualLinear(nn.Module): 199 | def __init__( 200 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 201 | ): 202 | super().__init__() 203 | 204 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 205 | 206 | if bias: 207 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 208 | 209 | else: 210 | self.bias = None 211 | 212 | self.activation = activation 213 | 214 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 215 | self.lr_mul = lr_mul 216 | 217 | def forward(self, input): 218 | if self.activation: 219 | out = F.linear(input, self.weight * self.scale) 220 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 221 | 222 | else: 223 | out = F.linear( 224 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 225 | ) 226 | 227 | return out 228 | 229 | def __repr__(self): 230 | return ( 231 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 232 | ) 233 | 234 | 235 | class ModulatedConv2d(nn.Module): 236 | def __init__( 237 | self, 238 | in_channel, 239 | out_channel, 240 | kernel_size, 241 | style_dim, 242 | demodulate=True, 243 | upsample=False, 244 | downsample=False, 245 | blur_kernel=[1, 3, 3, 1], 246 | ): 247 | super().__init__() 248 | 249 | self.eps = 1e-8 250 | self.kernel_size = kernel_size 251 | self.in_channel = in_channel 252 | self.out_channel = out_channel 253 | self.upsample = upsample 254 | self.downsample = downsample 255 | 256 | if upsample: 257 | factor = 2 258 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 259 | pad0 = (p + 1) // 2 + factor - 1 260 | pad1 = p // 2 + 1 261 | 262 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 263 | 264 | if downsample: 265 | factor = 2 266 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 267 | pad0 = (p + 1) // 2 268 | pad1 = p // 2 269 | 270 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 271 | 272 | fan_in = in_channel * kernel_size ** 2 273 | self.scale = 1 / math.sqrt(fan_in) 274 | self.padding = kernel_size // 2 275 | 276 | self.weight = nn.Parameter( 277 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 278 | ) 279 | 280 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 281 | 282 | self.demodulate = demodulate 283 | 284 | def __repr__(self): 285 | return ( 286 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 287 | f"upsample={self.upsample}, downsample={self.downsample})" 288 | ) 289 | 290 | def forward(self, input, style): 291 | batch, in_channel, height, width = input.shape 292 | 293 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 294 | weight = self.scale * self.weight * style 295 | 296 | if self.demodulate: 297 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 298 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 299 | 300 | weight = weight.view( 301 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 302 | ) 303 | 304 | if self.upsample: 305 | input = input.view(1, batch * in_channel, height, width) 306 | weight = weight.view( 307 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 308 | ) 309 | weight = weight.transpose(1, 2).reshape( 310 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 311 | ) 312 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 313 | _, _, height, width = out.shape 314 | out = out.view(batch, self.out_channel, height, width) 315 | out = self.blur(out) 316 | 317 | elif self.downsample: 318 | input = self.blur(input) 319 | _, _, height, width = input.shape 320 | input = input.view(1, batch * in_channel, height, width) 321 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 322 | _, _, height, width = out.shape 323 | out = out.view(batch, self.out_channel, height, width) 324 | 325 | else: 326 | input = input.view(1, batch * in_channel, height, width) 327 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 328 | _, _, height, width = out.shape 329 | out = out.view(batch, self.out_channel, height, width) 330 | 331 | return out 332 | 333 | 334 | class CondStyledConv(nn.Module): 335 | def __init__( 336 | self, 337 | in_channel, 338 | out_channel, 339 | kernel_size, 340 | style_dim, 341 | upsample=False, 342 | blur_kernel=[1, 3, 3, 1], 343 | demodulate=True, 344 | ): 345 | super().__init__() 346 | 347 | self.conv = ModulatedConv2d( 348 | in_channel, 349 | out_channel, 350 | kernel_size, 351 | style_dim, 352 | upsample=upsample, 353 | blur_kernel=blur_kernel, 354 | demodulate=demodulate, 355 | ) 356 | 357 | self.noise = CondInjection() 358 | 359 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 360 | # self.activate = ScaledLeakyReLU(0.2) 361 | self.activate = FusedLeakyReLU(out_channel) 362 | 363 | def forward(self, input, style, labels): 364 | out = self.conv(input, style) 365 | out = self.noise(out, labels) 366 | # out = out + self.bias 367 | out = self.activate(out) 368 | 369 | return out 370 | 371 | class CondInjection(nn.Module): 372 | def __init__(self): 373 | super().__init__() 374 | 375 | self.weight = nn.Parameter(torch.zeros(1)) 376 | 377 | def forward(self, image, labels, noise=None): 378 | if noise is None: 379 | batch, _, height, width = image.shape 380 | noise = image.new_empty(batch, 1, height, width).normal_() 381 | 382 | labels = labels.view(-1, 1, 1, 1) 383 | batch, _, height, width = image.shape 384 | cond = image.new_ones(batch, 1, height, width) / (labels + 1) 385 | 386 | # return image + self.weight * cond 387 | return image + self.weight * noise 388 | 389 | 390 | 391 | class NoiseInjection(nn.Module): 392 | def __init__(self): 393 | super().__init__() 394 | 395 | self.weight = nn.Parameter(torch.zeros(1)) 396 | 397 | def forward(self, image, noise=None): 398 | if noise is None: 399 | batch, _, height, width = image.shape 400 | noise = image.new_empty(batch, 1, height, width).normal_() 401 | 402 | return image + self.weight * noise 403 | 404 | 405 | class ConstantInput(nn.Module): 406 | def __init__(self, channel, size=4): 407 | super().__init__() 408 | 409 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 410 | 411 | def forward(self, input): 412 | batch = input.shape[0] 413 | out = self.input.repeat(batch, 1, 1, 1) 414 | 415 | return out 416 | 417 | 418 | class StyledConv(nn.Module): 419 | def __init__( 420 | self, 421 | in_channel, 422 | out_channel, 423 | kernel_size, 424 | style_dim, 425 | upsample=False, 426 | blur_kernel=[1, 3, 3, 1], 427 | demodulate=True, 428 | ): 429 | super().__init__() 430 | 431 | self.conv = ModulatedConv2d( 432 | in_channel, 433 | out_channel, 434 | kernel_size, 435 | style_dim, 436 | upsample=upsample, 437 | blur_kernel=blur_kernel, 438 | demodulate=demodulate, 439 | ) 440 | 441 | self.noise = NoiseInjection() 442 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 443 | # self.activate = ScaledLeakyReLU(0.2) 444 | self.activate = FusedLeakyReLU(out_channel) 445 | 446 | def forward(self, input, style, noise=None): 447 | out = self.conv(input, style) 448 | out = self.noise(out, noise=noise) 449 | # out = out + self.bias 450 | out = self.activate(out) 451 | 452 | return out 453 | 454 | class ScaledLeakyReLU(nn.Module): 455 | def __init__(self, negative_slope=0.2): 456 | super().__init__() 457 | 458 | self.negative_slope = negative_slope 459 | 460 | def forward(self, input): 461 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 462 | 463 | return out * math.sqrt(2) 464 | 465 | class ToRGB(nn.Module): 466 | def __init__(self, in_channel, out_channel=3, style_dim=512, upsample=True, blur_kernel=[1, 3, 3, 1]): 467 | super().__init__() 468 | 469 | if upsample: 470 | self.upsample = Upsample(blur_kernel) 471 | 472 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False) 473 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 474 | 475 | def forward(self, input, style, skip=None): 476 | out = self.conv(input, style) 477 | out = out + self.bias 478 | 479 | if skip is not None: 480 | skip = self.upsample(skip) 481 | 482 | out = out + skip 483 | 484 | return out 485 | 486 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.5.2 2 | imageio==2.8.0 3 | imageio-ffmpeg==0.4.2 4 | imgaug==0.4.0 5 | lmdb==0.98 6 | scikit-image==0.17.2 7 | scipy==1.5.0 8 | 9 | -------------------------------------------------------------------------------- /semanticGAN/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/semanticGAN/__init__.py -------------------------------------------------------------------------------- /semanticGAN/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | class LogCoshLoss(torch.nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def forward(self, true, pred): 31 | loss = true - pred 32 | return torch.mean(torch.log(torch.cosh(loss + 1e-12))) 33 | 34 | class SoftmaxLoss(torch.nn.Module): 35 | def __init__(self, tau=1.0): 36 | super().__init__() 37 | self.tau = tau 38 | self.ce_loss = torch.nn.CrossEntropyLoss() 39 | 40 | def forward(self, pred, true): 41 | logits = pred / self.tau 42 | l = self.ce_loss(logits, true) 43 | 44 | return l 45 | 46 | class SoftBinaryCrossEntropyLoss(torch.nn.Module): 47 | def __init__(self, tau=1.0): 48 | super().__init__() 49 | self.tau = tau 50 | # for numerical stable reason 51 | self.bce_logit = torch.nn.BCEWithLogitsLoss() 52 | 53 | def forward(self, pred, true): 54 | logits = pred / self.tau 55 | l = self.bce_logit(logits, true) 56 | 57 | return l 58 | 59 | def noise_regularize(noises): 60 | loss = 0 61 | batch_size = noises[0].shape[0] 62 | for noise in noises: 63 | size = noise.shape[2] 64 | 65 | while True: 66 | loss = ( 67 | loss 68 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 69 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 70 | ) 71 | 72 | if size <= 8: 73 | break 74 | 75 | noise = noise.reshape([batch_size, 1, size // 2, 2, size // 2, 2]) 76 | noise = noise.mean([3, 5]) 77 | size //= 2 78 | 79 | return loss 80 | 81 | class FocalLoss(nn.Module): 82 | """ 83 | copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py 84 | This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in 85 | 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)' 86 | Focal_Loss= -1*alpha*(1-pt)*log(pt) 87 | :param num_class: 88 | :param alpha: (tensor) 3D or 4D the scalar factor for this criterion 89 | :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more 90 | focus on hard misclassified example 91 | :param smooth: (float,double) smooth value when cross entropy 92 | :param balance_index: (int) balance class index, should be specific when alpha is float 93 | :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch. 94 | """ 95 | 96 | def __init__(self, alpha=None, gamma=2, tau=1.0, balance_index=0, smooth=1e-5, size_average=True): 97 | super(FocalLoss, self).__init__() 98 | self.alpha = alpha 99 | self.gamma = gamma 100 | self.tau = tau 101 | self.balance_index = balance_index 102 | self.smooth = smooth 103 | self.size_average = size_average 104 | 105 | if self.smooth is not None: 106 | if self.smooth < 0 or self.smooth > 1.0: 107 | raise ValueError('smooth value should be in [0,1]') 108 | 109 | def _apply_nonlin(self, logit): 110 | num_class = logit.shape[1] 111 | if num_class == 1: 112 | logit = torch.sigmoid(logit / self.tau) 113 | else: 114 | logit = torch.softmax(logit / self.tau, dim=1) 115 | 116 | return logit 117 | 118 | def forward(self, logit, target): 119 | logit = self._apply_nonlin(logit) 120 | num_class = logit.shape[1] 121 | 122 | if logit.dim() > 2: 123 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 124 | logit = logit.view(logit.size(0), logit.size(1), -1) 125 | logit = logit.permute(0, 2, 1).contiguous() 126 | logit = logit.view(-1, logit.size(-1)) 127 | target = torch.squeeze(target, 1) 128 | target = target.view(-1, 1) 129 | 130 | alpha = self.alpha 131 | 132 | if alpha is None: 133 | alpha = torch.ones(num_class, 1) 134 | elif isinstance(alpha, (list, np.ndarray)): 135 | assert len(alpha) == num_class 136 | alpha = torch.FloatTensor(alpha).view(num_class, 1) 137 | alpha = alpha / alpha.sum() 138 | elif isinstance(alpha, float): 139 | alpha = torch.ones(num_class, 1) 140 | alpha = alpha * (1 - self.alpha) 141 | alpha[self.balance_index] = self.alpha 142 | 143 | else: 144 | raise TypeError('Not support alpha type') 145 | 146 | if alpha.device != logit.device: 147 | alpha = alpha.to(logit.device) 148 | 149 | idx = target.cpu().long() 150 | 151 | one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_() 152 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 153 | if one_hot_key.device != logit.device: 154 | one_hot_key = one_hot_key.to(logit.device) 155 | 156 | if self.smooth: 157 | one_hot_key = torch.clamp( 158 | one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth) 159 | pt = (one_hot_key * logit).sum(1) + self.smooth 160 | logpt = pt.log() 161 | 162 | gamma = self.gamma 163 | 164 | alpha = alpha[idx] 165 | alpha = torch.squeeze(alpha) 166 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 167 | 168 | if self.size_average: 169 | loss = loss.mean() 170 | else: 171 | loss = loss.sum() 172 | return loss 173 | 174 | class DiceLoss(nn.Module): 175 | """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. 176 | For multi-class segmentation `weight` parameter can be used to assign different weights per class. 177 | The input to the loss function is assumed to be a logit and will be normalized by the Sigmoid function. 178 | """ 179 | 180 | def __init__(self, weight=None, sigmoid_tau=0.3, include_bg=False): 181 | super().__init__() 182 | self.register_buffer('weight', weight) 183 | self.normalization = nn.Sigmoid() 184 | self.sigmoid_tau = sigmoid_tau 185 | self.include_bg = include_bg 186 | 187 | def _flatten(self, tensor): 188 | """Flattens a given tensor such that the channel axis is first. 189 | The shapes are transformed as follows: 190 | (N, C, D, H, W) -> (C, N * D * H * W) 191 | """ 192 | # number of channels 193 | C = tensor.size(1) 194 | # new axis order 195 | axis_order = (1, 0) + tuple(range(2, tensor.dim())) 196 | # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) 197 | transposed = tensor.permute(axis_order) 198 | # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) 199 | return transposed.contiguous().view(C, -1) 200 | 201 | def _compute_per_channel_dice(self, input, target, epsilon=1e-6, weight=None): 202 | """ 203 | Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. 204 | Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. 205 | Args: 206 | input (torch.Tensor): NxCxSpatial input tensor 207 | target (torch.Tensor): NxCxSpatial target tensor 208 | epsilon (float): prevents division by zero 209 | weight (torch.Tensor): Cx1 tensor of weight per channel/class 210 | """ 211 | 212 | # input and target shapes must match 213 | assert input.size() == target.size(), "'input' and 'target' must have the same shape" 214 | 215 | input = self._flatten(input) 216 | target = self._flatten(target) 217 | target = target.float() 218 | 219 | # compute per channel Dice Coefficient 220 | intersect = (input * target).sum(-1) 221 | if weight is not None: 222 | intersect = weight * intersect 223 | 224 | # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) 225 | denominator = (input * input).sum(-1) + (target * target).sum(-1) 226 | return 2 * (intersect / denominator.clamp(min=epsilon)) 227 | 228 | def dice(self, input, target, weight): 229 | return self._compute_per_channel_dice(input, target, weight=weight) 230 | 231 | def forward(self, input, target): 232 | # get probabilities from logits 233 | input = self.normalization(input / self.sigmoid_tau) 234 | 235 | # compute per channel Dice coefficient 236 | per_channel_dice = self.dice(input, target, weight=self.weight) 237 | 238 | # average Dice score across all channels/classes 239 | if self.include_bg: 240 | return 1. - torch.mean(per_channel_dice) 241 | else: 242 | return 1. - torch.mean(per_channel_dice[1:]) 243 | -------------------------------------------------------------------------------- /semanticGAN/prepare_inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn.functional as F 26 | from torch.utils.data import DataLoader, ConcatDataset 27 | 28 | import argparse 29 | from utils import inception_utils 30 | from dataloader import (CelebAMaskDataset) 31 | import pickle 32 | 33 | @torch.no_grad() 34 | def extract_features(args, loader, inception, device): 35 | pbar = loader 36 | 37 | pools, logits = [], [] 38 | 39 | for data in pbar: 40 | img = data['image'] 41 | 42 | # check img dim 43 | if img.shape[1] != 3: 44 | img = img.expand(-1,3,-1,-1) 45 | 46 | img = img.to(device) 47 | pool_val, logits_val = inception(img) 48 | 49 | pools.append(pool_val.cpu().numpy()) 50 | logits.append(F.softmax(logits_val, dim=1).cpu().numpy()) 51 | 52 | pools = np.concatenate(pools, axis=0) 53 | logits = np.concatenate(logits, axis=0) 54 | 55 | return pools, logits 56 | 57 | 58 | def get_dataset(args): 59 | if args.dataset_name == 'celeba-mask': 60 | unlabel_dataset = CelebAMaskDataset(args, args.path, is_label=False) 61 | train_val_dataset = CelebAMaskDataset(args, args.path, is_label=True, phase='train-val') 62 | dataset = ConcatDataset([unlabel_dataset, train_val_dataset]) 63 | else: 64 | raise Exception('No such a dataloader!') 65 | return dataset 66 | 67 | if __name__ == '__main__': 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 | 70 | parser = argparse.ArgumentParser( 71 | description='Calculate Inception v3 features for datasets' 72 | ) 73 | parser.add_argument('--size', type=int, default=256) 74 | parser.add_argument('--batch', default=64, type=int, help='batch size') 75 | parser.add_argument('--n_sample', type=int, default=50000) 76 | parser.add_argument('--output', type=str, required=True) 77 | parser.add_argument('--image_mode', type=str, default='RGB') 78 | parser.add_argument('--dataset_name', type=str, help='[celeba-mask]') 79 | parser.add_argument('path', metavar='PATH', help='path to datset dir') 80 | 81 | args = parser.parse_args() 82 | 83 | inception = inception_utils.load_inception_net() 84 | 85 | dset = get_dataset(args) 86 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 87 | 88 | pools, logits = extract_features(args, loader, inception, device) 89 | 90 | # pools = pools[: args.n_sample] 91 | # logits = logits[: args.n_sample] 92 | 93 | print(f'extracted {pools.shape[0]} features') 94 | 95 | print('Calculating inception metrics...') 96 | IS_mean, IS_std = inception_utils.calculate_inception_score(logits) 97 | print('Training data from dataloader has IS of %5.5f +/- %5.5f' % (IS_mean, IS_std)) 98 | print('Calculating means and covariances...') 99 | 100 | mean = np.mean(pools, axis=0) 101 | cov = np.cov(pools, rowvar=False) 102 | 103 | with open(args.output, 'wb') as f: 104 | pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) 105 | -------------------------------------------------------------------------------- /semanticGAN/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/semanticGAN_code/342889ebbe817695c0e64133100ede8f9877f3de/semanticGAN/preprocessing/__init__.py -------------------------------------------------------------------------------- /semanticGAN/preprocessing/face_postprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import os 24 | import numpy as np 25 | from PIL import Image 26 | import json 27 | import argparse 28 | 29 | def find_coeffs(pa, pb): 30 | matrix = [] 31 | for p1, p2 in zip(pa, pb): 32 | matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]]) 33 | matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]]) 34 | 35 | A = np.matrix(matrix, dtype=np.float) 36 | B = np.array(pb).reshape(8) 37 | 38 | res = np.dot(np.linalg.inv(A.T * A) * A.T, B) 39 | return np.array(res).reshape(8) 40 | 41 | def main(args): 42 | img_list = sorted(os.listdir(args.img_dir)) 43 | meta_list = sorted(os.listdir(args.meta_dir)) 44 | raw_list = sorted(os.listdir(args.raw_dir)) 45 | 46 | for img_p, meta_p, raw_p in zip(img_list, meta_list, raw_list): 47 | img_n = img_p.split('.')[0] 48 | 49 | img_p = os.path.join(args.img_dir, img_p) 50 | meta_p = os.path.join(args.meta_dir, meta_p) 51 | raw_p = os.path.join(args.raw_dir, raw_p) 52 | 53 | with open(meta_p, 'r') as f: 54 | meta_json = json.load(f) 55 | 56 | kps = meta_json['quad'] 57 | crop_box = meta_json['bbox'] 58 | size = meta_json['size'] 59 | pad = meta_json['pad'] 60 | shrink = meta_json['shrink'] 61 | 62 | upper_left = kps[0:2] 63 | lower_left = kps[2:4] 64 | lower_right = kps[4:6] 65 | upper_right= kps[6:] 66 | all_kps = [upper_left, lower_left, lower_right, upper_right] 67 | pa = all_kps 68 | pb = [[0,0 ], [0, args.size], [args.size, args.size], [args.size,0]] 69 | 70 | coeffs = find_coeffs(pa, pb) 71 | 72 | left, top, right, bottom = crop_box 73 | 74 | width = size[0] 75 | height = size[1] 76 | 77 | img_pil = Image.open(img_p).convert('RGB') 78 | 79 | img_pil = img_pil.transform((width, height), Image.PERSPECTIVE, coeffs, Image.BILINEAR) 80 | 81 | #unpad 82 | img_np = np.array(img_pil) 83 | if (pad[0] == 0 and 84 | pad[1] == 0 and 85 | pad[2] == 0 and 86 | pad[3] == 0): 87 | pass 88 | else: 89 | if pad[3] != 0 and pad[2] != 0: 90 | img_np = img_np[pad[1]:-pad[3], pad[0]:-pad[2]] 91 | elif pad[3] == 0 and pad[2] != 0: 92 | img_np = img_np[pad[1]:, pad[0]:-pad[2]] 93 | elif pad[3] != 0 and pad[2] == 0: 94 | img_np = img_np[pad[1]:-pad[3], pad[0]:] 95 | else: 96 | img_np = img_np[pad[1]:, pad[0]:] 97 | 98 | crop_width = crop_box[2] - crop_box[0] 99 | crop_height = crop_box[3] - crop_box[1] 100 | #unshrink 101 | if shrink > 1: 102 | img_pil = Image.fromarray(img_np) 103 | rsize = (int(np.rint(float(img_pil.size[0]) * shrink)), int(np.rint(float(img_pil.size[1]) * shrink))) 104 | img_pil = img_pil.resize(rsize, resample=Image.LANCZOS) 105 | crop_width *= shrink 106 | crop_height *= shrink 107 | crop_box[3] *= shrink 108 | crop_box[2] *= shrink 109 | img_np = np.array(img_pil) 110 | 111 | assert crop_width == img_np.shape[1] 112 | assert crop_height == img_np.shape[0] 113 | 114 | img_ori_pil = Image.open(raw_p).convert('RGB') 115 | img_ori_np = np.array(img_ori_pil) 116 | 117 | img_ori_np[crop_box[1]:crop_box[3], crop_box[0]:crop_box[2]] = img_np 118 | 119 | img_ori_pil = Image.fromarray(img_ori_np) 120 | 121 | img_ori_pil.save(os.path.join(depth_out, img_n + '.png')) 122 | 123 | if __name__ == '__main__': 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('--raw_dir', type=str, required=True) 126 | parser.add_argument('--img_dir', type=str, required=True) 127 | parser.add_argument('--meta_dir', type=str, required=True) 128 | parser.add_argument('--outdir', type=str, required=True) 129 | parser.add_argument('--size', type=int, default=256) 130 | 131 | args = parser.parse_args() 132 | 133 | os.makedirs(args.outdir, exist_ok=True) 134 | 135 | main(args) -------------------------------------------------------------------------------- /semanticGAN/preprocessing/face_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import os 24 | import dlib 25 | import numpy as np 26 | import json 27 | import scipy.ndimage 28 | import PIL.Image 29 | import argparse 30 | 31 | 32 | def main(args): 33 | detector = dlib.get_frontal_face_detector() 34 | predictor = dlib.shape_predictor(args.detector) 35 | 36 | target_size = args.size 37 | supersampling = 4 38 | face_shrink = 2 39 | enable_padding = True 40 | 41 | 42 | img_out_dir = os.path.join(args.out_dir, 'image') 43 | meta_out_dir = os.path.join(args.out_dir, 'meta') 44 | 45 | img_list = sorted(os.listdir(args.img_dir)) 46 | 47 | os.makedirs(img_out_dir, exist_ok=True) 48 | os.makedirs(meta_out_dir, exist_ok=True) 49 | 50 | def rot90(v) -> np.ndarray: 51 | return np.array([-v[1], v[0]]) 52 | 53 | for img_n in img_list: 54 | img_p = os.path.join(args.img_dir, img_n) 55 | detector_img = dlib.load_rgb_image(img_p) 56 | 57 | # Ask the detector to find the bounding boxes of each face. The 1 in the 58 | # second argument indicates that we should upsample the image 1 time. This 59 | # will make everything bigger and allow us to detect more faces. 60 | dets = detector(detector_img, 1) 61 | print("Number of faces detected: {}".format(len(dets))) 62 | if len(dets) > 1: 63 | continue 64 | 65 | for k, d in enumerate(dets): 66 | 67 | # Get the landmarks/parts for the face in box d. 68 | shape = predictor(detector_img, d) 69 | all_parts = shape.parts() 70 | lm = np.array([ [item.x,item.y ] for item in all_parts]) 71 | landmarks = np.float32(lm) + 0.5 72 | assert landmarks.shape == (68, 2) 73 | 74 | lm_eye_left = landmarks[36 : 42] # left-clockwise 75 | lm_eye_right = landmarks[42 : 48] # left-clockwise 76 | lm_mouth_outer = landmarks[48 : 60] # left-clockwise 77 | 78 | # Calculate auxiliary vectors. 79 | eye_left = np.mean(lm_eye_left, axis=0) 80 | eye_right = np.mean(lm_eye_right, axis=0) 81 | eye_avg = (eye_left + eye_right) * 0.5 82 | eye_to_eye = eye_right - eye_left 83 | mouth_left = lm_mouth_outer[0] 84 | mouth_right = lm_mouth_outer[6] 85 | mouth_avg = (mouth_left + mouth_right) * 0.5 86 | eye_to_mouth = mouth_avg - eye_avg 87 | 88 | # Choose oriented crop rectangle. 89 | x = eye_to_eye - rot90(eye_to_mouth) 90 | x /= np.hypot(*x) 91 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 92 | y = rot90(x) 93 | c = eye_avg + eye_to_mouth * 0.1 94 | 95 | # Calculate auxiliary data. 96 | qsize = np.hypot(*x) * 2 97 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 98 | lo = np.min(quad, axis=0) 99 | hi = np.max(quad, axis=0) 100 | lm_rel = np.dot(landmarks - c, np.transpose([x, y])) / qsize**2 * 2 + 0.5 101 | rp = np.dot(np.random.RandomState(123).uniform(-1, 1, size=(1024, 2)), [x, y]) + c 102 | 103 | # Load. 104 | img_ori = PIL.Image.open(img_p).convert('RGB') 105 | img = PIL.Image.open(img_p).convert('RGB') 106 | 107 | # Shrink. 108 | shrink = int(np.floor(qsize / target_size * 0.5)) 109 | if shrink > 1: 110 | rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink))) 111 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 112 | quad /= shrink 113 | qsize /= shrink 114 | 115 | # Crop. 116 | border = max(int(np.rint(qsize * 0.1)), 3) 117 | crop = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 118 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1])) 119 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 120 | img = img.crop(crop) 121 | quad -= crop[0:2] 122 | 123 | # Pad. 124 | pad = (int(np.floor(min(quad[:,0]))), int(np.floor(min(quad[:,1]))), int(np.ceil(max(quad[:,0]))), int(np.ceil(max(quad[:,1])))) 125 | pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0)) 126 | if enable_padding and max(pad) > border - 4: 127 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 128 | img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 129 | h, w, _ = img.shape 130 | y, x, _ = np.ogrid[:h, :w, :1] 131 | mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w-1-x) / pad[2]), 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h-1-y) / pad[3])) 132 | blur = qsize * 0.02 133 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 134 | img += (np.median(img, axis=(0,1)) - img) * np.clip(mask, 0.0, 1.0) 135 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') 136 | quad += pad[:2] 137 | else: 138 | pad = (0,0,0,0) 139 | 140 | meta = { 141 | 'bbox': list(crop), 142 | 'quad': list((quad.astype(float) + 0.5).flatten()), 143 | 'size': list(img.size), 144 | 'pad': [int(p) for p in list(pad)], 145 | 'shrink': shrink, 146 | } 147 | 148 | # Transform. 149 | super_size = target_size * supersampling 150 | img = img.transform((super_size, super_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR) 151 | if target_size < super_size: 152 | img = img.resize((target_size, target_size), PIL.Image.ANTIALIAS) 153 | 154 | img_name = os.path.basename(img_p).split('.')[0] 155 | 156 | # save 157 | with open(os.path.join(meta_out_dir, img_name + '.json'), 'w') as f: 158 | json.dump(meta, f) 159 | 160 | img.save(os.path.join(img_out_dir, img_name + '.png')) 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument('--img_dir', type=str, required=True) 165 | parser.add_argument('--outdir', type=str, required=True) 166 | 167 | parser.add_argument('--detector', type=str, default='./shape_predictor_68_face_landmarks.dat') 168 | parser.add_argument('--size', type=int, default=256) 169 | 170 | args = parser.parse_args() -------------------------------------------------------------------------------- /semanticGAN/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 2020.9.4 10 | 11 | 12 | # Credits: 13 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 14 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 15 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 16 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 17 | 18 | # summary of changes: 19 | # 9/4/20 - updated addcmul_ signature to avoid warning. Integrates latest changes from GC developer (he did the work for this), and verified on performance on private dataloader. 20 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 21 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 22 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 23 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 24 | # changed eps to 1e-5 as better default than 1e-8. 25 | 26 | # Apache License 2.0 LICENSE code copy from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 27 | # please refer to https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer/blob/master/LICENSE 28 | 29 | 30 | import math 31 | import torch 32 | from torch.optim.optimizer import Optimizer, required 33 | 34 | 35 | def centralized_gradient(x, use_gc=True, gc_conv_only=False): 36 | '''credit - https://github.com/Yonghongwei/Gradient-Centralization ''' 37 | if use_gc: 38 | if gc_conv_only: 39 | if len(list(x.size())) > 3: 40 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 41 | else: 42 | if len(list(x.size())) > 1: 43 | x.add_(-x.mean(dim=tuple(range(1, len(list(x.size())))), keepdim=True)) 44 | return x 45 | 46 | 47 | class Ranger(Optimizer): 48 | 49 | def __init__(self, params, lr=1e-3, # lr 50 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 51 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 52 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 53 | use_gc=True, gc_conv_only=False, gc_loc=True 54 | ): 55 | 56 | # parameter checks 57 | if not 0.0 <= alpha <= 1.0: 58 | raise ValueError(f'Invalid slow update rate: {alpha}') 59 | if not 1 <= k: 60 | raise ValueError(f'Invalid lookahead steps: {k}') 61 | if not lr > 0: 62 | raise ValueError(f'Invalid Learning Rate: {lr}') 63 | if not eps > 0: 64 | raise ValueError(f'Invalid eps: {eps}') 65 | 66 | # parameter comments: 67 | # beta1 (momentum) of .95 seems to work better than .90... 68 | # N_sma_threshold of 5 seems better in testing than 4. 69 | # In both cases, worth testing on your dataloader (.90 vs .95, 4 vs 5) to make sure which works best for you. 70 | 71 | # prep defaults and init torch.optim base 72 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, 73 | N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 74 | super().__init__(params, defaults) 75 | 76 | # adjustable threshold 77 | self.N_sma_threshhold = N_sma_threshhold 78 | 79 | # look ahead params 80 | 81 | self.alpha = alpha 82 | self.k = k 83 | 84 | # radam buffer for state 85 | self.radam_buffer = [[None, None, None] for ind in range(10)] 86 | 87 | # gc on or off 88 | self.gc_loc = gc_loc 89 | self.use_gc = use_gc 90 | self.gc_conv_only = gc_conv_only 91 | # level of gradient centralization 92 | #self.gc_gradient_threshold = 3 if gc_conv_only else 1 93 | 94 | print( 95 | f"Ranger optimizer loaded. \nGradient Centralization usage = {self.use_gc}") 96 | if (self.use_gc and self.gc_conv_only == False): 97 | print(f"GC applied to both conv and fc layers") 98 | elif (self.use_gc and self.gc_conv_only == True): 99 | print(f"GC applied to conv layers only") 100 | 101 | def __setstate__(self, state): 102 | print("set state called") 103 | super(Ranger, self).__setstate__(state) 104 | 105 | def step(self, closure=None): 106 | loss = None 107 | # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 108 | # Uncomment if you need to use the actual closure... 109 | 110 | # if closure is not None: 111 | #loss = closure() 112 | 113 | # Evaluate averages and grad, update param tensors 114 | for group in self.param_groups: 115 | 116 | for p in group['params']: 117 | if p.grad is None: 118 | continue 119 | grad = p.grad.data.float() 120 | 121 | if grad.is_sparse: 122 | raise RuntimeError( 123 | 'Ranger optimizer does not support sparse gradients') 124 | 125 | p_data_fp32 = p.data.float() 126 | 127 | state = self.state[p] # get state dict for this param 128 | 129 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 130 | # if self.first_run_check==0: 131 | # self.first_run_check=1 132 | #print("Initializing slow buffer...should not see this at load from saved model!") 133 | state['step'] = 0 134 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 135 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 136 | 137 | # look ahead weight storage now in state dict 138 | state['slow_buffer'] = torch.empty_like(p.data) 139 | state['slow_buffer'].copy_(p.data) 140 | 141 | else: 142 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 143 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as( 144 | p_data_fp32) 145 | 146 | # begin computations 147 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 148 | beta1, beta2 = group['betas'] 149 | 150 | # GC operation for Conv layers and FC layers 151 | # if grad.dim() > self.gc_gradient_threshold: 152 | # grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 153 | if self.gc_loc: 154 | grad = centralized_gradient(grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) 155 | 156 | state['step'] += 1 157 | 158 | # compute variance mov avg 159 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 160 | 161 | # compute mean moving avg 162 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 163 | 164 | buffered = self.radam_buffer[int(state['step'] % 10)] 165 | 166 | if state['step'] == buffered[0]: 167 | N_sma, step_size = buffered[1], buffered[2] 168 | else: 169 | buffered[0] = state['step'] 170 | beta2_t = beta2 ** state['step'] 171 | N_sma_max = 2 / (1 - beta2) - 1 172 | N_sma = N_sma_max - 2 * \ 173 | state['step'] * beta2_t / (1 - beta2_t) 174 | buffered[1] = N_sma 175 | if N_sma > self.N_sma_threshhold: 176 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 177 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 178 | else: 179 | step_size = 1.0 / (1 - beta1 ** state['step']) 180 | buffered[2] = step_size 181 | 182 | # if group['weight_decay'] != 0: 183 | # p_data_fp32.add_(-group['weight_decay'] 184 | # * group['lr'], p_data_fp32) 185 | 186 | # apply lr 187 | if N_sma > self.N_sma_threshhold: 188 | denom = exp_avg_sq.sqrt().add_(group['eps']) 189 | G_grad = exp_avg / denom 190 | else: 191 | G_grad = exp_avg 192 | 193 | if group['weight_decay'] != 0: 194 | G_grad.add_(p_data_fp32, alpha=group['weight_decay']) 195 | # GC operation 196 | if self.gc_loc == False: 197 | G_grad = centralized_gradient(G_grad, use_gc=self.use_gc, gc_conv_only=self.gc_conv_only) 198 | 199 | p_data_fp32.add_(G_grad, alpha=-step_size * group['lr']) 200 | p.data.copy_(p_data_fp32) 201 | 202 | # integrated look ahead... 203 | # we do it at the param level instead of group level 204 | if state['step'] % group['k'] == 0: 205 | # get access to slow param tensor 206 | slow_p = state['slow_buffer'] 207 | # (fast weights - slow weights) * alpha 208 | slow_p.add_(p.data - slow_p, alpha=self.alpha) 209 | # copy interpolated weights to RAdam param tensor 210 | p.data.copy_(slow_p) 211 | 212 | return loss -------------------------------------------------------------------------------- /semanticGAN/samplers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | 24 | from typing import Iterator, List, Optional, Union 25 | from operator import itemgetter 26 | from torch.utils.data import DistributedSampler, Dataset 27 | from torch.utils.data.sampler import BatchSampler, Sampler 28 | 29 | class DatasetFromSampler(Dataset): 30 | """Dataset of indexes from `Sampler`.""" 31 | 32 | def __init__(self, sampler: Sampler): 33 | """ 34 | Args: 35 | sampler (Sampler): @TODO: Docs. Contribution is welcome 36 | """ 37 | self.sampler = sampler 38 | self.sampler_list = None 39 | 40 | def __getitem__(self, index: int): 41 | """Gets element of the dataloader. 42 | Args: 43 | index (int): index of the element in the dataloader 44 | Returns: 45 | Single element by index 46 | """ 47 | if self.sampler_list is None: 48 | self.sampler_list = list(self.sampler) 49 | return self.sampler_list[index] 50 | 51 | def __len__(self) -> int: 52 | """ 53 | Returns: 54 | int: length of the dataloader 55 | """ 56 | return len(self.sampler) 57 | 58 | class DistributedSamplerWrapper(DistributedSampler): 59 | """ 60 | Wrapper over `Sampler` for distributed training. 61 | Allows you to use any sampler in distributed mode. 62 | It is especially useful in conjunction with 63 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 64 | process can pass a DistributedSamplerWrapper instance as a DataLoader 65 | sampler, and load a subset of subsampled data of the original dataloader 66 | that is exclusive to it. 67 | .. note:: 68 | Sampler is assumed to be of constant size. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | sampler, 74 | num_replicas: Optional[int] = None, 75 | rank: Optional[int] = None, 76 | shuffle: bool = True, 77 | ): 78 | """ 79 | Args: 80 | sampler: Sampler used for subsampling 81 | num_replicas (int, optional): Number of processes participating in 82 | distributed training 83 | rank (int, optional): Rank of the current process 84 | within ``num_replicas`` 85 | shuffle (bool, optional): If true (default), 86 | sampler will shuffle the indices 87 | """ 88 | super(DistributedSamplerWrapper, self).__init__( 89 | DatasetFromSampler(sampler), 90 | num_replicas=num_replicas, 91 | rank=rank, 92 | shuffle=shuffle, 93 | ) 94 | self.sampler = sampler 95 | 96 | def __iter__(self): 97 | """@TODO: Docs. Contribution is welcome.""" 98 | self.dataset = DatasetFromSampler(self.sampler) 99 | indexes_of_indexes = super().__iter__() 100 | subsampler_indexes = self.dataset 101 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ -------------------------------------------------------------------------------- /utils/data_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import numpy as np 24 | 25 | face_class = ['background', 'head', 'head***cheek', 'head***chin', 'head***ear', 'head***ear***helix', 26 | 'head***ear***lobule', 'head***eye***botton lid', 'head***eye***eyelashes', 'head***eye***iris', 27 | 'head***eye***pupil', 'head***eye***sclera', 'head***eye***tear duct', 'head***eye***top lid', 28 | 'head***eyebrow', 'head***forehead', 'head***frown', 'head***hair', 'head***hair***sideburns', 29 | 'head***jaw', 'head***moustache', 'head***mouth***inferior lip', 'head***mouth***oral comisure', 30 | 'head***mouth***superior lip', 'head***mouth***teeth', 'head***neck', 'head***nose', 31 | 'head***nose***ala of nose', 'head***nose***bridge', 'head***nose***nose tip', 'head***nose***nostril', 32 | 'head***philtrum', 'head***temple', 'head***wrinkles'] 33 | 34 | car_12_class = ['background', 'car_body', 'head light', 'tail light', 'licence plate', 35 | 'wind shield', 'wheel', 'door', 'handle' , 'wheelhub', 'window', 'mirror'] 36 | car_20_class = ['background', 'back_bumper', 'bumper', 'car_body', 'car_lights', 'door', 'fender','grilles','handles', 37 | 'hoods', 'licensePlate', 'mirror','roof', 'running_boards', 'tailLight','tire', 'trunk_lids','wheelhub', 'window', 'windshield'] 38 | 39 | 40 | car_20_palette =[ 255, 255, 255, # 0 background 41 | 238, 229, 102,# 1 back_bumper 42 | 0, 0, 0,# 2 bumper 43 | 124, 99 , 34, # 3 car 44 | 193 , 127, 15,# 4 car_lights 45 | 248 ,213 , 42, # 5 door 46 | 220 ,147 , 77, # 6 fender 47 | 99 , 83 , 3, # 7 grilles 48 | 116 , 116 , 138, # 8 handles 49 | 200 ,226 , 37, # 9 hoods 50 | 225 , 184 , 161, # 10 licensePlate 51 | 142 , 172 ,248, # 11 mirror 52 | 153 , 112 , 146, # 12 roof 53 | 38 ,112 , 254, # 13 running_boards 54 | 229 , 30 ,141, # 14 tailLight 55 | 52 , 83 ,84, # 15 tire 56 | 194 , 87 , 125, # 16 trunk_lids 57 | 225, 96 ,18, # 17 wheelhub 58 | 31 , 102 , 211, # 18 window 59 | 104 , 131 , 101# 19 windshield 60 | ] 61 | 62 | 63 | 64 | face_palette = [ 1.0000, 1.0000 , 1.0000, 65 | 0.4420, 0.5100 , 0.4234, 66 | 0.8562, 0.9537 , 0.3188, 67 | 0.2405, 0.4699 , 0.9918, 68 | 0.8434, 0.9329 ,0.7544, 69 | 0.3748, 0.7917 , 0.3256, 70 | 0.0190, 0.4943 , 0.3782, 71 | 0.7461 , 0.0137 , 0.5684, 72 | 0.1644, 0.2402 , 0.7324, 73 | 0.0200 , 0.4379 , 0.4100, 74 | 0.5853 , 0.8880 , 0.6137, 75 | 0.7991 , 0.9132 , 0.9720, 76 | 0.6816 , 0.6237 ,0.8562, 77 | 0.9981 , 0.4692 , 0.3849, 78 | 0.5351 , 0.8242 , 0.2731, 79 | 0.1747 , 0.3626 , 0.8345, 80 | 0.5323 , 0.6668 , 0.4922, 81 | 0.2122 , 0.3483 , 0.4707, 82 | 0.6844, 0.1238 , 0.1452, 83 | 0.3882 , 0.4664 , 0.1003, 84 | 0.2296, 0.0401 , 0.3030, 85 | 0.5751 , 0.5467 , 0.9835, 86 | 0.1308 , 0.9628, 0.0777, 87 | 0.2849 ,0.1846 , 0.2625, 88 | 0.9764 , 0.9420 , 0.6628, 89 | 0.3893 , 0.4456 , 0.6433, 90 | 0.8705 , 0.3957 , 0.0963, 91 | 0.6117 , 0.9702 , 0.0247, 92 | 0.3668 , 0.6694 , 0.3117, 93 | 0.6451 , 0.7302, 0.9542, 94 | 0.6171 , 0.1097, 0.9053, 95 | 0.3377 , 0.4950, 0.7284, 96 | 0.1655, 0.9254, 0.6557, 97 | 0.9450 ,0.6721, 0.6162] 98 | 99 | face_palette = [int(item * 255) for item in face_palette] 100 | 101 | 102 | 103 | 104 | 105 | car_12_palette =[ 255, 255, 255, # 0 background 106 | 124, 99 , 34, # 3 car 107 | 193 , 127, 15,# 4 car_lights 108 | 229 , 30 ,141, # 14 tailLight 109 | 225 , 184 , 161, # 10 licensePlate 110 | 104 , 131 , 101,# 19 windshield 111 | 52 , 83 ,84, # 15 tire 112 | 248 ,213 , 42, # 5 door 113 | 116 , 116 , 138, # 8 handles 114 | 225, 96 ,18, # 17 wheelhub 115 | 31 , 102 , 211, # 18 window 116 | 142 , 172 ,248, # 11 mirror 117 | ] 118 | 119 | 120 | 121 | car_32_palette =[ 255, 255, 255, 122 | 238, 229, 102, 123 | 0, 0, 0, 124 | 124, 99 , 34, 125 | 193 , 127, 15, 126 | 106, 177, 21, 127 | 248 ,213 , 42, 128 | 252 , 155, 83, 129 | 220 ,147 , 77, 130 | 99 , 83 , 3, 131 | 116 , 116 , 138, 132 | 63 ,182 , 24, 133 | 200 ,226 , 37, 134 | 225 , 184 , 161, 135 | 233 , 5 ,219, 136 | 142 , 172 ,248, 137 | 153 , 112 , 146, 138 | 38 ,112 , 254, 139 | 229 , 30 ,141, 140 | 115 ,208 , 131, 141 | 52 , 83 ,84, 142 | 229 , 63 , 110, 143 | 194 , 87 , 125, 144 | 225, 96 ,18, 145 | 73 ,139, 226, 146 | 172 , 143 , 16, 147 | 169 , 101 , 111, 148 | 31 , 102 , 211, 149 | 104 , 131 , 101, 150 | 70 ,168 ,156, 151 | 183 , 242 , 209, 152 | 72 ,184 , 226] 153 | 154 | bedroom_palette =[ 255, 255, 255, 155 | 238, 229, 102, 156 | 255, 72, 69, 157 | 124, 99 , 34, 158 | 193 , 127, 15, 159 | 106, 177, 21, 160 | 248 ,213 , 42, 161 | 252 , 155, 83, 162 | 220 ,147 , 77, 163 | 99 , 83 , 3, 164 | 116 , 116 , 138, 165 | 63 ,182 , 24, 166 | 200 ,226 , 37, 167 | 225 , 184 , 161, 168 | 233 , 5 ,219, 169 | 142 , 172 ,248, 170 | 153 , 112 , 146, 171 | 38 ,112 , 254, 172 | 229 , 30 ,141, 173 | 238, 229, 12, 174 | 255, 72, 6, 175 | 124, 9, 34, 176 | 193, 17, 15, 177 | 106, 17, 21, 178 | 28, 213, 2, 179 | 252, 155, 3, 180 | 20, 147, 77, 181 | 9, 83, 3, 182 | 11, 16, 138, 183 | 6, 12, 24, 184 | 20, 22, 37, 185 | 225, 14, 16, 186 | 23, 5, 29, 187 | 14, 12, 28, 188 | 15, 11, 16, 189 | 3, 12, 24, 190 | 22, 3, 11 191 | ] 192 | 193 | cat_palette = [255, 255, 255, 194 | 220, 220, 0, 195 | 190, 153, 153, 196 | 250, 170, 30, 197 | 220, 220, 0, 198 | 107, 142, 35, 199 | 102, 102, 156, 200 | 152, 251, 152, 201 | 119, 11, 32, 202 | 244, 35, 232, 203 | 220, 20, 60, 204 | 52 , 83 ,84, 205 | 194 , 87 , 125, 206 | 225, 96 ,18, 207 | 31 , 102 , 211, 208 | 104 , 131 , 101 209 | ] 210 | 211 | def trans_mask_stylegan_20classTo12(mask): 212 | final_mask = np.zeros(mask.shape) 213 | final_mask[(mask != 0)] = 1 # car 214 | final_mask[(mask == 4)] = 2 # head light 215 | final_mask[(mask == 14)] = 5 # tail light 216 | final_mask[(mask == 10)] = 3 # licence plate 217 | final_mask[ (mask == 19)] = 8 # wind shield 218 | final_mask[(mask == 15)] = 6 # wheel 219 | final_mask[(mask == 5)] = 9 # door 220 | final_mask[(mask == 8)] = 10 # handle 221 | final_mask[(mask == 17)] = 11 # wheelhub 222 | final_mask[(mask == 18)] = 7 # window 223 | final_mask[(mask == 11)] = 4 # mirror 224 | return final_mask 225 | 226 | 227 | def trans_mask(mask): 228 | return mask 229 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import pickle 24 | 25 | import torch 26 | from torch import distributed as dist 27 | 28 | 29 | def get_rank(): 30 | if not dist.is_available(): 31 | return 0 32 | 33 | if not dist.is_initialized(): 34 | return 0 35 | 36 | return dist.get_rank() 37 | 38 | 39 | def synchronize(): 40 | if not dist.is_available(): 41 | return 42 | 43 | if not dist.is_initialized(): 44 | return 45 | 46 | world_size = dist.get_world_size() 47 | 48 | if world_size == 1: 49 | return 50 | 51 | dist.barrier() 52 | 53 | 54 | def get_world_size(): 55 | if not dist.is_available(): 56 | return 1 57 | 58 | if not dist.is_initialized(): 59 | return 1 60 | 61 | return dist.get_world_size() 62 | 63 | 64 | def reduce_sum(tensor): 65 | if not dist.is_available(): 66 | return tensor 67 | 68 | if not dist.is_initialized(): 69 | return tensor 70 | 71 | tensor = tensor.clone() 72 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 73 | 74 | return tensor 75 | 76 | 77 | def gather_grad(params): 78 | world_size = get_world_size() 79 | 80 | if world_size == 1: 81 | return 82 | 83 | for param in params: 84 | if param.grad is not None: 85 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 86 | param.grad.data.div_(world_size) 87 | 88 | 89 | def all_gather(data): 90 | world_size = get_world_size() 91 | 92 | if world_size == 1: 93 | return [data] 94 | 95 | buffer = pickle.dumps(data) 96 | storage = torch.ByteStorage.from_buffer(buffer) 97 | tensor = torch.ByteTensor(storage).to('cuda') 98 | 99 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 100 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 101 | dist.all_gather(size_list, local_size) 102 | size_list = [int(size.item()) for size in size_list] 103 | max_size = max(size_list) 104 | 105 | tensor_list = [] 106 | for _ in size_list: 107 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 108 | 109 | if local_size != max_size: 110 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 111 | tensor = torch.cat((tensor, padding), 0) 112 | 113 | dist.all_gather(tensor_list, tensor) 114 | 115 | data_list = [] 116 | 117 | for size, tensor in zip(size_list, tensor_list): 118 | buffer = tensor.cpu().numpy().tobytes()[:size] 119 | data_list.append(pickle.loads(buffer)) 120 | 121 | return data_list 122 | 123 | 124 | def reduce_loss_dict(loss_dict): 125 | world_size = get_world_size() 126 | 127 | if world_size < 2: 128 | return loss_dict 129 | 130 | with torch.no_grad(): 131 | keys = [] 132 | losses = [] 133 | 134 | for k in sorted(loss_dict.keys()): 135 | keys.append(k) 136 | losses.append(loss_dict[k]) 137 | 138 | losses = torch.stack(losses, 0) 139 | dist.reduce(losses, dst=0) 140 | 141 | if dist.get_rank() == 0: 142 | losses /= world_size 143 | 144 | reduced_losses = {k: v for k, v in zip(keys, losses)} 145 | 146 | return reduced_losses 147 | -------------------------------------------------------------------------------- /utils/inception_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | 24 | import numpy as np 25 | from scipy import linalg # For numpy FID 26 | import time 27 | import pickle 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | from torch.nn import Parameter as P 32 | from torchvision.models.inception import inception_v3 33 | 34 | 35 | class WrapInception(nn.Module): 36 | def __init__(self, net): 37 | super(WrapInception,self).__init__() 38 | self.net = net 39 | self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1), 40 | requires_grad=False) 41 | self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1), 42 | requires_grad=False) 43 | 44 | def forward(self, x): 45 | # Normalize x 46 | x = (x + 1.) / 2.0 47 | x = (x - self.mean) / self.std 48 | # Upsample if necessary 49 | if x.shape[2] != 299 or x.shape[3] != 299: 50 | x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) 51 | # 299 x 299 x 3 52 | x = self.net.Conv2d_1a_3x3(x) 53 | # 149 x 149 x 32 54 | x = self.net.Conv2d_2a_3x3(x) 55 | # 147 x 147 x 32 56 | x = self.net.Conv2d_2b_3x3(x) 57 | # 147 x 147 x 64 58 | x = F.max_pool2d(x, kernel_size=3, stride=2) 59 | # 73 x 73 x 64 60 | x = self.net.Conv2d_3b_1x1(x) 61 | # 73 x 73 x 80 62 | x = self.net.Conv2d_4a_3x3(x) 63 | # 71 x 71 x 192 64 | x = F.max_pool2d(x, kernel_size=3, stride=2) 65 | # 35 x 35 x 192 66 | x = self.net.Mixed_5b(x) 67 | # 35 x 35 x 256 68 | x = self.net.Mixed_5c(x) 69 | # 35 x 35 x 288 70 | x = self.net.Mixed_5d(x) 71 | # 35 x 35 x 288 72 | x = self.net.Mixed_6a(x) 73 | # 17 x 17 x 768 74 | x = self.net.Mixed_6b(x) 75 | # 17 x 17 x 768 76 | x = self.net.Mixed_6c(x) 77 | # 17 x 17 x 768 78 | x = self.net.Mixed_6d(x) 79 | # 17 x 17 x 768 80 | x = self.net.Mixed_6e(x) 81 | # 17 x 17 x 768 82 | # 17 x 17 x 768 83 | x = self.net.Mixed_7a(x) 84 | # 8 x 8 x 1280 85 | x = self.net.Mixed_7b(x) 86 | # 8 x 8 x 2048 87 | x = self.net.Mixed_7c(x) 88 | # 8 x 8 x 2048 89 | pool = torch.mean(x.view(x.size(0), x.size(1), -1), 2) 90 | # 1 x 1 x 2048 91 | logits = self.net.fc(F.dropout(pool, training=False).view(pool.size(0), -1)) 92 | # 1000 (num_classes) 93 | return pool, logits 94 | 95 | 96 | # A pytorch implementation of cov, from Modar M. Alfadly 97 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 98 | def torch_cov(m, rowvar=False): 99 | '''Estimate a covariance matrix given data. 100 | 101 | Covariance indicates the level to which two variables vary together. 102 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 103 | then the covariance matrix element `C_{ij}` is the covariance of 104 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 105 | 106 | Args: 107 | m: A 1-D or 2-D array containing multiple variables and observations. 108 | Each row of `m` represents a variable, and each column a single 109 | observation of all those variables. 110 | rowvar: If `rowvar` is True, then each row represents a 111 | variable, with observations in the columns. Otherwise, the 112 | relationship is transposed: each column represents a variable, 113 | while the rows contain observations. 114 | 115 | Returns: 116 | The covariance matrix of the variables. 117 | ''' 118 | if m.dim() > 2: 119 | raise ValueError('m has more than 2 dimensions') 120 | if m.dim() < 2: 121 | m = m.view(1, -1) 122 | if not rowvar and m.size(0) != 1: 123 | m = m.t() 124 | # m = m.type(torch.double) # uncomment this line if desired 125 | fact = 1.0 / (m.size(1) - 1) 126 | m -= torch.mean(m, dim=1, keepdim=True) 127 | mt = m.t() # if complex: mt = m.t().conj() 128 | 129 | return fact * m.matmul(mt).squeeze() 130 | 131 | 132 | # Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji 133 | # https://github.com/msubhransu/matrix-sqrt 134 | def sqrt_newton_schulz(A, numIters, dtype=None): 135 | with torch.no_grad(): 136 | if dtype is None: 137 | dtype = A.type() 138 | batchSize = A.shape[0] 139 | dim = A.shape[1] 140 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 141 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)) 142 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 143 | Z = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 144 | for i in range(numIters): 145 | T = 0.5*(3.0*I - Z.bmm(Y)) 146 | Y = Y.bmm(T) 147 | Z = T.bmm(Z) 148 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 149 | return sA 150 | 151 | 152 | # FID calculator from TTUR--consider replacing this with GPU-accelerated cov 153 | # calculations using torch? 154 | def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 155 | """Numpy implementation of the Frechet Distance. 156 | Taken from https://github.com/bioinf-jku/TTUR 157 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 158 | and X_2 ~ N(mu_2, C_2) is 159 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 160 | Stable version by Dougal J. Sutherland. 161 | Params: 162 | -- mu1 : Numpy array containing the activations of a layer of the 163 | inception net (like returned by the function 'get_predictions') 164 | for generated samples. 165 | -- mu2 : The sample mean over activations, precalculated on an 166 | representive data set. 167 | -- sigma1: The covariance matrix over activations for generated samples. 168 | -- sigma2: The covariance matrix over activations, precalculated on an 169 | representive data set. 170 | Returns: 171 | -- : The Frechet Distance. 172 | """ 173 | 174 | mu1 = np.atleast_1d(mu1) 175 | mu2 = np.atleast_1d(mu2) 176 | 177 | sigma1 = np.atleast_2d(sigma1) 178 | sigma2 = np.atleast_2d(sigma2) 179 | 180 | assert mu1.shape == mu2.shape, \ 181 | 'Training and test mean vectors have different lengths' 182 | assert sigma1.shape == sigma2.shape, \ 183 | 'Training and test covariances have different dimensions' 184 | 185 | diff = mu1 - mu2 186 | 187 | # Product might be almost singular 188 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 189 | if not np.isfinite(covmean).all(): 190 | msg = ('fid calculation produces singular product; ' 191 | 'adding %s to diagonal of cov estimates') % eps 192 | print(msg) 193 | offset = np.eye(sigma1.shape[0]) * eps 194 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 195 | 196 | # Numerical error might give slight imaginary component 197 | if np.iscomplexobj(covmean): 198 | print('wat') 199 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 200 | m = np.max(np.abs(covmean.imag)) 201 | raise ValueError('Imaginary component {}'.format(m)) 202 | covmean = covmean.real 203 | 204 | tr_covmean = np.trace(covmean) 205 | 206 | out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 207 | return out 208 | 209 | 210 | def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 211 | """Pytorch implementation of the Frechet Distance. 212 | Taken from https://github.com/bioinf-jku/TTUR 213 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 214 | and X_2 ~ N(mu_2, C_2) is 215 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 216 | Stable version by Dougal J. Sutherland. 217 | Params: 218 | -- mu1 : Numpy array containing the activations of a layer of the 219 | inception net (like returned by the function 'get_predictions') 220 | for generated samples. 221 | -- mu2 : The sample mean over activations, precalculated on an 222 | representive data set. 223 | -- sigma1: The covariance matrix over activations for generated samples. 224 | -- sigma2: The covariance matrix over activations, precalculated on an 225 | representive data set. 226 | Returns: 227 | -- : The Frechet Distance. 228 | """ 229 | 230 | 231 | assert mu1.shape == mu2.shape, \ 232 | 'Training and test mean vectors have different lengths' 233 | assert sigma1.shape == sigma2.shape, \ 234 | 'Training and test covariances have different dimensions' 235 | 236 | diff = mu1 - mu2 237 | # Run 50 itrs of newton-schulz to get the matrix sqrt of sigma1 dot sigma2 238 | covmean = sqrt_newton_schulz(sigma1.mm(sigma2).unsqueeze(0), 50).squeeze() 239 | out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2) 240 | - 2 * torch.trace(covmean)) 241 | return out 242 | 243 | 244 | # Calculate Inception Score mean + std given softmax'd logits and number of splits 245 | def calculate_inception_score(pred, num_splits=10): 246 | scores = [] 247 | for index in range(num_splits): 248 | pred_chunk = pred[index * (pred.shape[0] // num_splits): (index + 1) * (pred.shape[0] // num_splits), :] 249 | kl_inception = pred_chunk * (np.log(pred_chunk) - np.log(np.expand_dims(np.mean(pred_chunk, 0), 0))) 250 | kl_inception = np.mean(np.sum(kl_inception, 1)) 251 | scores.append(np.exp(kl_inception)) 252 | return np.mean(scores), np.std(scores) 253 | 254 | 255 | # Loop and run the sampler and the net until it accumulates num_inception_images 256 | # activations. Return the pool, the logits, and the labels (if one wants 257 | # Inception Accuracy the labels of the generated class will be needed) 258 | def accumulate_inception_activations(sample, net, num_inception_images=50000): 259 | pool, logits= [], [] 260 | while (torch.cat(logits, 0).shape[0] if len(logits) else 0) < num_inception_images: 261 | with torch.no_grad(): 262 | images = sample() 263 | if images.shape[1] != 3: 264 | images = images.expand(-1,3,-1,-1) 265 | 266 | pool_val, logits_val = net(images.float()) 267 | pool += [pool_val] 268 | logits += [F.softmax(logits_val, 1)] 269 | 270 | return torch.cat(pool, 0), torch.cat(logits, 0) 271 | 272 | 273 | # Load and wrap the Inception model 274 | def load_inception_net(parallel=False): 275 | inception_model = inception_v3(pretrained=True, transform_input=False) 276 | inception_model = WrapInception(inception_model.eval()).cuda() 277 | if parallel: 278 | print('Parallelizing Inception module...') 279 | inception_model = nn.DataParallel(inception_model) 280 | return inception_model 281 | 282 | 283 | # This produces a function which takes in an iterator which returns a set number of samples 284 | # and iterates until it accumulates config['num_inception_images'] images. 285 | # The iterator can return samples with a different batch size than used in 286 | # training, using the setting confg['inception_batchsize'] 287 | def prepare_inception_metrics(dataset, parallel, no_fid=False): 288 | # Load metrics; this is intentionally not in a try-except loop so that 289 | # the script will crash here if it cannot find the Inception moments. 290 | # By default, remove the "hdf5" from dataloader 291 | with open(dataset, 'rb') as f: 292 | embeds = pickle.load(f) 293 | data_mu = embeds['mean'] 294 | data_sigma = embeds['cov'] 295 | 296 | # Load network 297 | net = load_inception_net(parallel) 298 | def get_inception_metrics(sample, num_inception_images, num_splits=10, 299 | prints=True, use_torch=True): 300 | if prints: 301 | print('Gathering activations...') 302 | pool, logits = accumulate_inception_activations(sample, net, num_inception_images) 303 | if prints: 304 | print('Calculating Inception Score...') 305 | IS_mean, IS_std = calculate_inception_score(logits.cpu().numpy(), num_splits) 306 | if no_fid: 307 | FID = 9999.0 308 | else: 309 | if prints: 310 | print('Calculating means and covariances...') 311 | if use_torch: 312 | mu, sigma = torch.mean(pool, 0), torch_cov(pool, rowvar=False) 313 | else: 314 | mu, sigma = np.mean(pool.cpu().numpy(), axis=0), np.cov(pool.cpu().numpy(), rowvar=False) 315 | if prints: 316 | print('Covariances calculated, getting FID...') 317 | if use_torch: 318 | import pdb; pdb.set_trace() 319 | FID = torch_calculate_frechet_distance(mu, sigma, torch.tensor(data_mu).float().cuda(), torch.tensor(data_sigma).float().cuda()) 320 | FID = float(FID.cpu().numpy()) 321 | else: 322 | FID = numpy_calculate_frechet_distance(mu, sigma, data_mu, data_sigma) 323 | # Delete mu, sigma, pool, logits, and labels, just in case 324 | del mu, sigma, pool, logits 325 | return IS_mean, IS_std, FID 326 | 327 | return get_inception_metrics 328 | 329 | def sample_gema(g_ema, device, truncation, mean_latent, batch_size): 330 | with torch.no_grad(): 331 | g_ema.eval() 332 | 333 | sample_z = torch.randn(batch_size, 512, device=device) 334 | 335 | samples = g_ema([sample_z], truncation=truncation, truncation_latent=mean_latent) 336 | 337 | sample = samples[0] 338 | 339 | return sample -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2021 NVIDIA Corporation. All rights reserved. 3 | Licensed under The MIT License (MIT) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so, 10 | subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 17 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 18 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 19 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 20 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | """ 22 | 23 | import torch 24 | from PIL import Image 25 | import numpy as np 26 | from torch import nn 27 | 28 | class Interpolate(nn.Module): 29 | def __init__(self, size, mode): 30 | super(Interpolate, self).__init__() 31 | self.interp = nn.functional.interpolate 32 | self.size = size 33 | self.mode = mode 34 | 35 | def forward(self, x): 36 | x = self.interp(x, size=self.size, mode=self.mode, align_corners=False) 37 | return x 38 | 39 | 40 | 41 | def multi_acc(y_pred, y_test): 42 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 43 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 44 | 45 | correct_pred = (y_pred_tags == y_test).float() 46 | acc = correct_pred.sum() / len(correct_pred) 47 | 48 | acc = acc * 100 49 | 50 | return acc 51 | 52 | 53 | def oht_to_scalar(y_pred): 54 | y_pred_softmax = torch.log_softmax(y_pred, dim=1) 55 | _, y_pred_tags = torch.max(y_pred_softmax, dim=1) 56 | 57 | return y_pred_tags 58 | 59 | def latent_to_image(g_all, upsamplers, latents, return_upsampled_layers=False, use_style_latents=False, 60 | style_latents=None, process_out=True, return_stylegan_latent=False, dim=512, return_only_im=False): 61 | '''Given a input latent code, generate corresponding image and concatenated feature maps''' 62 | 63 | # assert (len(latents) == 1) # for GPU memory constraints 64 | if not use_style_latents: 65 | # generate style_latents from latents 66 | style_latents = g_all.module.truncation(g_all.module.g_mapping(latents)) 67 | style_latents = style_latents.clone() # make different layers non-alias 68 | 69 | else: 70 | style_latents = latents 71 | 72 | # style_latents = latents 73 | if return_stylegan_latent: 74 | 75 | return style_latents 76 | img_list, affine_layers = g_all.module.g_synthesis(style_latents) 77 | 78 | if return_only_im: 79 | if process_out: 80 | if img_list.shape[-2] > 512: 81 | img_list = upsamplers[-1](img_list) 82 | 83 | img_list = img_list.cpu().detach().numpy() 84 | img_list = process_image(img_list) 85 | img_list = np.transpose(img_list, (0, 2, 3, 1)).astype(np.uint8) 86 | return img_list, style_latents 87 | 88 | number_feautre = 0 89 | 90 | for item in affine_layers: 91 | number_feautre += item.shape[1] 92 | 93 | 94 | affine_layers_upsamples = torch.FloatTensor(1, number_feautre, dim, dim).cuda() 95 | if return_upsampled_layers: 96 | 97 | start_channel_index = 0 98 | for i in range(len(affine_layers)): 99 | len_channel = affine_layers[i].shape[1] 100 | affine_layers_upsamples[:, start_channel_index:start_channel_index + len_channel] = upsamplers[i]( 101 | affine_layers[i]) 102 | start_channel_index += len_channel 103 | 104 | if img_list.shape[-2] != 512: 105 | img_list = upsamplers[-1](img_list) 106 | 107 | if process_out: 108 | img_list = img_list.cpu().detach().numpy() 109 | img_list = process_image(img_list) 110 | img_list = np.transpose(img_list, (0, 2, 3, 1)).astype(np.uint8) 111 | # print('start_channel_index',start_channel_index) 112 | 113 | 114 | return img_list, affine_layers_upsamples 115 | 116 | 117 | def process_image(images): 118 | drange = [-1, 1] 119 | scale = 255 / (drange[1] - drange[0]) 120 | images = images * scale + (0.5 - drange[0] * scale) 121 | 122 | images = images.astype(int) 123 | images[images > 255] = 255 124 | images[images < 0] = 0 125 | 126 | return images.astype(int) 127 | 128 | def colorize_mask(mask, palette): 129 | # mask: numpy array of the mask 130 | 131 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 132 | new_mask.putpalette(palette) 133 | return np.array(new_mask.convert('RGB')) 134 | 135 | 136 | def get_label_stas(data_loader): 137 | count_dict = {} 138 | for i in range(data_loader.__len__()): 139 | x, y = data_loader.__getitem__(i) 140 | if int(y.item()) not in count_dict: 141 | count_dict[int(y.item())] = 1 142 | else: 143 | count_dict[int(y.item())] += 1 144 | 145 | return count_dict 146 | --------------------------------------------------------------------------------