├── .gitignore ├── LICENSE.txt ├── README.md ├── adamw.py ├── assets ├── .DS_Store ├── TransGAN_1.png ├── Visual_results.png ├── cifar_visual.png ├── readme.png └── teaser_examples.jpg ├── celeba.py ├── cfg.py ├── datasets.py ├── dnnlib ├── __init__.py └── util.py ├── exps ├── __init__.py ├── celeba_hq_256_test.py ├── celeba_hq_256_train.py ├── church_256_train.py ├── cifar_test.py ├── cifar_train.py └── stl_train.py ├── fid_stat ├── fid_stats_celeba_hq_256.npz ├── fid_stats_church_256.npz ├── fid_stats_cifar10_train.npz └── stl10_train_unlabeled_fid_stats_48.npz ├── flops.py ├── functions.py ├── models_search ├── Celeba256_dis.py ├── Celeba256_gen.py ├── ViT_custom.py ├── ViT_custom_local544444_256_rp.py ├── ViT_custom_local544444_256_rp_noise.py ├── ViT_custom_rp.py ├── ViT_custom_scale2.py ├── ViT_custom_scale2_rp_noise.py ├── ViT_helper.py ├── ViT_scale3_local_new_rp.py ├── __init__.py ├── ada.py └── diff_aug.py ├── requirements.txt ├── test.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train_derived.py └── utils ├── __init__.py ├── cal_fid_stat.py ├── fid_score.py ├── inception.py ├── inception_model.py ├── inception_score.py ├── torch_fid_score.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # pytorch 2 | *.pth 3 | *.pth.tar 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | pretrained_weight/ 14 | Logs 15 | logs 16 | sampled_image* 17 | data/ 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | #lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Yifan Jiang. All rights reserved. 2 | 3 | 4 | ======================================================================= 5 | 6 | 1. Definitions 7 | 8 | "Licensor" means any person or entity that distributes its Work. 9 | 10 | "Software" means the original work of authorship made available under 11 | this License. 12 | 13 | "Work" means the Software and any additions to or derivative works of 14 | the Software that are made available under this License. 15 | 16 | The terms "reproduce," "reproduction," "derivative works," and 17 | "distribution" have the meaning as provided under U.S. copyright law; 18 | provided, however, that for the purposes of this License, derivative 19 | works shall not include works that remain separable from, or merely 20 | link (or bind by name) to the interfaces of, the Work. 21 | 22 | Works, including the Software, are "made available" under this License 23 | by including in or with the Work either (a) a copyright notice 24 | referencing the applicability of this License to the Work, or (b) a 25 | copy of this License. 26 | 27 | 2. License Grants 28 | 29 | 2.1 Copyright Grant. Subject to the terms and conditions of this 30 | License, each Licensor grants to you a perpetual, worldwide, 31 | non-exclusive, royalty-free, copyright license to reproduce, 32 | prepare derivative works of, publicly display, publicly perform, 33 | sublicense and distribute its Work and any resulting derivative 34 | works in any form. 35 | 36 | 3. Limitations 37 | 38 | 3.1 Redistribution. You may reproduce or distribute the Work only 39 | if (a) you do so under this License, (b) you include a complete 40 | copy of this License with your distribution, and (c) you retain 41 | without modification any copyright, patent, trademark, or 42 | attribution notices that are present in the Work. 43 | 44 | 3.2 Derivative Works. You may specify that additional or different 45 | terms apply to the use, reproduction, and distribution of your 46 | derivative works of the Work ("Your Terms") only if (a) Your Terms 47 | provide that the use limitation in Section 3 applies to your 48 | derivative works, and (b) you identify the specific derivative 49 | works that are subject to Your Terms. Notwithstanding Your Terms, 50 | this License (including the redistribution requirements in Section 51 | 3.1) will continue to apply to the Work itself. 52 | 53 | 3.3 Use Limitation. The Work and any derivative works thereof only 54 | may be used or intended for use non-commercially. As used herein, 55 | "non-commercially" means for research or evaluation purposes only. 56 | 57 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 58 | against any Licensor (including any claim, cross-claim or 59 | counterclaim in a lawsuit) to enforce any patents that you allege 60 | are infringed by any Work, then your rights under this License from 61 | such Licensor (including the grants in Sections 2.1 and 2.2) will 62 | terminate immediately. 63 | 64 | 3.5 Trademarks. This License does not grant any rights to use any 65 | Licensor's or its affiliates' names, logos, or trademarks, except 66 | as necessary to reproduce the notices described in this License. 67 | 68 | 3.6 Termination. If you violate any term of this License, then your 69 | rights under this License (including the grants in Sections 2.1 and 70 | 2.2) will terminate immediately. 71 | 72 | 4. Disclaimer of Warranty. 73 | 74 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 75 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 76 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 77 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 78 | THIS LICENSE. 79 | 80 | 5. Limitation of Liability. 81 | 82 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 83 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 84 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 85 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 86 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 87 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 88 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 89 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 90 | THE POSSIBILITY OF SUCH DAMAGES. 91 | 92 | ======================================================================= 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransGAN: Two Pure Transformers Can Make One Strong GAN, and That Can Scale Up 2 | Code used for [TransGAN: Two Pure Transformers Can Make One Strong GAN, and That Can Scale Up](https://arxiv.org/abs/2102.07074). 3 | 4 | ## Implementation 5 | - [ ] checkpoint gradient using torch.utils.checkpoint 6 | - [ ] 16bit precision training 7 | - [x] Distributed Training (Faster!) 8 | - [x] IS/FID Evaluation 9 | - [x] Gradient Accumulation 10 | - [x] Stronger Data Augmentation 11 | - [x] Self-Modulation 12 | 13 | ## Guidance 14 | #### Cifar training script 15 | ``` 16 | python exp/cifar_train.py 17 | ``` 18 | I disabled the evaluation during training job as it causes strange bug. Please launch another evaluation job simultaneously by copying the `path` to [test script](https://github.com/VITA-Group/TransGAN/blob/a13640fbf4699d651c1a9da0fd936f260f5f096d/exps/cifar_test.py#L58). 19 | #### Cifar test 20 | First download the [cifar checkpoint](https://drive.google.com/file/d/149I8kPnNOypp_4tU_27s7OAVdBR_ZR2Z/view?usp=sharing) and put it on `./cifar_checkpoint`. Then run the following script. 21 | ``` 22 | python exp/cifar_test.py 23 | ``` 24 | 25 | ## Main Pipeline 26 | ![Main Pipeline](assets/TransGAN_1.png) 27 | 28 | ## Representative Visual Results 29 | ![Cifar Visual Results](assets/cifar_visual.png) 30 | ![Visual Results](assets/teaser_examples.jpg) 31 | 32 | 33 | README waits for updated 34 | ## Acknowledgement 35 | Codebase from [AutoGAN](https://github.com/VITA-Group/AutoGAN), [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) 36 | 37 | ## Citation 38 | if you find this repo is helpful, please cite 39 | ``` 40 | @article{jiang2021transgan, 41 | title={Transgan: Two pure transformers can make one strong gan, and that can scale up}, 42 | author={Jiang, Yifan and Chang, Shiyu and Wang, Zhangyang}, 43 | journal={Advances in Neural Information Processing Systems}, 44 | volume={34}, 45 | year={2021} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /adamw.py: -------------------------------------------------------------------------------- 1 | """ AdamW Optimizer 2 | Impl copied from PyTorch master 3 | """ 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class AdamW(Optimizer): 10 | r"""Implements AdamW algorithm. 11 | 12 | The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. 13 | The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. 14 | 15 | Arguments: 16 | params (iterable): iterable of parameters to optimize or dicts defining 17 | parameter groups 18 | lr (float, optional): learning rate (default: 1e-3) 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | eps (float, optional): term added to the denominator to improve 22 | numerical stability (default: 1e-8) 23 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 24 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 25 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 26 | (default: False) 27 | 28 | .. _Adam\: A Method for Stochastic Optimization: 29 | https://arxiv.org/abs/1412.6980 30 | .. _Decoupled Weight Decay Regularization: 31 | https://arxiv.org/abs/1711.05101 32 | .. _On the Convergence of Adam and Beyond: 33 | https://openreview.net/forum?id=ryQu7f-RZ 34 | """ 35 | 36 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 37 | weight_decay=1e-2, amsgrad=False): 38 | if not 0.0 <= lr: 39 | raise ValueError("Invalid learning rate: {}".format(lr)) 40 | if not 0.0 <= eps: 41 | raise ValueError("Invalid epsilon value: {}".format(eps)) 42 | if not 0.0 <= betas[0] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 44 | if not 0.0 <= betas[1] < 1.0: 45 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 46 | defaults = dict(lr=lr, betas=betas, eps=eps, 47 | weight_decay=weight_decay, amsgrad=amsgrad) 48 | super(AdamW, self).__init__(params, defaults) 49 | 50 | def __setstate__(self, state): 51 | super(AdamW, self).__setstate__(state) 52 | for group in self.param_groups: 53 | group.setdefault('amsgrad', False) 54 | 55 | def step(self, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | """ 62 | loss = None 63 | if closure is not None: 64 | loss = closure() 65 | 66 | for group in self.param_groups: 67 | for p in group['params']: 68 | if p.grad is None: 69 | continue 70 | 71 | # Perform stepweight decay 72 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 73 | 74 | # Perform optimization step 75 | grad = p.grad.data 76 | if grad.is_sparse: 77 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 78 | amsgrad = group['amsgrad'] 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | if amsgrad: 90 | # Maintains max of all exp. moving avg. of sq. grad. values 91 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 92 | 93 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 94 | if amsgrad: 95 | max_exp_avg_sq = state['max_exp_avg_sq'] 96 | beta1, beta2 = group['betas'] 97 | 98 | state['step'] += 1 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | 102 | # Decay the first and second moment running average coefficient 103 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 104 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 105 | if amsgrad: 106 | # Maintains the maximum of all 2nd moment running avg. till now 107 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 108 | # Use the max. for normalizing running avg. of gradient 109 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 110 | else: 111 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 112 | 113 | step_size = group['lr'] / bias_correction1 114 | 115 | p.data.addcdiv_(-step_size, exp_avg, denom) 116 | 117 | return loss 118 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/.DS_Store -------------------------------------------------------------------------------- /assets/TransGAN_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/TransGAN_1.png -------------------------------------------------------------------------------- /assets/Visual_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/Visual_results.png -------------------------------------------------------------------------------- /assets/cifar_visual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/cifar_visual.png -------------------------------------------------------------------------------- /assets/readme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/readme.png -------------------------------------------------------------------------------- /assets/teaser_examples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/assets/teaser_examples.jpg -------------------------------------------------------------------------------- /celeba.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 10/6/19 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from functools import partial 8 | import torch 9 | import os 10 | import PIL 11 | from torchvision.datasets.vision import VisionDataset 12 | from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg 13 | from torch.utils.data import Dataset 14 | import glob 15 | 16 | 17 | 18 | class CelebA(Dataset): 19 | """ pyTorch Dataset wrapper for the generic flat directory images dataset """ 20 | 21 | def __setup_files(self): 22 | """ 23 | private helper for setting up the files_list 24 | :return: files => list of paths of files 25 | """ 26 | file_names = os.listdir(self.data_dir) 27 | files = [] # initialize to empty list 28 | 29 | for file_name in file_names: 30 | possible_file = os.path.join(self.data_dir, file_name) 31 | if os.path.isfile(possible_file): 32 | files.append(possible_file) 33 | 34 | # return the files list 35 | return files 36 | 37 | def __init__(self, root, transform=None): 38 | """ 39 | constructor for the class 40 | :param data_dir: path to the directory containing the data 41 | :param transform: transforms to be applied to the images 42 | """ 43 | # define the state of the object 44 | self.data_dir = root 45 | self.transform = transform 46 | 47 | # setup the files for reading 48 | self.files = self.__setup_files() 49 | 50 | def __len__(self): 51 | """ 52 | compute the length of the dataset 53 | :return: len => length of dataset 54 | """ 55 | return len(self.files) 56 | 57 | def __getitem__(self, idx): 58 | """ 59 | obtain the image (read and transform) 60 | :param idx: index of the file required 61 | :return: img => image array 62 | """ 63 | from PIL import Image 64 | 65 | # read the image: 66 | img_name = self.files[idx] 67 | if img_name[-4:] == ".npy": 68 | img = np.load(img_name) 69 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0)) 70 | else: 71 | img = Image.open(img_name) 72 | 73 | # apply the transforms on the image 74 | if self.transform is not None: 75 | img = self.transform(img) 76 | 77 | # return the image: 78 | return img, img 79 | 80 | 81 | class FFHQ(Dataset): 82 | """ pyTorch Dataset wrapper for the generic flat directory images dataset """ 83 | 84 | def __setup_files(self): 85 | """ 86 | private helper for setting up the files_list 87 | :return: files => list of paths of files 88 | """ 89 | file_names = glob.glob(os.path.join(self.data_dir, "./*/*.png")) + \ 90 | glob.glob(os.path.join(self.data_dir, "./*.jpg")) + \ 91 | [y for x in os.walk(self.data_dir) for y in glob.glob(os.path.join(x[0], "*.webp"))] 92 | files = [] # initialize to empty list 93 | 94 | for file_name in file_names: 95 | possible_file = os.path.join(self.data_dir, file_name) 96 | if os.path.isfile(possible_file): 97 | files.append(possible_file) 98 | 99 | # return the files list 100 | return files 101 | 102 | def __init__(self, root, transform=None): 103 | """ 104 | constructor for the class 105 | :param data_dir: path to the directory containing the data 106 | :param transform: transforms to be applied to the images 107 | """ 108 | # define the state of the object 109 | self.data_dir = root 110 | self.transform = transform 111 | 112 | # setup the files for reading 113 | self.files = self.__setup_files() 114 | 115 | def __len__(self): 116 | """ 117 | compute the length of the dataset 118 | :return: len => length of dataset 119 | """ 120 | return len(self.files) 121 | 122 | def __getitem__(self, idx): 123 | """ 124 | obtain the image (read and transform) 125 | :param idx: index of the file required 126 | :return: img => image array 127 | """ 128 | from PIL import Image 129 | 130 | # read the image: 131 | img_name = self.files[idx] 132 | if img_name[-4:] == ".npy": 133 | img = np.load(img_name) 134 | img = Image.fromarray(img.squeeze(0).transpose(1, 2, 0)) 135 | else: 136 | img = Image.open(img_name) 137 | 138 | # apply the transforms on the image 139 | if self.transform is not None: 140 | img = self.transform(img) 141 | 142 | # return the image: 143 | return img, img -------------------------------------------------------------------------------- /cfg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import argparse 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 12 | return True 13 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--world-size', default=-1, type=int, 22 | help='number of nodes for distributed training') 23 | parser.add_argument('--rank', default=-1, type=int, 24 | help='node rank for distributed training') 25 | parser.add_argument('--loca_rank', default=-1, type=int, 26 | help='node rank for distributed training') 27 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 28 | help='url used to set up distributed training') 29 | parser.add_argument('--dist-backend', default='nccl', type=str, 30 | help='distributed backend') 31 | parser.add_argument('--seed', default=12345, type=int, 32 | help='seed for initializing training. ') 33 | parser.add_argument('--gpu', default=None, type=int, 34 | help='GPU id to use.') 35 | parser.add_argument('--multiprocessing-distributed', action='store_true', 36 | help='Use multi-processing distributed training to launch ' 37 | 'N processes per node, which has N GPUs. This is the ' 38 | 'fastest way to use PyTorch for either single node or ' 39 | 'multi node data parallel training') 40 | parser.add_argument( 41 | '--max_epoch', 42 | type=int, 43 | default=200, 44 | help='number of epochs of training') 45 | parser.add_argument( 46 | '--max_iter', 47 | type=int, 48 | default=None, 49 | help='set the max iteration number') 50 | parser.add_argument( 51 | '-gen_bs', 52 | '--gen_batch_size', 53 | type=int, 54 | default=64, 55 | help='size of the batches') 56 | parser.add_argument( 57 | '-dis_bs', 58 | '--dis_batch_size', 59 | type=int, 60 | default=64, 61 | help='size of the batches') 62 | parser.add_argument( 63 | '--g_lr', 64 | type=float, 65 | default=0.0002, 66 | help='adam: gen learning rate') 67 | parser.add_argument( 68 | '--wd', 69 | type=float, 70 | default=0, 71 | help='adamw: gen weight decay') 72 | parser.add_argument( 73 | '--d_lr', 74 | type=float, 75 | default=0.0002, 76 | help='adam: disc learning rate') 77 | parser.add_argument( 78 | '--ctrl_lr', 79 | type=float, 80 | default=3.5e-4, 81 | help='adam: ctrl learning rate') 82 | parser.add_argument( 83 | '--lr_decay', 84 | action='store_true', 85 | help='learning rate decay or not') 86 | parser.add_argument( 87 | '--beta1', 88 | type=float, 89 | default=0.0, 90 | help='adam: decay of first order momentum of gradient') 91 | parser.add_argument( 92 | '--beta2', 93 | type=float, 94 | default=0.9, 95 | help='adam: decay of first order momentum of gradient') 96 | parser.add_argument( 97 | '--num_workers', 98 | type=int, 99 | default=8, 100 | help='number of cpu threads to use during batch generation') 101 | parser.add_argument( 102 | '--latent_dim', 103 | type=int, 104 | default=128, 105 | help='dimensionality of the latent space') 106 | parser.add_argument( 107 | '--img_size', 108 | type=int, 109 | default=32, 110 | help='size of each image dimension') 111 | parser.add_argument( 112 | '--channels', 113 | type=int, 114 | default=3, 115 | help='number of image channels') 116 | parser.add_argument( 117 | '--n_critic', 118 | type=int, 119 | default=1, 120 | help='number of training steps for discriminator per iter') 121 | parser.add_argument( 122 | '--val_freq', 123 | type=int, 124 | default=20, 125 | help='interval between each validation') 126 | parser.add_argument( 127 | '--print_freq', 128 | type=int, 129 | default=100, 130 | help='interval between each verbose') 131 | parser.add_argument( 132 | '--load_path', 133 | type=str, 134 | help='The reload model path') 135 | parser.add_argument( 136 | '--exp_name', 137 | type=str, 138 | help='The name of exp') 139 | parser.add_argument( 140 | '--d_spectral_norm', 141 | type=str2bool, 142 | default=False, 143 | help='add spectral_norm on discriminator?') 144 | parser.add_argument( 145 | '--g_spectral_norm', 146 | type=str2bool, 147 | default=False, 148 | help='add spectral_norm on generator?') 149 | parser.add_argument( 150 | '--dataset', 151 | type=str, 152 | default='cifar10', 153 | help='dataset type') 154 | parser.add_argument( 155 | '--data_path', 156 | type=str, 157 | default='./data', 158 | help='The path of data set') 159 | parser.add_argument('--init_type', type=str, default='normal', 160 | choices=['normal', 'orth', 'xavier_uniform', 'false'], 161 | help='The init type') 162 | parser.add_argument('--gf_dim', type=int, default=64, 163 | help='The base channel num of gen') 164 | parser.add_argument('--df_dim', type=int, default=64, 165 | help='The base channel num of disc') 166 | parser.add_argument( 167 | '--gen_model', 168 | type=str, 169 | help='path of gen model') 170 | parser.add_argument( 171 | '--dis_model', 172 | type=str, 173 | help='path of dis model') 174 | parser.add_argument( 175 | '--controller', 176 | type=str, 177 | default='controller', 178 | help='path of controller') 179 | parser.add_argument('--eval_batch_size', type=int, default=100) 180 | parser.add_argument('--num_eval_imgs', type=int, default=50000) 181 | parser.add_argument( 182 | '--bottom_width', 183 | type=int, 184 | default=4, 185 | help="the base resolution of the GAN") 186 | parser.add_argument('--random_seed', type=int, default=12345) 187 | 188 | # search 189 | parser.add_argument('--shared_epoch', type=int, default=15, 190 | help='the number of epoch to train the shared gan at each search iteration') 191 | parser.add_argument('--grow_step1', type=int, default=25, 192 | help='which iteration to grow the image size from 8 to 16') 193 | parser.add_argument('--grow_step2', type=int, default=55, 194 | help='which iteration to grow the image size from 16 to 32') 195 | parser.add_argument('--max_search_iter', type=int, default=90, 196 | help='max search iterations of this algorithm') 197 | parser.add_argument('--ctrl_step', type=int, default=30, 198 | help='number of steps to train the controller at each search iteration') 199 | parser.add_argument('--ctrl_sample_batch', type=int, default=1, 200 | help='sample size of controller of each step') 201 | parser.add_argument('--hid_size', type=int, default=100, 202 | help='the size of hidden vector') 203 | parser.add_argument('--baseline_decay', type=float, default=0.9, 204 | help='baseline decay rate in RL') 205 | parser.add_argument('--rl_num_eval_img', type=int, default=5000, 206 | help='number of images to be sampled in order to get the reward') 207 | parser.add_argument('--num_candidate', type=int, default=10, 208 | help='number of candidate architectures to be sampled') 209 | parser.add_argument('--topk', type=int, default=5, 210 | help='preserve topk models architectures after each stage' ) 211 | parser.add_argument('--entropy_coeff', type=float, default=1e-3, 212 | help='to encourage the exploration') 213 | parser.add_argument('--dynamic_reset_threshold', type=float, default=1e-3, 214 | help='var threshold') 215 | parser.add_argument('--dynamic_reset_window', type=int, default=500, 216 | help='the window size') 217 | parser.add_argument('--arch', nargs='+', type=int, 218 | help='the vector of a discovered architecture') 219 | parser.add_argument('--optimizer', type=str, default="adam", 220 | help='optimizer') 221 | parser.add_argument('--loss', type=str, default="hinge", 222 | help='loss function') 223 | parser.add_argument('--n_classes', type=int, default=0, 224 | help='classes') 225 | parser.add_argument('--phi', type=float, default=1, 226 | help='wgan-gp phi') 227 | parser.add_argument('--grow_steps', nargs='+', type=int, 228 | help='the vector of a discovered architecture') 229 | parser.add_argument('--D_downsample', type=str, default="avg", 230 | help='downsampling type') 231 | parser.add_argument('--fade_in', type=float, default=1, 232 | help='fade in step') 233 | parser.add_argument('--d_depth', type=int, default=7, 234 | help='Discriminator Depth') 235 | parser.add_argument('--g_depth', type=str, default="5,4,2", 236 | help='Generator Depth') 237 | parser.add_argument('--g_norm', type=str, default="ln", 238 | help='Generator Normalization') 239 | parser.add_argument('--d_norm', type=str, default="ln", 240 | help='Discriminator Normalization') 241 | parser.add_argument('--g_act', type=str, default="gelu", 242 | help='Generator activation Layer') 243 | parser.add_argument('--d_act', type=str, default="gelu", 244 | help='Discriminator activation layer') 245 | parser.add_argument('--patch_size', type=int, default=4, 246 | help='Discriminator Depth') 247 | parser.add_argument('--fid_stat', type=str, default="None", 248 | help='Discriminator Depth') 249 | parser.add_argument('--diff_aug', type=str, default="None", 250 | help='differentiable augmentation type') 251 | parser.add_argument('--accumulated_times', type=int, default=1, 252 | help='gradient accumulation') 253 | parser.add_argument('--g_accumulated_times', type=int, default=1, 254 | help='gradient accumulation') 255 | parser.add_argument('--num_landmarks', type=int, default=64, 256 | help='number of landmarks') 257 | parser.add_argument('--d_heads', type=int, default=4, 258 | help='number of heads') 259 | parser.add_argument('--dropout', type=float, default=0., 260 | help='dropout ratio') 261 | parser.add_argument('--ema', type=float, default=0.995, 262 | help='ema') 263 | parser.add_argument('--ema_warmup', type=float, default=0., 264 | help='ema warm up') 265 | parser.add_argument('--ema_kimg', type=int, default=500, 266 | help='ema thousand images') 267 | parser.add_argument('--latent_norm',action='store_true', 268 | help='latent vector normalization') 269 | parser.add_argument('--ministd',action='store_true', 270 | help='mini batch std') 271 | parser.add_argument('--g_mlp', type=int, default=4, 272 | help='generator mlp ratio') 273 | parser.add_argument('--d_mlp', type=int, default=4, 274 | help='discriminator mlp ratio') 275 | parser.add_argument('--g_window_size', type=int, default=8, 276 | help='generator mlp ratio') 277 | parser.add_argument('--d_window_size', type=int, default=8, 278 | help='discriminator mlp ratio') 279 | parser.add_argument('--show', action='store_true', 280 | help='show') 281 | 282 | opt = parser.parse_args() 283 | 284 | return opt 285 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import torch 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import Dataset 11 | from celeba import CelebA, FFHQ 12 | 13 | class ImageDataset(object): 14 | def __init__(self, args, cur_img_size=None, bs=None): 15 | bs = args.dis_batch_size if bs == None else bs 16 | img_size = cur_img_size if args.fade_in > 0 else args.img_size 17 | if args.dataset.lower() == 'cifar10': 18 | Dt = datasets.CIFAR10 19 | transform = transforms.Compose([ 20 | transforms.Resize(size=(img_size, img_size)), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 24 | ]) 25 | args.n_classes = 0 26 | train_dataset = Dt(root=args.data_path, train=True, transform=transform, download=True) 27 | val_dataset = Dt(root=args.data_path, train=False, transform=transform) 28 | 29 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 30 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 31 | self.train_sampler = train_sampler 32 | self.train = torch.utils.data.DataLoader( 33 | train_dataset, 34 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 35 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) 36 | 37 | self.valid = torch.utils.data.DataLoader( 38 | val_dataset, 39 | batch_size=args.dis_batch_size, shuffle=False, 40 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 41 | 42 | self.test = self.valid 43 | 44 | 45 | elif args.dataset.lower() == 'stl10': 46 | Dt = datasets.STL10 47 | transform = transforms.Compose([ 48 | transforms.Resize(img_size), 49 | transforms.RandomHorizontalFlip(), 50 | transforms.ToTensor(), 51 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 52 | ]) 53 | 54 | train_dataset = Dt(root=args.data_path, split='train+unlabeled', transform=transform, download=True) 55 | val_dataset = Dt(root=args.data_path, split='test', transform=transform) 56 | if args.distributed: 57 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 58 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 59 | else: 60 | train_sampler = None 61 | val_sampler = None 62 | self.train_sampler = train_sampler 63 | self.train = torch.utils.data.DataLoader( 64 | train_dataset, 65 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 66 | num_workers=args.num_workers, pin_memory=True, sampler=train_sampler) 67 | 68 | self.valid = torch.utils.data.DataLoader( 69 | val_dataset, 70 | batch_size=args.dis_batch_size, shuffle=False, 71 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 72 | 73 | self.test = self.valid 74 | elif args.dataset.lower() == 'celeba': 75 | Dt = CelebA 76 | transform = transforms.Compose([ 77 | transforms.Resize(size=(img_size, img_size)), 78 | transforms.RandomHorizontalFlip(), 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 81 | ]) 82 | 83 | train_dataset = Dt(root=args.data_path, transform=transform) 84 | val_dataset = Dt(root=args.data_path, transform=transform) 85 | 86 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 87 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 88 | self.train_sampler = train_sampler 89 | self.train = torch.utils.data.DataLoader( 90 | train_dataset, 91 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 92 | num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler) 93 | 94 | self.valid = torch.utils.data.DataLoader( 95 | val_dataset, 96 | batch_size=args.dis_batch_size, shuffle=False, 97 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 98 | 99 | self.test = torch.utils.data.DataLoader( 100 | val_dataset, 101 | batch_size=args.dis_batch_size, shuffle=False, 102 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 103 | elif args.dataset.lower() == 'ffhq': 104 | Dt = FFHQ 105 | transform = transforms.Compose([ 106 | transforms.Resize(size=(img_size, img_size)), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 110 | ]) 111 | 112 | train_dataset = Dt(root=args.data_path, transform=transform) 113 | val_dataset = Dt(root=args.data_path, transform=transform) 114 | 115 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 116 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 117 | self.train_sampler = train_sampler 118 | 119 | self.train = torch.utils.data.DataLoader( 120 | train_dataset, 121 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 122 | num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler) 123 | 124 | self.valid = torch.utils.data.DataLoader( 125 | val_dataset, 126 | batch_size=args.dis_batch_size, shuffle=False, 127 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 128 | 129 | self.test = torch.utils.data.DataLoader( 130 | val_dataset, 131 | batch_size=args.dis_batch_size, shuffle=False, 132 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 133 | elif args.dataset.lower() == 'bedroom': 134 | Dt = datasets.LSUN 135 | transform = transforms.Compose([ 136 | transforms.Resize(size=(img_size, img_size)), 137 | transforms.RandomHorizontalFlip(), 138 | transforms.ToTensor(), 139 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 140 | ]) 141 | 142 | train_dataset = Dt(root=args.data_path, classes=["bedroom_train"], transform=transform) 143 | val_dataset = Dt(root=args.data_path, classes=["bedroom_val"], transform=transform) 144 | 145 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 146 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 147 | self.train_sampler = train_sampler 148 | self.train = torch.utils.data.DataLoader( 149 | train_dataset, 150 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 151 | num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler) 152 | 153 | self.valid = torch.utils.data.DataLoader( 154 | val_dataset, 155 | batch_size=args.dis_batch_size, shuffle=False, 156 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 157 | 158 | self.test = torch.utils.data.DataLoader( 159 | val_dataset, 160 | batch_size=args.dis_batch_size, shuffle=False, 161 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 162 | elif args.dataset.lower() == 'church': 163 | Dt = datasets.LSUN 164 | transform = transforms.Compose([ 165 | transforms.Resize(size=(img_size, img_size)), 166 | transforms.RandomHorizontalFlip(), 167 | transforms.ToTensor(), 168 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 169 | ]) 170 | 171 | train_dataset = Dt(root=args.data_path, classes=["church_outdoor_train"], transform=transform) 172 | val_dataset = Dt(root=args.data_path, classes=["church_outdoor_val"], transform=transform) 173 | 174 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 175 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 176 | self.train_sampler = train_sampler 177 | self.train = torch.utils.data.DataLoader( 178 | train_dataset, 179 | batch_size=args.dis_batch_size, shuffle=(train_sampler is None), 180 | num_workers=args.num_workers, pin_memory=True, drop_last=True, sampler=train_sampler) 181 | 182 | self.valid = torch.utils.data.DataLoader( 183 | val_dataset, 184 | batch_size=args.dis_batch_size, shuffle=False, 185 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 186 | 187 | self.test = torch.utils.data.DataLoader( 188 | val_dataset, 189 | batch_size=args.dis_batch_size, shuffle=False, 190 | num_workers=args.num_workers, pin_memory=True, sampler=val_sampler) 191 | else: 192 | raise NotImplementedError('Unknown dataset: {}'.format(args.dataset)) -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /exps/__init__.py: -------------------------------------------------------------------------------- 1 | from models_search import ViT_custom_local544444_256_rp_noise, ViT_scale3_local_new_rp, diff_aug, ada, ViT_custom -------------------------------------------------------------------------------- /exps/celeba_hq_256_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python test.py \ 16 | -gen_bs 32 \ 17 | -dis_bs 16 \ 18 | --accumulated_times 4 \ 19 | --g_accumulated_times 4 \ 20 | --dist-url 'tcp://localhost:10641' \ 21 | --dist-backend 'nccl' \ 22 | --multiprocessing-distributed \ 23 | --world-size 1 \ 24 | --rank {args.rank} \ 25 | --dataset celeba \ 26 | --data_path ./celeba_hq \ 27 | --bottom_width 8 \ 28 | --img_size 256 \ 29 | --max_iter 500000 \ 30 | --gen_model Celeba256_gen \ 31 | --dis_model Celeba256_dis \ 32 | --g_window_size 16 \ 33 | --d_window_size 4 \ 34 | --g_norm pn \ 35 | --df_dim 384 \ 36 | --d_depth 3 \ 37 | --g_depth 5,4,4,4,4,4 \ 38 | --latent_dim 512 \ 39 | --gf_dim 1024 \ 40 | --num_workers 32 \ 41 | --g_lr 0.0001 \ 42 | --d_lr 0.0001 \ 43 | --optimizer adam \ 44 | --loss wgangp-eps \ 45 | --wd 1e-3 \ 46 | --beta1 0 \ 47 | --beta2 0.99 \ 48 | --phi 1 \ 49 | --eval_batch_size 10 \ 50 | --num_eval_imgs 50000 \ 51 | --init_type xavier_uniform \ 52 | --n_critic 4 \ 53 | --val_freq 10 \ 54 | --print_freq 50 \ 55 | --grow_steps 0 0 \ 56 | --fade_in 0 \ 57 | --patch_size 2 \ 58 | --diff_aug filter,translation,erase_ratio,color,hue \ 59 | --fid_stat fid_stat/fid_stats_celeba_hq_256.npz \ 60 | --ema 0.995 \ 61 | --load_path ./celeba_256_checkpoint \ 62 | --exp_name celeba_hq_256") 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /exps/celeba_hq_256_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_derived.py \ 16 | -gen_bs 32 \ 17 | -dis_bs 16 \ 18 | --accumulated_times 4 \ 19 | --g_accumulated_times 4 \ 20 | --dist-url 'tcp://localhost:10641' \ 21 | --dist-backend 'nccl' \ 22 | --multiprocessing-distributed \ 23 | --world-size 1 \ 24 | --rank {args.rank} \ 25 | --dataset celeba \ 26 | --data_path ./celeba_hq \ 27 | --bottom_width 8 \ 28 | --img_size 256 \ 29 | --max_iter 500000 \ 30 | --gen_model Celeba256_gen \ 31 | --dis_model Celeba256_dis \ 32 | --g_window_size 16 \ 33 | --d_window_size 4 \ 34 | --g_norm pn \ 35 | --df_dim 384 \ 36 | --d_depth 3 \ 37 | --g_depth 5,4,4,4,4,4 \ 38 | --latent_dim 512 \ 39 | --gf_dim 1024 \ 40 | --num_workers 32 \ 41 | --g_lr 0.0001 \ 42 | --d_lr 0.0001 \ 43 | --optimizer adam \ 44 | --loss wgangp-eps \ 45 | --wd 1e-3 \ 46 | --beta1 0 \ 47 | --beta2 0.99 \ 48 | --phi 1 \ 49 | --eval_batch_size 10 \ 50 | --num_eval_imgs 50000 \ 51 | --init_type xavier_uniform \ 52 | --n_critic 4 \ 53 | --val_freq 10 \ 54 | --print_freq 50 \ 55 | --grow_steps 0 0 \ 56 | --fade_in 0 \ 57 | --patch_size 2 \ 58 | --diff_aug filter,translation,erase_ratio,color,hue \ 59 | --fid_stat fid_stat/fid_stats_celeba_hq_256.npz \ 60 | --ema 0.995 \ 61 | --exp_name celeba_hq_256") 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /exps/church_256_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_derived.py \ 16 | -gen_bs 16 \ 17 | -dis_bs 16 \ 18 | --accumulated_times 4 \ 19 | --g_accumulated_times 8 \ 20 | --dist-url 'tcp://localhost:10641' \ 21 | --dist-backend 'nccl' \ 22 | --multiprocessing-distributed \ 23 | --world-size 1 \ 24 | --rank {args.rank} \ 25 | --dataset church \ 26 | --data_path ./lsun \ 27 | --bottom_width 8 \ 28 | --img_size 256 \ 29 | --max_iter 500000 \ 30 | --gen_model ViT_custom_local544444_256_rp_noise \ 31 | --dis_model ViT_scale3_local_new_rp \ 32 | --g_window_size 16 \ 33 | --d_window_size 16 \ 34 | --g_norm pn \ 35 | --df_dim 384 \ 36 | --d_depth 3 \ 37 | --g_depth 5,4,4,4,4,4 \ 38 | --latent_dim 512 \ 39 | --gf_dim 1024 \ 40 | --num_workers 0 \ 41 | --g_lr 0.0001 \ 42 | --d_lr 0.0001 \ 43 | --optimizer adam \ 44 | --loss wgangp-eps \ 45 | --wd 1e-3 \ 46 | --beta1 0 \ 47 | --beta2 0.99 \ 48 | --phi 1 \ 49 | --eval_batch_size 10 \ 50 | --num_eval_imgs 50000 \ 51 | --init_type xavier_uniform \ 52 | --n_critic 4 \ 53 | --val_freq 5000 \ 54 | --print_freq 50 \ 55 | --grow_steps 0 0 \ 56 | --fade_in 0 \ 57 | --patch_size 4 \ 58 | --diff_aug translation,erase_ratio,color \ 59 | --fid_stat fid_stat/fid_stats_church_256.npz \ 60 | --ema 0.995 \ 61 | --exp_name church_256") 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /exps/cifar_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python test.py \ 16 | -gen_bs 128 \ 17 | -dis_bs 64 \ 18 | --dist-url 'tcp://localhost:14256' \ 19 | --dist-backend 'nccl' \ 20 | --multiprocessing-distributed \ 21 | --world-size 1 \ 22 | --rank {args.rank} \ 23 | --dataset cifar10 \ 24 | --bottom_width 8 \ 25 | --img_size 32 \ 26 | --max_iter 500000 \ 27 | --gen_model ViT_custom_rp \ 28 | --dis_model ViT_custom_scale2_rp_noise \ 29 | --df_dim 384 \ 30 | --d_heads 4 \ 31 | --d_depth 3 \ 32 | --g_depth 5,4,2 \ 33 | --dropout 0 \ 34 | --latent_dim 256 \ 35 | --gf_dim 1024 \ 36 | --num_workers 16 \ 37 | --g_lr 0.0001 \ 38 | --d_lr 0.0001 \ 39 | --optimizer adam \ 40 | --loss wgangp-eps \ 41 | --wd 1e-3 \ 42 | --beta1 0 \ 43 | --beta2 0.99 \ 44 | --phi 1 \ 45 | --eval_batch_size 8 \ 46 | --num_eval_imgs 50000 \ 47 | --init_type xavier_uniform \ 48 | --n_critic 4 \ 49 | --val_freq 20 \ 50 | --print_freq 50 \ 51 | --grow_steps 0 0 \ 52 | --fade_in 0 \ 53 | --patch_size 2 \ 54 | --ema_kimg 500 \ 55 | --ema_warmup 0.1 \ 56 | --ema 0.9999 \ 57 | --diff_aug translation,cutout,color \ 58 | --load_path ./cifar_checkpoint \ 59 | --exp_name cifar_train") 60 | -------------------------------------------------------------------------------- /exps/cifar_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--node', type=str, default="0015") 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_derived.py \ 16 | -gen_bs 128 \ 17 | -dis_bs 64 \ 18 | --dist-url 'tcp://localhost:14256' \ 19 | --dist-backend 'nccl' \ 20 | --multiprocessing-distributed \ 21 | --world-size 1 \ 22 | --rank {args.rank} \ 23 | --dataset cifar10 \ 24 | --bottom_width 8 \ 25 | --img_size 32 \ 26 | --max_iter 500000 \ 27 | --gen_model ViT_custom_rp \ 28 | --dis_model ViT_custom_scale2_rp_noise \ 29 | --df_dim 384 \ 30 | --d_heads 4 \ 31 | --d_depth 3 \ 32 | --g_depth 5,4,2 \ 33 | --dropout 0 \ 34 | --latent_dim 256 \ 35 | --gf_dim 1024 \ 36 | --num_workers 16 \ 37 | --g_lr 0.0001 \ 38 | --d_lr 0.0001 \ 39 | --optimizer adam \ 40 | --loss wgangp-eps \ 41 | --wd 1e-3 \ 42 | --beta1 0 \ 43 | --beta2 0.99 \ 44 | --phi 1 \ 45 | --eval_batch_size 8 \ 46 | --num_eval_imgs 50000 \ 47 | --init_type xavier_uniform \ 48 | --n_critic 4 \ 49 | --val_freq 1000000 \ 50 | --print_freq 50 \ 51 | --grow_steps 0 0 \ 52 | --fade_in 0 \ 53 | --patch_size 2 \ 54 | --ema_kimg 500 \ 55 | --ema_warmup 0.1 \ 56 | --ema 0.9999 \ 57 | --diff_aug translation,cutout,color \ 58 | --exp_name cifar_train") 59 | -------------------------------------------------------------------------------- /exps/stl_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | import os 4 | import argparse 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--rank', type=str, default="0") 9 | parser.add_argument('--test', action='store_true',) 10 | opt = parser.parse_args() 11 | 12 | return opt 13 | args = parse_args() 14 | 15 | if not args.test: 16 | os.system(f"CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train_derived.py \ 17 | -gen_bs 64 \ 18 | -dis_bs 32 \ 19 | --dist-url 'tcp://localhost:14256' \ 20 | --dist-backend 'nccl' \ 21 | --multiprocessing-distributed \ 22 | --world-size 1 \ 23 | --rank {args.rank} \ 24 | --dataset stl10 \ 25 | --bottom_width 12 \ 26 | --img_size 48 \ 27 | --max_iter 500000 \ 28 | --gen_model ViT_custom_rp \ 29 | --dis_model ViT_custom_scale2_rp_noise \ 30 | --df_dim 384 \ 31 | --g_norm pn \ 32 | --d_norm pn \ 33 | --d_heads 4 \ 34 | --d_depth 3 \ 35 | --g_depth 7,5,3 \ 36 | --dropout 0 \ 37 | --latent_dim 512 \ 38 | --gf_dim 1024 \ 39 | --num_workers 16 \ 40 | --g_lr 0.0001 \ 41 | --d_lr 0.0001 \ 42 | --optimizer adam \ 43 | --loss wgangp-eps \ 44 | --wd 1e-3 \ 45 | --beta1 0 \ 46 | --beta2 0.99 \ 47 | --phi 1 \ 48 | --eval_batch_size 8 \ 49 | --num_eval_imgs 20000 \ 50 | --init_type xavier_uniform \ 51 | --n_critic 5 \ 52 | --val_freq 100000 \ 53 | --print_freq 50 \ 54 | --grow_steps 0 0 \ 55 | --fade_in 0 \ 56 | --D_downsample pixel \ 57 | --arch 1 0 1 1 1 0 0 1 1 1 0 1 0 3 \ 58 | --patch_size 2 \ 59 | --ema_kimg 500 \ 60 | --ema_warmup 0.05 \ 61 | --ema 0.9999 \ 62 | --diff_aug translation,stl_erase_ratio,color \ 63 | --exp_name stl_train_latent512_stl_erase") 64 | -------------------------------------------------------------------------------- /fid_stat/fid_stats_celeba_hq_256.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/fid_stat/fid_stats_celeba_hq_256.npz -------------------------------------------------------------------------------- /fid_stat/fid_stats_church_256.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/fid_stat/fid_stats_church_256.npz -------------------------------------------------------------------------------- /fid_stat/fid_stats_cifar10_train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/fid_stat/fid_stats_cifar10_train.npz -------------------------------------------------------------------------------- /fid_stat/stl10_train_unlabeled_fid_stats_48.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VITA-Group/TransGAN/6b85440ca56716fd7a60bac964466cc0296ce663/fid_stat/stl10_train_unlabeled_fid_stats_48.npz -------------------------------------------------------------------------------- /flops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-10-01 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import cfg 12 | import models_search 13 | import datasets 14 | from functions import train, validate, LinearLrDecay, load_params, copy_params, cur_stages 15 | from utils.utils import set_log_dir, save_checkpoint, create_logger 16 | from utils.inception_score import _init_inception 17 | from utils.fid_score import create_inception_graph, check_or_download_inception 18 | 19 | import torch 20 | import os 21 | import numpy as np 22 | import torch.nn as nn 23 | from tensorboardX import SummaryWriter 24 | from tqdm import tqdm 25 | from copy import deepcopy 26 | from adamw import AdamW 27 | import random 28 | 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = True 31 | from models_search.ViT_8_8 import matmul, count_matmul 32 | 33 | 34 | def main(): 35 | args = cfg.parse_args() 36 | torch.cuda.manual_seed(args.random_seed) 37 | torch.cuda.manual_seed_all(args.random_seed) 38 | np.random.seed(args.random_seed) 39 | random.seed(args.random_seed) 40 | torch.backends.cudnn.deterministic = True 41 | 42 | 43 | # set tf env 44 | # _init_inception() 45 | # inception_path = check_or_download_inception(None) 46 | # create_inception_graph(inception_path) 47 | 48 | # # import network 49 | gen_net = eval('models_search.'+args.gen_model+'.Generator')(args=args).cuda() 50 | dis_net = eval('models_search.'+args.dis_model+'.Discriminator')(args=args).cuda() 51 | gen_net.set_arch(args.arch, cur_stage=2) 52 | 53 | import thop, math 54 | dummy_data = (1, 1024) 55 | macs, params = thop.profile(gen_net, inputs=(torch.randn(dummy_data).cuda(), ), 56 | custom_ops={matmul: count_matmul}) 57 | flops, params = thop.clever_format([macs, params], "%.3f") 58 | print('Flops (GB):\t', flops) 59 | print('Params Size (MB):\t', params) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /models_search/ViT_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | def drop_path(x, drop_prob: float = 0., training: bool = False): 5 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 6 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 7 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 8 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 9 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 10 | 'survival rate' as the argument. 11 | """ 12 | if drop_prob == 0. or not training: 13 | return x 14 | keep_prob = 1 - drop_prob 15 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 16 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 17 | random_tensor.floor_() # binarize 18 | output = x.div(keep_prob) * random_tensor 19 | return output 20 | 21 | 22 | class DropPath(nn.Module): 23 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 24 | """ 25 | def __init__(self, drop_prob=None): 26 | super(DropPath, self).__init__() 27 | self.drop_prob = drop_prob 28 | 29 | def forward(self, x): 30 | return drop_path(x, self.drop_prob, self.training) 31 | 32 | from itertools import repeat 33 | from torch._six import container_abcs 34 | 35 | 36 | # From PyTorch internals 37 | def _ntuple(n): 38 | def parse(x): 39 | if isinstance(x, container_abcs.Iterable): 40 | return x 41 | return tuple(repeat(x, n)) 42 | return parse 43 | 44 | 45 | to_1tuple = _ntuple(1) 46 | to_2tuple = _ntuple(2) 47 | to_3tuple = _ntuple(3) 48 | to_4tuple = _ntuple(4) 49 | 50 | 51 | 52 | import torch 53 | import math 54 | import warnings 55 | 56 | 57 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 58 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 59 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 60 | def norm_cdf(x): 61 | # Computes standard normal cumulative distribution function 62 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 63 | 64 | if (mean < a - 2 * std) or (mean > b + 2 * std): 65 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 66 | "The distribution of values may be incorrect.", 67 | stacklevel=2) 68 | 69 | with torch.no_grad(): 70 | # Values are generated by using a truncated uniform distribution and 71 | # then using the inverse CDF for the normal distribution. 72 | # Get upper and lower cdf values 73 | l = norm_cdf((a - mean) / std) 74 | u = norm_cdf((b - mean) / std) 75 | 76 | # Uniformly fill tensor with values from [l, u], then translate to 77 | # [2l-1, 2u-1]. 78 | tensor.uniform_(2 * l - 1, 2 * u - 1) 79 | 80 | # Use inverse cdf transform for normal distribution to get truncated 81 | # standard normal 82 | tensor.erfinv_() 83 | 84 | # Transform to proper mean, std 85 | tensor.mul_(std * math.sqrt(2.)) 86 | tensor.add_(mean) 87 | 88 | # Clamp to ensure it's in the proper range 89 | tensor.clamp_(min=a, max=b) 90 | return tensor 91 | 92 | 93 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 94 | # type: (Tensor, float, float, float, float) -> Tensor 95 | r"""Fills the input Tensor with values drawn from a truncated 96 | normal distribution. The values are effectively drawn from the 97 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 98 | with values outside :math:`[a, b]` redrawn until they are within 99 | the bounds. The method used for generating the random values works 100 | best when :math:`a \leq \text{mean} \leq b`. 101 | Args: 102 | tensor: an n-dimensional `torch.Tensor` 103 | mean: the mean of the normal distribution 104 | std: the standard deviation of the normal distribution 105 | a: the minimum cutoff value 106 | b: the maximum cutoff value 107 | Examples: 108 | >>> w = torch.empty(3, 5) 109 | >>> nn.init.trunc_normal_(w) 110 | """ 111 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) -------------------------------------------------------------------------------- /models_search/__init__.py: -------------------------------------------------------------------------------- 1 | from models_search import ViT_custom_local544444_256_rp, ViT_custom_local544444_256_rp_noise, ViT_scale3_local_new_rp, diff_aug, ada, ViT_helper, ViT_custom, ViT_custom_rp, ViT_custom_scale2, ViT_custom_scale2_rp_noise, Celeba256_gen, Celeba256_dis 2 | -------------------------------------------------------------------------------- /models_search/ada.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | _constant_cache = dict() 5 | 6 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 7 | value = np.asarray(value) 8 | if shape is not None: 9 | shape = tuple(shape) 10 | if dtype is None: 11 | dtype = torch.get_default_dtype() 12 | if device is None: 13 | device = torch.device('cpu') 14 | if memory_format is None: 15 | memory_format = torch.contiguous_format 16 | 17 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 18 | tensor = _constant_cache.get(key, None) 19 | if tensor is None: 20 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 21 | if shape is not None: 22 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 23 | tensor = tensor.contiguous(memory_format=memory_format) 24 | _constant_cache[key] = tensor 25 | return tensor 26 | 27 | def matrix(*rows, device=None): 28 | assert all(len(row) == len(rows[0]) for row in rows) 29 | elems = [x for row in rows for x in row] 30 | ref = [x for x in elems if isinstance(x, torch.Tensor)] 31 | if len(ref) == 0: 32 | return constant(np.asarray(rows), device=device) 33 | assert device is None or device == ref[0].device 34 | elems = [x if isinstance(x, torch.Tensor) else constant(x, shape=ref[0].shape, device=ref[0].device) for x in elems] 35 | return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) 36 | 37 | def translate2d(tx, ty, **kwargs): 38 | return matrix( 39 | [1, 0, tx], 40 | [0, 1, ty], 41 | [0, 0, 1], 42 | **kwargs) 43 | 44 | def translate3d(tx, ty, tz, **kwargs): 45 | return matrix( 46 | [1, 0, 0, tx], 47 | [0, 1, 0, ty], 48 | [0, 0, 1, tz], 49 | [0, 0, 0, 1], 50 | **kwargs) 51 | 52 | def scale2d(sx, sy, **kwargs): 53 | return matrix( 54 | [sx, 0, 0], 55 | [0, sy, 0], 56 | [0, 0, 1], 57 | **kwargs) 58 | 59 | def scale3d(sx, sy, sz, **kwargs): 60 | return matrix( 61 | [sx, 0, 0, 0], 62 | [0, sy, 0, 0], 63 | [0, 0, sz, 0], 64 | [0, 0, 0, 1], 65 | **kwargs) 66 | 67 | def rotate2d(theta, **kwargs): 68 | return matrix( 69 | [torch.cos(theta), torch.sin(-theta), 0], 70 | [torch.sin(theta), torch.cos(theta), 0], 71 | [0, 0, 1], 72 | **kwargs) 73 | 74 | def rotate3d(v, theta, **kwargs): 75 | vx = v[..., 0]; vy = v[..., 1]; vz = v[..., 2] 76 | s = torch.sin(theta); c = torch.cos(theta); cc = 1 - c 77 | return matrix( 78 | [vx*vx*cc+c, vx*vy*cc-vz*s, vx*vz*cc+vy*s, 0], 79 | [vy*vx*cc+vz*s, vy*vy*cc+c, vy*vz*cc-vx*s, 0], 80 | [vz*vx*cc-vy*s, vz*vy*cc+vx*s, vz*vz*cc+c, 0], 81 | [0, 0, 0, 1], 82 | **kwargs) 83 | 84 | def translate2d_inv(tx, ty, **kwargs): 85 | return translate2d(-tx, -ty, **kwargs) 86 | 87 | def scale2d_inv(sx, sy, **kwargs): 88 | return scale2d(1 / sx, 1 / sy, **kwargs) 89 | 90 | def rotate2d_inv(theta, **kwargs): 91 | return rotate2d(-theta, **kwargs) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio 2 | scipy 3 | six 4 | numpy 5 | einops 6 | pillow 7 | python-dateutil==2.7.3 8 | torch==1.7.1 9 | torchvision==0.8.2 10 | tensorboard==1.12.2 11 | tensorboardX==1.6 12 | tensorflow==2.4 13 | tqdm==4.29.1 14 | opencv-python 15 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import logging 6 | import operator 7 | import os 8 | from copy import deepcopy 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from imageio import imsave 14 | from torchvision.utils import make_grid, save_image 15 | from tqdm import tqdm 16 | import cv2 17 | 18 | from utils.fid_score import calculate_fid_given_paths 19 | # from utils.torch_fid_score import get_fid 20 | # from utils.inception_score import get_inception_score 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | import cfg 25 | import models_search 26 | from functions import validate 27 | from utils.utils import set_log_dir, create_logger 28 | from utils.inception_score import _init_inception 29 | from utils.fid_score import create_inception_graph, check_or_download_inception 30 | 31 | import torch 32 | import os 33 | import numpy as np 34 | from tensorboardX import SummaryWriter 35 | from utils.inception_score import get_inception_score 36 | 37 | torch.backends.cudnn.enabled = True 38 | torch.backends.cudnn.benchmark = True 39 | 40 | 41 | def validate(args, fixed_z, fid_stat, epoch, gen_net: nn.Module, writer_dict, clean_dir=True): 42 | writer = writer_dict['writer'] 43 | global_steps = writer_dict['valid_global_steps'] 44 | 45 | # eval mode 46 | gen_net.eval() 47 | 48 | # generate images 49 | with torch.no_grad(): 50 | # sample_imgs = gen_net(fixed_z, epoch) 51 | # img_grid = make_grid(sample_imgs, nrow=5, normalize=True, scale_each=True) 52 | 53 | 54 | eval_iter = args.num_eval_imgs // args.eval_batch_size 55 | img_list = list() 56 | for iter_idx in tqdm(range(eval_iter), desc='sample images'): 57 | z = torch.cuda.FloatTensor(np.random.normal(0, 1, (args.eval_batch_size, args.latent_dim))) 58 | 59 | # Generate a batch of images 60 | gen_imgs = gen_net(z, epoch).mul_(127.5).add_(127.5).clamp_(0.0, 255.0).permute(0, 2, 3, 1).to('cpu', torch.uint8).numpy() 61 | img_list.extend(list(gen_imgs)) 62 | 63 | # mean, std = 0, 0 64 | # get fid score 65 | # mean, std = get_inception_score(img_list) 66 | # print(f"IS score: {mean}") 67 | print('=> calculate fid score') if args.rank == 0 else 0 68 | fid_score = calculate_fid_given_paths([img_list, fid_stat], inception_path=None) 69 | # fid_score = 10000 70 | print(f"FID score: {fid_score}") if args.rank == 0 else 0 71 | with open(f'output/{args.exp_name}.txt', 'a') as f: 72 | print('fid:' + str(fid_score) + 'epoch' + str(epoch), file=f) 73 | 74 | if args.rank == 0: 75 | # writer.add_scalar('Inception_score/mean', mean, global_steps) 76 | # writer.add_scalar('Inception_score/std', std, global_steps) 77 | writer.add_scalar('FID_score', fid_score, global_steps) 78 | 79 | # writer_dict['valid_global_steps'] = global_steps + 1 80 | 81 | return 0, fid_score 82 | 83 | def main(): 84 | args = cfg.parse_args() 85 | torch.cuda.manual_seed(args.random_seed) 86 | assert args.exp_name 87 | # assert args.load_path.endswith('.pth') 88 | assert os.path.exists(args.load_path) 89 | args.path_helper = set_log_dir('logs_eval', args.exp_name) 90 | logger = create_logger(args.path_helper['log_path'], phase='test') 91 | 92 | # set tf env 93 | _init_inception() 94 | inception_path = check_or_download_inception(None) 95 | create_inception_graph(inception_path) 96 | 97 | # import network 98 | gen_net = eval('models_search.'+args.gen_model+'.Generator')(args=args).cuda() 99 | gen_net = torch.nn.DataParallel(gen_net.to("cuda:0"), device_ids=[0]) 100 | 101 | # fid stat 102 | if args.dataset.lower() == 'cifar10': 103 | fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 104 | elif args.dataset.lower() == 'cifar10_flip': 105 | fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 106 | elif args.dataset.lower() == 'stl10': 107 | fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' 108 | elif args.fid_stat is not None: 109 | fid_stat = args.fid_stat 110 | else: 111 | raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') 112 | assert os.path.exists(fid_stat) 113 | 114 | # initial 115 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (4, args.latent_dim))) 116 | 117 | # set writer 118 | logger.info(f'=> resuming from {args.load_path}') 119 | checkpoint_file = args.load_path 120 | assert os.path.exists(checkpoint_file) 121 | checkpoint = torch.load(checkpoint_file) 122 | 123 | if 'avg_gen_state_dict' in checkpoint: 124 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 125 | epoch = checkpoint['epoch'] 126 | logger.info(f'=> loaded checkpoint {checkpoint_file} (epoch {epoch})') 127 | else: 128 | gen_net.load_state_dict(checkpoint) 129 | logger.info(f'=> loaded checkpoint {checkpoint_file}') 130 | 131 | logger.info(args) 132 | writer_dict = { 133 | 'writer': SummaryWriter(args.path_helper['log_path']), 134 | 'valid_global_steps': 0, 135 | } 136 | inception_score, fid_score = validate(args, fixed_z, fid_stat, epoch, gen_net, writer_dict, clean_dir=False) 137 | logger.info(f'Inception score: {inception_score}, FID score: {fid_score}.') 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /train_derived.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import cfg 6 | import models_search 7 | import datasets 8 | from functions import train, validate, save_samples, LinearLrDecay, load_params, copy_params, cur_stages 9 | from utils.utils import set_log_dir, save_checkpoint, create_logger 10 | # from utils.inception_score import _init_inception 11 | # from utils.fid_score import create_inception_graph, check_or_download_inception 12 | 13 | import torch 14 | import torch.multiprocessing as mp 15 | import torch.distributed as dist 16 | import torch.utils.data.distributed 17 | import os 18 | import numpy as np 19 | import torch.nn as nn 20 | from tensorboardX import SummaryWriter 21 | from tqdm import tqdm 22 | from copy import deepcopy 23 | from adamw import AdamW 24 | import random 25 | 26 | # torch.backends.cudnn.enabled = True 27 | # torch.backends.cudnn.benchmark = True 28 | 29 | 30 | def main(): 31 | args = cfg.parse_args() 32 | 33 | # _init_inception() 34 | # inception_path = check_or_download_inception(None) 35 | # create_inception_graph(inception_path) 36 | 37 | if args.seed is not None: 38 | torch.manual_seed(args.random_seed) 39 | torch.cuda.manual_seed(args.random_seed) 40 | torch.cuda.manual_seed_all(args.random_seed) 41 | np.random.seed(args.random_seed) 42 | random.seed(args.random_seed) 43 | torch.backends.cudnn.benchmark = False 44 | torch.backends.cudnn.deterministic = True 45 | 46 | if args.gpu is not None: 47 | warnings.warn('You have chosen a specific GPU. This will completely ' 48 | 'disable data parallelism.') 49 | 50 | if args.dist_url == "env://" and args.world_size == -1: 51 | args.world_size = int(os.environ["WORLD_SIZE"]) 52 | 53 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 54 | 55 | ngpus_per_node = torch.cuda.device_count() 56 | if args.multiprocessing_distributed: 57 | # Since we have ngpus_per_node processes per node, the total world_size 58 | # needs to be adjusted accordingly 59 | args.world_size = ngpus_per_node * args.world_size 60 | # Use torch.multiprocessing.spawn to launch distributed processes: the 61 | # main_worker process function 62 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 63 | else: 64 | # Simply call main_worker function 65 | main_worker(args.gpu, ngpus_per_node, args) 66 | 67 | def main_worker(gpu, ngpus_per_node, args): 68 | args.gpu = gpu 69 | 70 | if args.gpu is not None: 71 | print("Use GPU: {} for training".format(args.gpu)) 72 | 73 | if args.distributed: 74 | if args.dist_url == "env://" and args.rank == -1: 75 | args.rank = int(os.environ["RANK"]) 76 | if args.multiprocessing_distributed: 77 | # For multiprocessing distributed training, rank needs to be the 78 | # global rank among all the processes 79 | args.rank = args.rank * ngpus_per_node + gpu 80 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 81 | world_size=args.world_size, rank=args.rank) 82 | # weight init 83 | def weights_init(m): 84 | classname = m.__class__.__name__ 85 | if classname.find('Conv2d') != -1: 86 | if args.init_type == 'normal': 87 | nn.init.normal_(m.weight.data, 0.0, 0.02) 88 | elif args.init_type == 'orth': 89 | nn.init.orthogonal_(m.weight.data) 90 | elif args.init_type == 'xavier_uniform': 91 | nn.init.xavier_uniform(m.weight.data, 1.) 92 | else: 93 | raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 94 | # elif classname.find('Linear') != -1: 95 | # if args.init_type == 'normal': 96 | # nn.init.normal_(m.weight.data, 0.0, 0.02) 97 | # elif args.init_type == 'orth': 98 | # nn.init.orthogonal_(m.weight.data) 99 | # elif args.init_type == 'xavier_uniform': 100 | # nn.init.xavier_uniform(m.weight.data, 1.) 101 | # else: 102 | # raise NotImplementedError('{} unknown inital type'.format(args.init_type)) 103 | elif classname.find('BatchNorm2d') != -1: 104 | nn.init.normal_(m.weight.data, 1.0, 0.02) 105 | nn.init.constant_(m.bias.data, 0.0) 106 | 107 | # import network 108 | 109 | 110 | if not torch.cuda.is_available(): 111 | print('using CPU, this will be slow') 112 | elif args.distributed: 113 | # For multiprocessing distributed, DistributedDataParallel constructor 114 | # should always set the single device scope, otherwise, 115 | # DistributedDataParallel will use all available devices. 116 | if args.gpu is not None: 117 | 118 | torch.cuda.set_device(args.gpu) 119 | gen_net = eval('models_search.'+args.gen_model+'.Generator')(args=args) 120 | dis_net = eval('models_search.'+args.dis_model+'.Discriminator')(args=args) 121 | 122 | gen_net.apply(weights_init) 123 | dis_net.apply(weights_init) 124 | gen_net.cuda(args.gpu) 125 | dis_net.cuda(args.gpu) 126 | # When using a single GPU per process and per 127 | # DistributedDataParallel, we need to divide the batch size 128 | # ourselves based on the total number of GPUs we have 129 | args.dis_batch_size = int(args.dis_batch_size / ngpus_per_node) 130 | args.gen_batch_size = int(args.gen_batch_size / ngpus_per_node) 131 | args.batch_size = args.dis_batch_size 132 | 133 | args.num_workers = int((args.num_workers + ngpus_per_node - 1) / ngpus_per_node) 134 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net, device_ids=[args.gpu], find_unused_parameters=True) 135 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net, device_ids=[args.gpu], find_unused_parameters=True) 136 | else: 137 | gen_net.cuda() 138 | dis_net.cuda() 139 | # DistributedDataParallel will divide and allocate batch_size to all 140 | # available GPUs if device_ids are not set 141 | gen_net = torch.nn.parallel.DistributedDataParallel(gen_net) 142 | dis_net = torch.nn.parallel.DistributedDataParallel(dis_net) 143 | elif args.gpu is not None: 144 | torch.cuda.set_device(args.gpu) 145 | gen_net.cuda(args.gpu) 146 | dis_net.cuda(args.gpu) 147 | else: 148 | gen_net = torch.nn.DataParallel(gen_net).cuda() 149 | dis_net = torch.nn.DataParallel(dis_net).cuda() 150 | print(dis_net) if args.rank == 0 else 0 151 | 152 | 153 | # set optimizer 154 | if args.optimizer == "adam": 155 | gen_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gen_net.parameters()), 156 | args.g_lr, (args.beta1, args.beta2)) 157 | dis_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, dis_net.parameters()), 158 | args.d_lr, (args.beta1, args.beta2)) 159 | elif args.optimizer == "adamw": 160 | gen_optimizer = AdamW(filter(lambda p: p.requires_grad, gen_net.parameters()), 161 | args.g_lr, weight_decay=args.wd) 162 | dis_optimizer = AdamW(filter(lambda p: p.requires_grad, dis_net.parameters()), 163 | args.g_lr, weight_decay=args.wd) 164 | gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic) 165 | dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic) 166 | 167 | # fid stat 168 | if args.dataset.lower() == 'cifar10': 169 | fid_stat = 'fid_stat/fid_stats_cifar10_train.npz' 170 | elif args.dataset.lower() == 'stl10': 171 | fid_stat = 'fid_stat/stl10_train_unlabeled_fid_stats_48.npz' 172 | elif args.fid_stat is not None: 173 | fid_stat = args.fid_stat 174 | else: 175 | raise NotImplementedError(f'no fid stat for {args.dataset.lower()}') 176 | assert os.path.exists(fid_stat) 177 | 178 | 179 | # epoch number for dis_net 180 | args.max_epoch = args.max_epoch * args.n_critic 181 | dataset = datasets.ImageDataset(args, cur_img_size=8) 182 | train_loader = dataset.train 183 | train_sampler = dataset.train_sampler 184 | print(len(train_loader)) 185 | if args.max_iter: 186 | args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader)) 187 | 188 | # initial 189 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (100, args.latent_dim))) 190 | avg_gen_net = deepcopy(gen_net).cpu() 191 | gen_avg_param = copy_params(avg_gen_net) 192 | del avg_gen_net 193 | start_epoch = 0 194 | best_fid = 1e4 195 | 196 | # set writer 197 | writer = None 198 | if args.load_path: 199 | print(f'=> resuming from {args.load_path}') 200 | assert os.path.exists(args.load_path) 201 | checkpoint_file = os.path.join(args.load_path) 202 | assert os.path.exists(checkpoint_file) 203 | loc = 'cuda:{}'.format(args.gpu) 204 | checkpoint = torch.load(checkpoint_file, map_location=loc) 205 | start_epoch = checkpoint['epoch'] 206 | best_fid = checkpoint['best_fid'] 207 | 208 | 209 | dis_net.load_state_dict(checkpoint['dis_state_dict']) 210 | gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) 211 | dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) 212 | 213 | # avg_gen_net = deepcopy(gen_net) 214 | gen_net.load_state_dict(checkpoint['avg_gen_state_dict']) 215 | gen_avg_param = copy_params(gen_net, mode='gpu') 216 | gen_net.load_state_dict(checkpoint['gen_state_dict']) 217 | fixed_z = checkpoint['fixed_z'] 218 | # del avg_gen_net 219 | # gen_avg_param = list(p.cuda().to(f"cuda:{args.gpu}") for p in gen_avg_param) 220 | 221 | 222 | 223 | args.path_helper = checkpoint['path_helper'] 224 | logger = create_logger(args.path_helper['log_path']) if args.rank == 0 else None 225 | print(f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') 226 | writer = SummaryWriter(args.path_helper['log_path']) if args.rank == 0 else None 227 | del checkpoint 228 | else: 229 | # create new log dir 230 | assert args.exp_name 231 | if args.rank == 0: 232 | args.path_helper = set_log_dir('logs', args.exp_name) 233 | logger = create_logger(args.path_helper['log_path']) 234 | writer = SummaryWriter(args.path_helper['log_path']) 235 | 236 | if args.rank == 0: 237 | logger.info(args) 238 | writer_dict = { 239 | 'writer': writer, 240 | 'train_global_steps': start_epoch * len(train_loader), 241 | 'valid_global_steps': start_epoch // args.val_freq, 242 | } 243 | 244 | # train loop 245 | for epoch in range(int(start_epoch), int(args.max_epoch)): 246 | train_sampler.set_epoch(epoch) 247 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None 248 | cur_stage = cur_stages(epoch, args) 249 | print("cur_stage " + str(cur_stage)) if args.rank==0 else 0 250 | print(f"path: {args.path_helper['prefix']}") if args.rank==0 else 0 251 | train(args, gen_net, dis_net, gen_optimizer, dis_optimizer, gen_avg_param, train_loader, epoch, writer_dict,fixed_z, 252 | lr_schedulers) 253 | 254 | if args.rank == 0 and args.show: 255 | backup_param = copy_params(gen_net) 256 | load_params(gen_net, gen_avg_param, args, mode="cpu") 257 | save_samples(args, fixed_z, fid_stat, epoch, gen_net, writer_dict) 258 | load_params(gen_net, backup_param, args) 259 | 260 | if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch)-1: 261 | backup_param = copy_params(gen_net) 262 | load_params(gen_net, gen_avg_param, args, mode="cpu") 263 | inception_score, fid_score = validate(args, fixed_z, fid_stat, epoch, gen_net, writer_dict) 264 | if args.rank==0: 265 | logger.info(f'Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}.') 266 | load_params(gen_net, backup_param, args) 267 | if fid_score < best_fid: 268 | best_fid = fid_score 269 | is_best = True 270 | else: 271 | is_best = False 272 | else: 273 | is_best = False 274 | 275 | avg_gen_net = deepcopy(gen_net) 276 | load_params(avg_gen_net, gen_avg_param, args) 277 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 278 | and args.rank == 0): 279 | save_checkpoint({ 280 | 'epoch': epoch + 1, 281 | 'gen_model': args.gen_model, 282 | 'dis_model': args.dis_model, 283 | 'gen_state_dict': gen_net.state_dict(), 284 | 'dis_state_dict': dis_net.state_dict(), 285 | 'avg_gen_state_dict': avg_gen_net.state_dict(), 286 | 'gen_optimizer': gen_optimizer.state_dict(), 287 | 'dis_optimizer': dis_optimizer.state_dict(), 288 | 'best_fid': best_fid, 289 | 'path_helper': args.path_helper, 290 | 'fixed_z': fixed_z 291 | }, is_best, args.path_helper['ckpt_path'], filename="checkpoint") 292 | del avg_gen_net 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | 298 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from utils import utils 12 | -------------------------------------------------------------------------------- /utils/cal_fid_stat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-26 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | 8 | import os 9 | import glob 10 | import argparse 11 | import numpy as np 12 | from imageio import imread 13 | import tensorflow as tf 14 | 15 | import utils.fid_score as fid 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | '--data_path', 22 | type=str, 23 | required=True, 24 | help='set path to training set jpg images dir') 25 | parser.add_argument( 26 | '--output_file', 27 | type=str, 28 | default='fid_stat/fid_stats_cifar10_train.npz', 29 | help='path for where to store the statistics') 30 | 31 | opt = parser.parse_args() 32 | print(opt) 33 | return opt 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | 39 | ######## 40 | # PATHS 41 | ######## 42 | data_path = args.data_path 43 | output_path = args.output_file 44 | # if you have downloaded and extracted 45 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 46 | # set this path to the directory where the extracted files are, otherwise 47 | # just set it to None and the script will later download the files for you 48 | inception_path = None 49 | print("check for inception model..", end=" ", flush=True) 50 | inception_path = fid.check_or_download_inception(inception_path) # download inception if necessary 51 | print("ok") 52 | 53 | # loads all images into memory (this might require a lot of RAM!) 54 | print("load images..", end=" ", flush=True) 55 | image_list = glob.glob(os.path.join(data_path, '*.jpg')) 56 | images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list]) 57 | print("%d images found and loaded" % len(images)) 58 | 59 | print("create inception graph..", end=" ", flush=True) 60 | fid.create_inception_graph(inception_path) # load the graph into the current TF graph 61 | print("ok") 62 | 63 | print("calculte FID stats..", end=" ", flush=True) 64 | config = tf.ConfigProto() 65 | config.gpu_options.allow_growth = True 66 | with tf.Session(config=config) as sess: 67 | sess.run(tf.global_variables_initializer()) 68 | mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100) 69 | np.savez_compressed(output_path, mu=mu, sigma=sigma) 70 | print("finished") 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /utils/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = models.inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | 143 | if self.resize_input: 144 | x = F.interpolate(x, 145 | size=(299, 299), 146 | mode='bilinear', 147 | align_corners=False) 148 | 149 | if self.normalize_input: 150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 151 | 152 | for idx, block in enumerate(self.blocks): 153 | x = block(x) 154 | if idx in self.output_blocks: 155 | outp.append(x) 156 | 157 | if idx == self.last_needed_block: 158 | break 159 | 160 | return outp 161 | 162 | 163 | def fid_inception_v3(): 164 | """Build pretrained Inception model for FID computation 165 | The Inception model for FID computation uses a different set of weights 166 | and has a slightly different structure than torchvision's Inception. 167 | This method first constructs torchvision's Inception and then patches the 168 | necessary parts that are different in the FID Inception model. 169 | """ 170 | inception = models.inception_v3(num_classes=1008, 171 | aux_logits=False, 172 | pretrained=False, 173 | init_weights=False) 174 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 175 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 176 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 177 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 178 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 179 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 180 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 181 | inception.Mixed_7b = FIDInceptionE_1(1280) 182 | inception.Mixed_7c = FIDInceptionE_2(2048) 183 | 184 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 185 | inception.load_state_dict(state_dict) 186 | return inception 187 | 188 | 189 | class FIDInceptionA(models.inception.InceptionA): 190 | """InceptionA block patched for FID computation""" 191 | 192 | def __init__(self, in_channels, pool_features): 193 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 194 | 195 | def forward(self, x): 196 | branch1x1 = self.branch1x1(x) 197 | 198 | branch5x5 = self.branch5x5_1(x) 199 | branch5x5 = self.branch5x5_2(branch5x5) 200 | 201 | branch3x3dbl = self.branch3x3dbl_1(x) 202 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 203 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 204 | 205 | # Patch: Tensorflow's average pool does not use the padded zero's in 206 | # its average calculation 207 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 208 | count_include_pad=False) 209 | branch_pool = self.branch_pool(branch_pool) 210 | 211 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 212 | return torch.cat(outputs, 1) 213 | 214 | 215 | class FIDInceptionC(models.inception.InceptionC): 216 | """InceptionC block patched for FID computation""" 217 | 218 | def __init__(self, in_channels, channels_7x7): 219 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 220 | 221 | def forward(self, x): 222 | branch1x1 = self.branch1x1(x) 223 | 224 | branch7x7 = self.branch7x7_1(x) 225 | branch7x7 = self.branch7x7_2(branch7x7) 226 | branch7x7 = self.branch7x7_3(branch7x7) 227 | 228 | branch7x7dbl = self.branch7x7dbl_1(x) 229 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 230 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 231 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 233 | 234 | # Patch: Tensorflow's average pool does not use the padded zero's in 235 | # its average calculation 236 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 237 | count_include_pad=False) 238 | branch_pool = self.branch_pool(branch_pool) 239 | 240 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 241 | return torch.cat(outputs, 1) 242 | 243 | 244 | class FIDInceptionE_1(models.inception.InceptionE): 245 | """First InceptionE block patched for FID computation""" 246 | 247 | def __init__(self, in_channels): 248 | super(FIDInceptionE_1, self).__init__(in_channels) 249 | 250 | def forward(self, x): 251 | branch1x1 = self.branch1x1(x) 252 | 253 | branch3x3 = self.branch3x3_1(x) 254 | branch3x3 = [ 255 | self.branch3x3_2a(branch3x3), 256 | self.branch3x3_2b(branch3x3), 257 | ] 258 | branch3x3 = torch.cat(branch3x3, 1) 259 | 260 | branch3x3dbl = self.branch3x3dbl_1(x) 261 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 262 | branch3x3dbl = [ 263 | self.branch3x3dbl_3a(branch3x3dbl), 264 | self.branch3x3dbl_3b(branch3x3dbl), 265 | ] 266 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 267 | 268 | # Patch: Tensorflow's average pool does not use the padded zero's in 269 | # its average calculation 270 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 271 | count_include_pad=False) 272 | branch_pool = self.branch_pool(branch_pool) 273 | 274 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 275 | return torch.cat(outputs, 1) 276 | 277 | 278 | class FIDInceptionE_2(models.inception.InceptionE): 279 | """Second InceptionE block patched for FID computation""" 280 | 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /utils/inception_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | Parameters 39 | ---------- 40 | output_blocks : list of int 41 | Indices of blocks to return features of. Possible values are: 42 | - 0: corresponds to output of first max pooling 43 | - 1: corresponds to output of second max pooling 44 | - 2: corresponds to output which is fed to aux classifier 45 | - 3: corresponds to output of final average pooling 46 | resize_input : bool 47 | If true, bilinearly resizes input to width and height 299 before 48 | feeding input to model. As the network without fully connected 49 | layers is fully convolutional, it should be able to handle inputs 50 | of arbitrary size, so resizing might not be strictly needed 51 | normalize_input : bool 52 | If true, scales the input from range (0, 1) to the range the 53 | pretrained Inception network expects, namely (-1, 1) 54 | requires_grad : bool 55 | If true, parameters of the model require gradients. Possibly useful 56 | for finetuning the network 57 | use_fid_inception : bool 58 | If true, uses the pretrained Inception model used in Tensorflow's 59 | FID implementation. If false, uses the pretrained Inception model 60 | available in torchvision. The FID Inception model has different 61 | weights and a slightly different structure from torchvision's 62 | Inception model. If you want to compute FID scores, you are 63 | strongly advised to set this parameter to true to get comparable 64 | results. 65 | """ 66 | super(InceptionV3, self).__init__() 67 | 68 | self.resize_input = resize_input 69 | self.normalize_input = normalize_input 70 | self.output_blocks = sorted(output_blocks) 71 | self.last_needed_block = max(output_blocks) 72 | 73 | assert self.last_needed_block <= 3, \ 74 | 'Last possible output block index is 3' 75 | 76 | self.blocks = nn.ModuleList() 77 | 78 | if use_fid_inception: 79 | inception = fid_inception_v3() 80 | else: 81 | inception = models.inception_v3(pretrained=True) 82 | 83 | # Block 0: input to maxpool1 84 | block0 = [ 85 | inception.Conv2d_1a_3x3, 86 | inception.Conv2d_2a_3x3, 87 | inception.Conv2d_2b_3x3, 88 | nn.MaxPool2d(kernel_size=3, stride=2) 89 | ] 90 | self.blocks.append(nn.Sequential(*block0)) 91 | 92 | # Block 1: maxpool1 to maxpool2 93 | if self.last_needed_block >= 1: 94 | block1 = [ 95 | inception.Conv2d_3b_1x1, 96 | inception.Conv2d_4a_3x3, 97 | nn.MaxPool2d(kernel_size=3, stride=2) 98 | ] 99 | self.blocks.append(nn.Sequential(*block1)) 100 | 101 | # Block 2: maxpool2 to aux classifier 102 | if self.last_needed_block >= 2: 103 | block2 = [ 104 | inception.Mixed_5b, 105 | inception.Mixed_5c, 106 | inception.Mixed_5d, 107 | inception.Mixed_6a, 108 | inception.Mixed_6b, 109 | inception.Mixed_6c, 110 | inception.Mixed_6d, 111 | inception.Mixed_6e, 112 | ] 113 | self.blocks.append(nn.Sequential(*block2)) 114 | 115 | # Block 3: aux classifier to final avgpool 116 | if self.last_needed_block >= 3: 117 | block3 = [ 118 | inception.Mixed_7a, 119 | inception.Mixed_7b, 120 | inception.Mixed_7c, 121 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 122 | ] 123 | self.blocks.append(nn.Sequential(*block3)) 124 | 125 | for param in self.parameters(): 126 | param.requires_grad = requires_grad 127 | 128 | def forward(self, inp): 129 | """Get Inception feature maps 130 | Parameters 131 | ---------- 132 | inp : torch.autograd.Variable 133 | Input tensor of shape Bx3xHxW. Values are expected to be in 134 | range (0, 1) 135 | Returns 136 | ------- 137 | List of torch.autograd.Variable, corresponding to the selected output 138 | block, sorted ascending by index 139 | """ 140 | outp = [] 141 | x = inp 142 | 143 | if self.resize_input: 144 | x = F.interpolate(x, 145 | size=(299, 299), 146 | mode='bilinear', 147 | align_corners=False) 148 | 149 | if self.normalize_input: 150 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 151 | 152 | for idx, block in enumerate(self.blocks): 153 | x = block(x) 154 | if idx in self.output_blocks: 155 | outp.append(x) 156 | 157 | if idx == self.last_needed_block: 158 | break 159 | 160 | return outp 161 | 162 | 163 | def fid_inception_v3(): 164 | """Build pretrained Inception model for FID computation 165 | The Inception model for FID computation uses a different set of weights 166 | and has a slightly different structure than torchvision's Inception. 167 | This method first constructs torchvision's Inception and then patches the 168 | necessary parts that are different in the FID Inception model. 169 | """ 170 | inception = models.inception_v3(num_classes=1008, 171 | aux_logits=False, 172 | pretrained=False) 173 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 174 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 175 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 176 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 177 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 178 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 179 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 180 | inception.Mixed_7b = FIDInceptionE_1(1280) 181 | inception.Mixed_7c = FIDInceptionE_2(2048) 182 | 183 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 184 | inception.load_state_dict(state_dict) 185 | return inception 186 | 187 | 188 | class FIDInceptionA(models.inception.InceptionA): 189 | """InceptionA block patched for FID computation""" 190 | 191 | def __init__(self, in_channels, pool_features): 192 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 193 | 194 | def forward(self, x): 195 | branch1x1 = self.branch1x1(x) 196 | 197 | branch5x5 = self.branch5x5_1(x) 198 | branch5x5 = self.branch5x5_2(branch5x5) 199 | 200 | branch3x3dbl = self.branch3x3dbl_1(x) 201 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 202 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 203 | 204 | # Patch: Tensorflow's average pool does not use the padded zero's in 205 | # its average calculation 206 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 207 | count_include_pad=False) 208 | branch_pool = self.branch_pool(branch_pool) 209 | 210 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 211 | return torch.cat(outputs, 1) 212 | 213 | 214 | class FIDInceptionC(models.inception.InceptionC): 215 | """InceptionC block patched for FID computation""" 216 | 217 | def __init__(self, in_channels, channels_7x7): 218 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 219 | 220 | def forward(self, x): 221 | branch1x1 = self.branch1x1(x) 222 | 223 | branch7x7 = self.branch7x7_1(x) 224 | branch7x7 = self.branch7x7_2(branch7x7) 225 | branch7x7 = self.branch7x7_3(branch7x7) 226 | 227 | branch7x7dbl = self.branch7x7dbl_1(x) 228 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 229 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 230 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 231 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 232 | 233 | # Patch: Tensorflow's average pool does not use the padded zero's in 234 | # its average calculation 235 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 236 | count_include_pad=False) 237 | branch_pool = self.branch_pool(branch_pool) 238 | 239 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 240 | return torch.cat(outputs, 1) 241 | 242 | 243 | class FIDInceptionE_1(models.inception.InceptionE): 244 | """First InceptionE block patched for FID computation""" 245 | 246 | def __init__(self, in_channels): 247 | super(FIDInceptionE_1, self).__init__(in_channels) 248 | 249 | def forward(self, x): 250 | branch1x1 = self.branch1x1(x) 251 | 252 | branch3x3 = self.branch3x3_1(x) 253 | branch3x3 = [ 254 | self.branch3x3_2a(branch3x3), 255 | self.branch3x3_2b(branch3x3), 256 | ] 257 | branch3x3 = torch.cat(branch3x3, 1) 258 | 259 | branch3x3dbl = self.branch3x3dbl_1(x) 260 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 261 | branch3x3dbl = [ 262 | self.branch3x3dbl_3a(branch3x3dbl), 263 | self.branch3x3dbl_3b(branch3x3dbl), 264 | ] 265 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 266 | 267 | # Patch: Tensorflow's average pool does not use the padded zero's in 268 | # its average calculation 269 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 270 | count_include_pad=False) 271 | branch_pool = self.branch_pool(branch_pool) 272 | 273 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 274 | return torch.cat(outputs, 1) 275 | 276 | 277 | class FIDInceptionE_2(models.inception.InceptionE): 278 | """Second InceptionE block patched for FID computation""" 279 | 280 | def __init__(self, in_channels): 281 | super(FIDInceptionE_2, self).__init__(in_channels) 282 | 283 | def forward(self, x): 284 | branch1x1 = self.branch1x1(x) 285 | 286 | branch3x3 = self.branch3x3_1(x) 287 | branch3x3 = [ 288 | self.branch3x3_2a(branch3x3), 289 | self.branch3x3_2b(branch3x3), 290 | ] 291 | branch3x3 = torch.cat(branch3x3, 1) 292 | 293 | branch3x3dbl = self.branch3x3dbl_1(x) 294 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 295 | branch3x3dbl = [ 296 | self.branch3x3dbl_3a(branch3x3dbl), 297 | self.branch3x3dbl_3b(branch3x3dbl), 298 | ] 299 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 300 | 301 | # Patch: The FID Inception model uses max pooling instead of average 302 | # pooling. This is likely an error in this specific Inception 303 | # implementation, as other Inception models use average pooling here 304 | # (which matches the description in the paper). 305 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 306 | branch_pool = self.branch_pool(branch_pool) 307 | 308 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 309 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /utils/inception_score.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | import os 8 | import os.path 9 | import sys 10 | import tarfile 11 | 12 | import numpy as np 13 | import tensorflow.compat.v1 as tf 14 | tf.disable_v2_behavior() 15 | from six.moves import urllib 16 | from tqdm import tqdm 17 | 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 | MODEL_DIR = '/tmp/imagenet' 21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 22 | softmax = None 23 | config = tf.ConfigProto() 24 | # config = tf.ConfigProto(device_count = {'GPU': 0}) 25 | config.gpu_options.visible_device_list= '0' 26 | config.gpu_options.allow_growth = True 27 | 28 | 29 | # Call this function with list of images. Each of elements should be a 30 | # numpy array with values ranging from 0 to 255. 31 | def get_inception_score(images, splits=10): 32 | assert (type(images) == list) 33 | assert (type(images[0]) == np.ndarray) 34 | assert (len(images[0].shape) == 3) 35 | assert (np.max(images[0]) > 10) 36 | assert (np.min(images[0]) >= 0.0) 37 | inps = [] 38 | for img in images: 39 | img = img.astype(np.float32) 40 | inps.append(np.expand_dims(img, 0)) 41 | bs = 128 42 | with tf.Session(config=config) as sess: 43 | preds = [] 44 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 45 | for i in tqdm(range(n_batches), desc="Calculate inception score"): 46 | sys.stdout.flush() 47 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 48 | inp = np.concatenate(inp, 0) 49 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 50 | preds.append(pred) 51 | preds = np.concatenate(preds, 0) 52 | scores = [] 53 | for i in range(splits): 54 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 55 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 56 | kl = np.mean(np.sum(kl, 1)) 57 | scores.append(np.exp(kl)) 58 | 59 | sess.close() 60 | return np.mean(scores), np.std(scores) 61 | 62 | 63 | # This function is called automatically. 64 | def _init_inception(): 65 | global softmax 66 | if not os.path.exists(MODEL_DIR): 67 | os.makedirs(MODEL_DIR) 68 | filename = DATA_URL.split('/')[-1] 69 | filepath = os.path.join(MODEL_DIR, filename) 70 | if not os.path.exists(filepath): 71 | def _progress(count, block_size, total_size): 72 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 73 | filename, float(count * block_size) / float(total_size) * 100.0)) 74 | sys.stdout.flush() 75 | 76 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 77 | print() 78 | statinfo = os.stat(filepath) 79 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 80 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 81 | with tf.gfile.FastGFile(os.path.join( 82 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 83 | graph_def = tf.GraphDef() 84 | graph_def.ParseFromString(f.read()) 85 | _ = tf.import_graph_def(graph_def, name='') 86 | # Works with an arbitrary minibatch size. 87 | with tf.Session(config=config) as sess: 88 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 89 | ops = pool3.graph.get_operations() 90 | for op_idx, op in enumerate(ops): 91 | for o in op.outputs: 92 | shape = o.get_shape() 93 | if shape._dims != []: 94 | shape = [s.value for s in shape] 95 | new_shape = [] 96 | for j, s in enumerate(shape): 97 | if s == 1 and j == 0: 98 | new_shape.append(None) 99 | else: 100 | new_shape.append(s) 101 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 102 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 103 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) 104 | softmax = tf.nn.softmax(logits) 105 | sess.close() 106 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Date : 2019-07-25 3 | # @Author : Xinyu Gong (xy_gong@tamu.edu) 4 | # @Link : None 5 | # @Version : 0.0 6 | 7 | import collections 8 | import logging 9 | import math 10 | import os 11 | import time 12 | from datetime import datetime 13 | 14 | import dateutil.tz 15 | import torch 16 | 17 | from typing import Union, Optional, List, Tuple, Text, BinaryIO 18 | import pathlib 19 | import torch 20 | import math 21 | import warnings 22 | import numpy as np 23 | from PIL import Image, ImageDraw, ImageFont, ImageColor 24 | 25 | @torch.no_grad() 26 | def make_grid( 27 | tensor: Union[torch.Tensor, List[torch.Tensor]], 28 | nrow: int = 8, 29 | padding: int = 2, 30 | normalize: bool = False, 31 | value_range: Optional[Tuple[int, int]] = None, 32 | scale_each: bool = False, 33 | pad_value: int = 0, 34 | **kwargs 35 | ) -> torch.Tensor: 36 | """ 37 | Make a grid of images. 38 | Args: 39 | tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) 40 | or a list of images all of the same size. 41 | nrow (int, optional): Number of images displayed in each row of the grid. 42 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 43 | padding (int, optional): amount of padding. Default: ``2``. 44 | normalize (bool, optional): If True, shift the image to the range (0, 1), 45 | by the min and max values specified by :attr:`range`. Default: ``False``. 46 | value_range (tuple, optional): tuple (min, max) where min and max are numbers, 47 | then these numbers are used to normalize the image. By default, min and max 48 | are computed from the tensor. 49 | scale_each (bool, optional): If ``True``, scale each image in the batch of 50 | images separately rather than the (min, max) over all images. Default: ``False``. 51 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 52 | Returns: 53 | grid (Tensor): the tensor containing grid of images. 54 | Example: 55 | See this notebook 56 | `here `_ 57 | """ 58 | if not (torch.is_tensor(tensor) or 59 | (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if "range" in kwargs.keys(): 63 | warning = "range will be deprecated, please use value_range instead." 64 | warnings.warn(warning) 65 | value_range = kwargs["range"] 66 | 67 | # if list of tensors, convert to a 4D mini-batch Tensor 68 | if isinstance(tensor, list): 69 | tensor = torch.stack(tensor, dim=0) 70 | 71 | if tensor.dim() == 2: # single image H x W 72 | tensor = tensor.unsqueeze(0) 73 | if tensor.dim() == 3: # single image 74 | if tensor.size(0) == 1: # if single-channel, convert to 3-channel 75 | tensor = torch.cat((tensor, tensor, tensor), 0) 76 | tensor = tensor.unsqueeze(0) 77 | 78 | if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images 79 | tensor = torch.cat((tensor, tensor, tensor), 1) 80 | 81 | if normalize is True: 82 | tensor = tensor.clone() # avoid modifying tensor in-place 83 | if value_range is not None: 84 | assert isinstance(value_range, tuple), \ 85 | "value_range has to be a tuple (min, max) if specified. min and max are numbers" 86 | 87 | def norm_ip(img, low, high): 88 | img.clamp(min=low, max=high) 89 | img.sub_(low).div_(max(high - low, 1e-5)) 90 | 91 | def norm_range(t, value_range): 92 | if value_range is not None: 93 | norm_ip(t, value_range[0], value_range[1]) 94 | else: 95 | norm_ip(t, float(t.min()), float(t.max())) 96 | 97 | if scale_each is True: 98 | for t in tensor: # loop over mini-batch dimension 99 | norm_range(t, value_range) 100 | else: 101 | norm_range(tensor, value_range) 102 | 103 | if tensor.size(0) == 1: 104 | return tensor.squeeze(0) 105 | 106 | # make the mini-batch of images into a grid 107 | nmaps = tensor.size(0) 108 | xmaps = min(nrow, nmaps) 109 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 110 | height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) 111 | num_channels = tensor.size(1) 112 | grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) 113 | k = 0 114 | for y in range(ymaps): 115 | for x in range(xmaps): 116 | if k >= nmaps: 117 | break 118 | # Tensor.copy_() is a valid method but seems to be missing from the stubs 119 | # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ 120 | grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] 121 | 2, x * width + padding, width - padding 122 | ).copy_(tensor[k]) 123 | k = k + 1 124 | return grid 125 | 126 | 127 | @torch.no_grad() 128 | def save_image( 129 | tensor: Union[torch.Tensor, List[torch.Tensor]], 130 | fp: Union[Text, pathlib.Path, BinaryIO], 131 | format: Optional[str] = None, 132 | **kwargs 133 | ) -> None: 134 | """ 135 | Save a given Tensor into an image file. 136 | Args: 137 | tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, 138 | saves the tensor as a grid of images by calling ``make_grid``. 139 | fp (string or file object): A filename or a file object 140 | format(Optional): If omitted, the format to use is determined from the filename extension. 141 | If a file object was used instead of a filename, this parameter should always be used. 142 | **kwargs: Other arguments are documented in ``make_grid``. 143 | """ 144 | 145 | grid = make_grid(tensor, **kwargs) 146 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 147 | ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 148 | im = Image.fromarray(ndarr) 149 | im.save(fp, format=format) 150 | 151 | 152 | def create_logger(log_dir, phase='train'): 153 | time_str = time.strftime('%Y-%m-%d-%H-%M') 154 | log_file = '{}_{}.log'.format(time_str, phase) 155 | final_log_file = os.path.join(log_dir, log_file) 156 | head = '%(asctime)-15s %(message)s' 157 | logging.basicConfig(filename=str(final_log_file), 158 | format=head) 159 | logger = logging.getLogger() 160 | logger.setLevel(logging.INFO) 161 | console = logging.StreamHandler() 162 | logging.getLogger('').addHandler(console) 163 | 164 | return logger 165 | 166 | 167 | def set_log_dir(root_dir, exp_name): 168 | path_dict = {} 169 | os.makedirs(root_dir, exist_ok=True) 170 | 171 | # set log path 172 | exp_path = os.path.join(root_dir, exp_name) 173 | now = datetime.now(dateutil.tz.tzlocal()) 174 | timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') 175 | prefix = exp_path + '_' + timestamp 176 | os.makedirs(prefix) 177 | path_dict['prefix'] = prefix 178 | 179 | # set checkpoint path 180 | ckpt_path = os.path.join(prefix, 'Model') 181 | os.makedirs(ckpt_path) 182 | path_dict['ckpt_path'] = ckpt_path 183 | 184 | log_path = os.path.join(prefix, 'Log') 185 | os.makedirs(log_path) 186 | path_dict['log_path'] = log_path 187 | 188 | # set sample image path for fid calculation 189 | sample_path = os.path.join(prefix, 'Samples') 190 | os.makedirs(sample_path) 191 | path_dict['sample_path'] = sample_path 192 | 193 | return path_dict 194 | 195 | 196 | def save_checkpoint(states, is_best, output_dir, 197 | filename='checkpoint.pth'): 198 | torch.save(states, os.path.join(output_dir, filename)) 199 | if is_best: 200 | torch.save(states, os.path.join(output_dir, 'checkpoint_best.pth')) 201 | 202 | 203 | class RunningStats: 204 | def __init__(self, WIN_SIZE): 205 | self.mean = 0 206 | self.run_var = 0 207 | self.WIN_SIZE = WIN_SIZE 208 | 209 | self.window = collections.deque(maxlen=WIN_SIZE) 210 | 211 | def clear(self): 212 | self.window.clear() 213 | self.mean = 0 214 | self.run_var = 0 215 | 216 | def is_full(self): 217 | return len(self.window) == self.WIN_SIZE 218 | 219 | def push(self, x): 220 | 221 | if len(self.window) == self.WIN_SIZE: 222 | # Adjusting variance 223 | x_removed = self.window.popleft() 224 | self.window.append(x) 225 | old_m = self.mean 226 | self.mean += (x - x_removed) / self.WIN_SIZE 227 | self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed) 228 | else: 229 | # Calculating first variance 230 | self.window.append(x) 231 | delta = x - self.mean 232 | self.mean += delta / len(self.window) 233 | self.run_var += delta * (x - self.mean) 234 | 235 | def get_mean(self): 236 | return self.mean if len(self.window) else 0.0 237 | 238 | def get_var(self): 239 | return self.run_var / len(self.window) if len(self.window) > 1 else 0.0 240 | 241 | def get_std(self): 242 | return math.sqrt(self.get_var()) 243 | 244 | def get_all(self): 245 | return list(self.window) 246 | 247 | def __str__(self): 248 | return "Current window values: {}".format(list(self.window)) 249 | --------------------------------------------------------------------------------