├── .gitignore ├── LICENSE ├── README.md ├── assets └── watermark.jpg ├── conf ├── README.md ├── attn_maps.yaml ├── config.yaml ├── default_conf.yaml ├── inference.yaml ├── meta_conf.yaml ├── runs │ └── ffhq_256.yaml └── tags ├── dataset └── dataset.py ├── images ├── architecture.png ├── ffhq256.png ├── fidparams.png └── header.png ├── main.py ├── models ├── basic_layers.py ├── discriminator.py ├── generator.py └── stylenat.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── requirements.txt ├── src ├── analysis.py ├── evaluate.py ├── inference.py ├── logging.py ├── throughput.py └── train.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── gen_utils.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py ├── training_stats.py └── utils_spectrum.py └── utils ├── CRDiffAug.py ├── distributed.py ├── fid_score.py ├── helpers.py ├── improved_precision_recall.py └── inception.py /.gitignore: -------------------------------------------------------------------------------- 1 | __*__ 2 | *.pyc 3 | *.sw? 4 | wandb/ 5 | *.pth 6 | *.tar 7 | *.gz 8 | checkpoints/ 9 | samples/ 10 | tags 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2022] [SHI-Labs] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## StyleNAT: Giving Each Head a New Perspective 2 | 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stylenat-giving-each-head-a-new-perspective/image-generation-on-ffhq-256-x-256)](https://paperswithcode.com/sota/image-generation-on-ffhq-256-x-256?p=stylenat-giving-each-head-a-new-perspective) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stylenat-giving-each-head-a-new-perspective/image-generation-on-ffhq-1024-x-1024)](https://paperswithcode.com/sota/image-generation-on-ffhq-1024-x-1024?p=stylenat-giving-each-head-a-new-perspective) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/stylenat-giving-each-head-a-new-perspective/image-generation-on-lsun-churches-256-x-256)](https://paperswithcode.com/sota/image-generation-on-lsun-churches-256-x-256?p=stylenat-giving-each-head-a-new-perspective) 7 | 8 | ##### Authors: [Steven Walton](https://github.com/stevenwalton), [Ali Hassani](https://github.com/alihassanijr), [Xingqian Xu](https://github.com/xingqian2018), Zhangyang Wang, [Humphrey Shi](https://github.com/honghuis) 9 | 10 | ![header](images/header.png) 11 | StyleNAT is a Style-based GAN that exploits [Neighborhood 12 | Attention](https://github.com/SHI-Labs/Neighborhood-Attention-Transformer) to 13 | extend the power of localized attention heads to capture long range features and 14 | maximize information gain within the generative process. 15 | The flexibility of the the system allows it to be adapted to various 16 | environments and datasets. 17 | 18 | ## Abstract: 19 | Image generation has been a long sought-after but challenging task, and performing the generation task in an efficient manner is similarly difficult. 20 | Often researchers attempt to create a "one size fits all" generator, where there are few differences in the parameter space for drastically different datasets. 21 | Herein, we present a new transformer-based framework, dubbed StyleNAT, targeting high-quality image generation with superior efficiency and flexibility. 22 | At the core of our model, is a carefully designed framework that partitions attention heads to capture local and global information, which is achieved through using Neighborhood Attention (NA). 23 | With different heads able to pay attention to varying receptive fields, the model is able to better combine this information, and adapt, in a highly flexible manner, to the data at hand. 24 | StyleNAT attains a new SOTA FID score on FFHQ-256 with 2.046, beating prior arts with convolutional models such as StyleGAN-XL and transformers such as HIT and StyleSwin, and a new transformer SOTA on FFHQ-1024 with an FID score of 4.174. 25 | These results show a 6.4% improvement on FFHQ-256 scores when compared to StyleGAN-XL with a 28% reduction in the number of parameters and 56% improvement in sampling throughput. 26 | 27 | ## Architecture 28 | ![architecture](images/architecture.png) 29 | 30 | ## Performance 31 | ![compute](images/fidparams.png) 32 | 33 | Dataset | FID | Throughput (imgs/s) | Number of Parameters (M) | 34 | |:---:|:---:|:---:|:---:| 35 | FFHQ 256 | [2.046](https://shi-labs.com/projects/stylenat/checkpoints/FFHQ256_940k_flip.pt) | 32.56 | 48.92 | 36 | FFHQ 1024 | [4.174](https://shi-labs.com/projects/stylenat/checkpoints/FFHQ1024_700k.pt) | - | 49.45 | 37 | Church 256 | 3.400 | - | - | 38 | 39 | ## Building and Using StyleNAT 40 | We recommend building an environment with conda to get the best performance. We 41 | recommend the following build instructions but your millage may vary. 42 | ```bash 43 | conda create --name stylenat python=3.10 44 | conda activate stylenat 45 | conda install pytorch torchvision cudatoolkit=11.6 -c pytorch -c nvidia 46 | # Use xargs to install lines one at a time since natten requires torch to be installed first 47 | cat requirements.txt | xargs -L1 pip install 48 | ``` 49 | Note: some version issues can create poor FIDs. Always check your build 50 | environment first with the `evaluate` method. With the best FFHQ score you 51 | should always get under an FID < 2.10 (hopefully closer to 2.05). 52 | 53 | Note: You may need to install torch and torchvision first due to dependence. Pip 54 | does not build sequentially and NATTEN may fail to build. 55 | 56 | Notes: 57 | - [NATTEN can be sped up by using pre-built wheels directly.](https://shi-labs.com/natten/) 58 | 59 | - Some arguments and configurations have changed slightly. Everything should be 60 | backwards compatible but if they aren't please open an issue. 61 | 62 | - This is research code, not production. There are plenty of optimizations that 63 | can be implemented easily. We also are obsessive about logging information and 64 | storing into checkpoints. Official checkpoints may not have all information as 65 | current code tracks due to research and development. Most important things 66 | should exist but if you're missing something important open an issue. Sorry, 67 | seeds and rng states are only available if they exist in the checkpoints. 68 | 69 | ## Inference 70 | Using META's hydra-core we can easily run. We simply have to run 71 | ```bash 72 | python main.py type=inference 73 | ``` 74 | Note that the first time you run this it will take some time, upfirdn2d is compiling. 75 | 76 | By default this will create 10 random inference images with a checkpoint and the 77 | names will be saved as the name of the random seed. 78 | 79 | You can specify seeds by using 80 | ```bash 81 | python main.py type=inference inference.seeds=[1,2,3,4] 82 | ``` 83 | If you would like to specify a set of seeds in a range use the following command 84 | `python main 'inference.seeds="range(start, stop, step)"'` 85 | 86 | 87 | ## Evaluation 88 | If you would like to check the performance of a model we provide the evaluation 89 | mode type. Simply run 90 | ```bash 91 | python main.py type=evaluation 92 | ``` 93 | See the config file to set the proper dataset, checkpoint, etc. 94 | 95 | # Training 96 | If you would like to train a model from scratch we provide the following mode 97 | ```bash 98 | python main.py type=train restart.ckpt=null 99 | ``` 100 | We suggest explicitly setting the checkpoint to null so that you don't 101 | accidentally load a checkpoint. 102 | It is also advised to create a new run file and call 103 | ```bash 104 | python main.py type=train restart.ckpt=null runs=my_new_run 105 | ``` 106 | We also support distributed training. Simply use torchrun 107 | ```bash 108 | torchrun --nnodes=$NUM_NODES --nproc_per_node=$NUM_GPUS --node_rank=$NODE_RANK 109 | main.py type=train 110 | ``` 111 | 112 | ## Modifying Hydra-Configs 113 | The `confs` directory holds yaml configs for different types of runs. If you would 114 | like to adjust parameters (such as changing checkpoints, inference, number of 115 | images, specifying seeds, and so on) you should edit this file. The `confs/runs` folder holds 116 | parameters for the model and training options. It is not advised to modify these 117 | files. It is better to copy them to a new file and use those if you wish to 118 | train a new model. 119 | 120 | ## "Secret" Hydra args 121 | There's a few unspecified hydra configs around wandb. We're just providing a 122 | simple version. But we also support `tags` and `description` under this 123 | argument. 124 | 125 | 126 | 127 | ## Citation: 128 | ```bibtex 129 | @article{walton2022stylenat, 130 | title = {StyleNAT: Giving Each Head a New Perspective}, 131 | author = {Steven Walton and Ali Hassani and Xingqian Xu and Zhangyang Wang and Humphrey Shi}, 132 | year = 2022, 133 | url = {https://arxiv.org/abs/2211.05770}, 134 | eprint = {2211.05770}, 135 | archiveprefix = {arXiv}, 136 | primaryclass = {cs.CV} 137 | } 138 | ``` 139 | 140 | ## Acknowledgements 141 | This code heavily relies upon 142 | [StyleSwin](https://github.com/microsoft/StyleSwin) which also relies upon 143 | [rosinality's StyleGAN2-pytorch](https://github.com/rosinality/stylegan2-pytorch) library. 144 | We also utilize [mseitzer's pytorch-fid](https://github.com/mseitzer/pytorch-fid). 145 | Finally, we utilize SHI-Lab's [NATTEN](https://github.com/SHI-Labs/NATTEN/). 146 | 147 | We'd also like to thank Intelligence Advanced Research Projects Activity 148 | (IARPA), University of Oregon, University of Illinois at Urbana-Champaign, and 149 | Picsart AI Research (PAIR) for their generous support. 150 | -------------------------------------------------------------------------------- /assets/watermark.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/assets/watermark.jpg -------------------------------------------------------------------------------- /conf/README.md: -------------------------------------------------------------------------------- 1 | # Hydra-core configurations 2 | 3 | We use [hydra](https://hydra.cc/docs/intro/) to track our hyper-parameters. 4 | `runs` directory contains settings for the runs that were used for each dataset 5 | 6 | `meta_conf.yaml` is an explanation of all the possible arguments, the type, and 7 | what they are used in. 8 | 9 | `inference.yaml` is a simple inference example configuration. 10 | 11 | `conf.yaml` is the bare minimum configuration for training. 12 | 13 | Here's a quick reference on hydra's override syntax 14 | 15 | Overwriting an existing value 16 | ``` 17 | arg_key=new_vale 18 | nested/key=new_value 19 | ``` 20 | 21 | Add a new argument 22 | ``` 23 | +new_arg=new_value 24 | +new/nested/arg=new_value 25 | ``` 26 | 27 | Remove an argument 28 | ``` 29 | ~arg_key 30 | ~arg_key=value 31 | ~nested/arg 32 | ~nested/arg=value 33 | ``` 34 | 35 | Use a different base config file 36 | ``` 37 | python main.py --config-name attn_map 38 | python main.py -cn inference 39 | ``` 40 | -------------------------------------------------------------------------------- /conf/attn_maps.yaml: -------------------------------------------------------------------------------- 1 | type: attention_map 2 | device: cuda 3 | distributed: False 4 | save_root: my/storage/path/ 5 | 6 | defaults: 7 | - _self_ 8 | - runs: ffhq_256 9 | 10 | restart: 11 | ckpt: FFHQ256_940k_flip.pt 12 | 13 | evaluation: 14 | attn_map: True 15 | save_attn_map: True 16 | attn_map_path: remove/if/want/save_root/as/dir 17 | const_attn_seed: True 18 | 19 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_conf 3 | - runs: ffhq_256 4 | - _self_ 5 | 6 | type: train 7 | device: cuda 8 | latent: 4096 9 | world_size: 1 10 | rank: 0 11 | local_rank: 0 12 | distributed: False 13 | workers: 0 14 | save_root: results 15 | 16 | dataset: 17 | name: ffhq 18 | path: /data/datasets/ffhq_256 19 | 20 | inference: 21 | num_images: 10 22 | save_path: sample_images 23 | batch: 1 24 | 25 | logging: 26 | wandb: False 27 | log_img_batch: False 28 | print_freq: 1000 29 | eval_freq: 50000 30 | save_freq: 25000 31 | checkpoint_path: checkpoints 32 | sample_path: eval_samples 33 | reuse_samplepath: True 34 | 35 | evaluation: 36 | gt_path: /data/datasets/ffhq_256/images/ 37 | num_batches: 12500 38 | total_size: 50000 39 | batch: 4 40 | attn_map: True 41 | save_attn_map: False 42 | attn_map_path: attn_maps 43 | const_attn_seed: True 44 | 45 | wandb: 46 | project_name: my_stylenat 47 | entity: my_wand_team 48 | run_name: ffhq_256_reproduce 49 | -------------------------------------------------------------------------------- /conf/default_conf.yaml: -------------------------------------------------------------------------------- 1 | type: train 2 | device: cuda 3 | latent: 4096 4 | world_size: 1 5 | rank: 0 6 | local_rank: 0 7 | distributed: False 8 | workers: 0 9 | save_root: "/tmp/" 10 | 11 | inference: 12 | num_images: 1 13 | batch: 1 14 | save_path: "inference_samples" 15 | 16 | logging: 17 | print_freq: 1000 18 | eval_freq: 50000 19 | save_freq: 25000 20 | checkpoint_path: "checkpoints" 21 | sample_path: "samples" 22 | reuse_samplepath: True 23 | log_img_batch: False 24 | wandb: False 25 | 26 | evaluation: 27 | total_size: 50000 28 | num_batches: 12500 29 | batch: 4 30 | attn_map: True 31 | 32 | analysis: 33 | save_path: "attn_maps" 34 | const_attn_seed: True 35 | attn_seed: 0 36 | 37 | restart: 38 | wandb_resume: False 39 | reuse_rng: False 40 | 41 | throughput: 42 | rounds: 500 43 | warmup: 10 44 | batch_size: 1 45 | 46 | misc: 47 | seed: null 48 | rng_state: null 49 | py_rng_state: null 50 | -------------------------------------------------------------------------------- /conf/inference.yaml: -------------------------------------------------------------------------------- 1 | type: inference 2 | device: cuda 3 | distributed: False 4 | save_root: storage/ 5 | 6 | defaults: 7 | - _self_ 8 | - runs: ffhq_256 9 | 10 | restart: 11 | ckpt: FFHQ256_940k_flip.pt 12 | 13 | inference: 14 | seeds: range(0,20) 15 | save_path: tmp 16 | -------------------------------------------------------------------------------- /conf/meta_conf.yaml: -------------------------------------------------------------------------------- 1 | explanation_of_conf: |- 2 | This file is an explanation of the configuration YAML file 3 | and every variable that is potentially used within the code. 4 | Since hydra-core was made by META we'll be meta and write this meta conf. 5 | Each section will denote if it is required (req) or optional (opt). 6 | Each item in the section will similarly be noted and then given a type and 7 | description of what it does. 8 | 9 | type: (req:str) Type of job that to perform supports {train,inference,evaluate} 10 | logging_level: (opt:str) python's logging level. Defaults to warning to have better readouts 11 | device: (opt:str) Which acceleration device we will use supports {cuda,cpu} 12 | world_size: (opt:int) The size of your world. In distributed. This value will automatically be set in src/train.py Default 1 13 | rank: (opt:int) This is the rank of your GPU This value will automatically be set in src/train.py Default 1 14 | local_rank: (opt:int) Automatically set Default 0 15 | distributed: (opt:bool) Used in training Bool to determine if using distributed training default based if multiple GPUs are detected 16 | truncation: (opt:float) Unused optional truncation used when generating images Edit eval to enable this 17 | workers: (opt:int) Number of workers to spawn for dataloader default 0 (use this) 18 | save_root: (opt:str) root path for saving. If other paths don't start with / then we assume relative from here 19 | 20 | defaults: 21 | - runs: (req:str) Which yaml file to use from the run dir 22 | - _self_ 23 | 24 | dataset: 25 | (req for all) 26 | name: (req:str) name of the dataset 27 | path: (req:str) /path/to/dataset/ 28 | lmdb: (opt:bool) use the lmdb format overrides other options 29 | 30 | logging: 31 | print_freq: (opt:int) frequency to print to std default 1000 32 | eval_freq: (opt:int) frequency to evaluate, default -1 33 | save_freq: (opt:int) frequency to save model checkpoint 34 | checkpoint_path: (opt:str) where to save checkpoint, default /tmp 35 | sample_path: (opt:str) where to save samples, default /tmp 36 | reuse_samplepath: (opt:bool) write fid samples to same directory (save space) 37 | wandb: (opt:bool) enable wandb logging 38 | log_img_batch: (opt:bool) bool to log the first image batch 39 | 40 | evaluation: 41 | gt_path: (str) path to ground truth images/data 42 | total_size: (int) number of fid images, typically 50000 43 | batch: (opt:int) batch size for generator during fid calculation 44 | attn_map: (opt:bool) generate attention maps (if wandb enabled, we log) 45 | save_attn_map: (opt:bool) do you want to save the attention maps? 46 | attn_map_path: (opt:str) path to save attention maps if save_attn_map 47 | const_attn_seed: (opt:bool or int) Specify true or a integer seed 48 | 49 | inference: need either num_images or seeds 50 | num_images: (opt:int) number of images to sample 51 | seeds: (opt:list,range) list of seeds to sample. "range(a,b)" also accepted 52 | save_path: (str) where to save images relative to cwd 53 | batch: (opt:int) batch size of images to generate. Default 1 54 | 55 | restart: options to restart a run 56 | ckpt: (str) path to checkpoint 57 | wandb_resume: (opt:str) see wandb.init resume 58 | wandb_id: (opt:str) wandb run id 59 | start_iter: (opt: int) you can manually specify the iteration but we'll try to figure it out 60 | reuse_rng: (opt:bool) attempt to use the checkpoint's rng information 61 | 62 | misc: 63 | seed: (opt:int) random seed for run 64 | rng_state: (opt:tensor) (don't set!) the rng_state of the run 65 | py_rng_state: (opt:list) (don't set!) the python random rng state 66 | 67 | wandb: 68 | entity: (str) your username 69 | project_name: (str): name of wandb project 70 | run_name: (opt:str) name of your run 71 | tags: (opt:list) list of tags for the run 72 | description: (opt:str) description information for your run 73 | 74 | 75 | -------------------------------------------------------------------------------- /conf/runs/ffhq_256.yaml: -------------------------------------------------------------------------------- 1 | size: 256 2 | 3 | generator: 4 | n_mlp: 8 5 | block_type: nat 6 | style_dim: 512 7 | lr: 0.00004 8 | channel_multiplier: 2 9 | mlp_ratio: 4 10 | use_checkpoint: False 11 | lr_mlp: 0.01 12 | enable_full_resolution: 8 13 | min_heads: 4 14 | qkv_bias: True 15 | qk_scale: null 16 | proj_drop: 0. 17 | attn_drop: 0. 18 | kernels: [[3],[7],[7],[7],[7],[7],[7],[7],[7]] 19 | dilations: [[1],[1],[1,2],[1,4],[1,8],[1,16],[1,32],[1,64],[1,128]] 20 | reg_every: null 21 | params: 0 22 | 23 | discriminator: 24 | lr: 0.0002 25 | channel_multiplier: 2 26 | blur_kernel: [1, 3, 3, 1] 27 | sn: True 28 | ssd: False 29 | reg_every: 16 30 | params: 0 31 | 32 | training: 33 | iter: 1000000 34 | batch: 8 35 | use_flip: True 36 | ttur: True 37 | r1: 10 38 | bcr: True 39 | bcr_fake_lambda: 10 40 | bcr_real_lambda: 10 41 | beta1: 0.0 42 | beta2: 0.99 43 | start_dim: 512 44 | workers: 8 45 | lr_decay: True 46 | lr_decay_start_steps: 775000 47 | gan_weight: 1 48 | -------------------------------------------------------------------------------- /conf/tags: -------------------------------------------------------------------------------- 1 | !_TAG_FILE_SORTED 2 /0=unsorted, 1=sorted, 2=foldcase/ 2 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from torchvision import transforms as T 9 | from torchvision import datasets 10 | 11 | from utils.distributed import get_rank 12 | 13 | 14 | class MultiResolutionDataset(Dataset): 15 | def __init__(self, path, transform, resolution=256): 16 | self.env = lmdb.open( 17 | path, 18 | max_readers=32, 19 | readonly=True, 20 | lock=False, 21 | readahead=False, 22 | meminit=False, 23 | ) 24 | 25 | if not self.env: 26 | raise IOError('Cannot open lmdb dataset', path) 27 | 28 | with self.env.begin(write=False) as txn: 29 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 30 | 31 | self.resolution = resolution 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return self.length 36 | 37 | def __getitem__(self, index): 38 | with self.env.begin(write=False) as txn: 39 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 40 | img_bytes = txn.get(key) 41 | 42 | buffer = BytesIO(img_bytes) 43 | img = Image.open(buffer) 44 | img = self.transform(img) 45 | 46 | return img 47 | 48 | def unnormalize(image): 49 | if image.dim() == 4: 50 | image[:, 0, :, :] = image[:, 0, :, :] * 0.229 + 0.485 51 | image[:, 1, :, :] = image[:, 1, :, :] * 0.224 + 0.456 52 | image[:, 2, :, :] = image[:, 2, :, :] * 0.225 + 0.406 53 | elif image.dim() == 3: 54 | image[0, :, :] = image[0, :, :] * 0.229 + 0.485 55 | image[1, :, :] = image[1, :, :] * 0.224 + 0.456 56 | image[2, :, :] = image[2, :, :] * 0.225 + 0.406 57 | else: 58 | raise NotImplemented(f"Can't handle image of dimension {image.dim()}, please use a 3 or 4 dimensional image") 59 | return image 60 | 61 | def get_dataset(args, evaluation=True): 62 | normalize = T.Normalize(mean=[0.485, 0.456, 0.406], 63 | std=[0.229, 0.224, 0.225]) 64 | # Resize then center crop. Make sure data integrity is good 65 | transforms = [T.Resize(args.runs.size), 66 | T.CenterCrop(args.runs.size), 67 | ] 68 | if args.runs.training.use_flip and not evaluation: 69 | transforms.append(T.RandomHorizontalFlip()) 70 | transforms.append(T.ToTensor()) 71 | transforms.append(normalize) 72 | transforms = T.Compose(transforms) 73 | 74 | if "lmdb" in args.dataset and args.dataset.lmdb: 75 | if get_rank() == 0: 76 | print(f"Using LMDB with {args.dataset.path}") 77 | dataset = MultiResolutionDataset(path=args.dataset.path, 78 | transform=transforms, 79 | resolution=args.runs.size) 80 | elif args.dataset.name in ["cifar10"]: 81 | if get_rank() == 0: 82 | print(f"Loading CIFAR-10") 83 | dataset = datasets.CIFAR10(root=args.dataset.path, 84 | transform=transforms) 85 | else: 86 | if get_rank() == 0: 87 | print(f"Loading ImageFolder dataset from {args.dataset.path}") 88 | dataset = datasets.ImageFolder(root=args.dataset.path, 89 | transform=transforms) 90 | return dataset 91 | 92 | def data_sampler(dataset, shuffle, distributed): 93 | if distributed: 94 | return DistributedSampler(dataset, shuffle=shuffle) 95 | if shuffle: 96 | return RandomSampler(dataset) 97 | else: 98 | return SequentialSampler(dataset) 99 | 100 | def get_loader(args, dataset, batch_size=1): 101 | loader = DataLoader(dataset, 102 | batch_size=batch_size, 103 | num_workers=args.workers, 104 | sampler=data_sampler(dataset, 105 | shuffle=True, 106 | distributed=args.distributed), 107 | drop_last=True, 108 | ) 109 | return loader 110 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/architecture.png -------------------------------------------------------------------------------- /images/ffhq256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/ffhq256.png -------------------------------------------------------------------------------- /images/fidparams.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/fidparams.png -------------------------------------------------------------------------------- /images/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/images/header.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | import hydra 3 | import os 4 | import random 5 | import warnings 6 | from datetime import timedelta 7 | import logging 8 | import torch 9 | from omegaconf import OmegaConf, open_dict 10 | import natten 11 | 12 | from models.generator import Generator 13 | from utils.distributed import get_rank, synchronize, get_world_size 14 | from utils import helpers 15 | 16 | from src.train import train 17 | from src.inference import inference 18 | from src.evaluate import evaluate 19 | from src.analysis import visualize_attention 20 | from src.throughput import throughput 21 | 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cuda.matmul.allow_tf32 = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | 26 | @hydra.main(version_base=None, config_path="conf", config_name="config") 27 | def main(args): 28 | if "logging_level" in args: 29 | if type(args.logging_level) == str: 30 | _logging_level = {"DEBUG": logging.DEBUG, "INFO":logging.INFO, 31 | "WARNING": logging.WARNING, "ERROR": logging.ERROR, 32 | "CRITICAL": logging.CRITICAL}[args.logging_level.upper()] 33 | else: 34 | _logging_level = int(args.logging_level) 35 | logging.getLogger().setLevel(_logging_level) 36 | else: 37 | logging.getLogger().setLevel(logging.WARNING) 38 | helpers.validate_args(args) 39 | ckpt = None 40 | if "restart" in args and "ckpt" in args.restart and args.restart.ckpt: 41 | assert(os.path.exists(args.restart.ckpt)),f"Can't find a checkpoint "\ 42 | f"at {args.restart.ckpt}" 43 | ckpt = torch.load(args.restart.ckpt, map_location=lambda storage, loc: storage) 44 | if "start_iter" not in args.restart: 45 | with open_dict(args): 46 | try: 47 | args.restart.start_iter = \ 48 | int(os.path.basename(args.restart.ckpt)\ 49 | .split(".pt")[0]) 50 | except: 51 | args.restart.start_iter = 0 52 | 53 | helpers.rng_reproducibility(args, ckpt) 54 | #if "WORLD_SIZE" in os.environ: 55 | # # Single node multi GPU 56 | # n_gpu = int(os.environ["WORLD_SIZE"]) 57 | #else: 58 | # n_gpu = torch.cuda.device_count() 59 | #args.distributed = n_gpu > 1 60 | 61 | if args.distributed: 62 | try: 63 | args.local_rank = int(os.environ["LOCAL_RANK"]) 64 | except: 65 | args.local_rank = 0 66 | torch.cuda.set_device(args.local_rank) 67 | torch.distributed.init_process_group(backend="nccl", 68 | init_method="env://", 69 | timeout=timedelta(0, 180000)) 70 | args.rank = get_rank() 71 | args.world_size = get_world_size() 72 | args.device = f"cuda:{args.local_rank}" 73 | torch.cuda.set_device(args.local_rank) 74 | synchronize() 75 | 76 | if get_rank() == 0: 77 | if args.save_root[-1] != "/": args.save_root += "/" 78 | if not os.path.exists(args.save_root): 79 | print(f"[bold yellow]WARNING:[/] Save root {args.save_root} path "\ 80 | f"does not exist. Creating...") 81 | os.mkdir(args.save_root) 82 | if get_rank() == 0 and args.type in ['train']: 83 | # Make sample path 84 | if "sample_path" not in args.logging: 85 | samp_path = args.save_root + "samples" 86 | else: 87 | samp_path = args.logging.sample_path 88 | if args.logging.sample_path[0] != "/": 89 | samp_path = args.save_root + samp_path 90 | if not os.path.exists(samp_path): 91 | print(f"====> MAKING SAMPLE DIRECTORY: {samp_path}") 92 | os.mkdir(samp_path) 93 | # make checkpoint path 94 | if "checkpoint_path" not in args.logging: 95 | ckpt_path = args.save_root + "checkpoints" 96 | else: 97 | ckpt_path = args.logging.checkpoint_path 98 | if args.logging.checkpoint_path[0] != "/": 99 | ckpt_path = args.save_root + ckpt_path 100 | if not os.path.exists(ckpt_path): 101 | print(f"====> MAKING CHECKPOINT DIRECTORY: {ckpt_path}") 102 | os.mkdir(ckpt_path) 103 | 104 | 105 | # Only load gen if training, to save space 106 | if args.type == "train": 107 | generator = Generator(args=args.runs.generator, size=args.runs.size).to(args.device) 108 | g_ema = Generator(args=args.runs.generator, size=args.runs.size).to(args.device) 109 | 110 | if hasattr(g_ema, "num_params"): 111 | args.runs.generator.params = g_ema.num_params() / 1e6 112 | else: 113 | num_params = sum([m.numel() for m in g_ema.parameters()]) 114 | if hasattr(args.runs.generator, "params"): 115 | args.runs.generator.params = num_params / 1e6 116 | else: 117 | with open_dict(args): 118 | args.runs.generator.params = num_params / 1e6 119 | 120 | # Load generator checkpoint 121 | if ckpt is not None: 122 | # Load generator checkpoints. But only load g if training 123 | if get_rank() == 0: 124 | print(f"Loading Generative Model") 125 | if 'state_dicts' in ckpt.keys(): 126 | if args.type == "train": 127 | generator.load_state_dict(ckpt["state_dicts"]["g"]) 128 | g_ema.load_state_dict(ckpt["state_dicts"]["g_ema"]) 129 | elif set(['g', 'g_ema']).issubset(ckpt.keys()): # Old 130 | if args.type == "train": 131 | generator = Generator(args=args.runs.generator, 132 | size=args.runs.size, legacy=True).to(args.device) 133 | try: 134 | generator.load_state_dict(ckpt['g']) 135 | except Exception as e: 136 | print(e) 137 | print(f"[bold red]ERROR:[/] Failed to load checkpoint. " \ 138 | f"Likely a mismatch between kernel size or dilation " \ 139 | f"in the config and the checkpoint. "\ 140 | f"Checkpoint has kernels {args.runs.generator.kernels}, " \ 141 | f"and dilations {args.runs.generator.dilations}") 142 | exit(1) 143 | g_ema = Generator(args=args.runs.generator, 144 | size=args.runs.size, legacy=True).to(args.device) 145 | try: 146 | g_ema.load_state_dict(ckpt["g_ema"]) 147 | except Exception as e: 148 | print(e) 149 | print(f"[bold red]ERROR:[/] Failed to load checkpoint. " \ 150 | f"Likely a mismatch between kernel size or dilation " \ 151 | f"in the config and the checkpoint. "\ 152 | f"\nCheckpoint has " \ 153 | f"\n\tkernels {args.runs.generator.kernels}, " \ 154 | f"\n\tdilations {args.runs.generator.dilations}") 155 | exit(1) 156 | else: 157 | raise ValueError(f"Checkpoint dict broken:\n"\ 158 | f"Checkpoint name: {args.restart.ckpt}\n" 159 | f"Keys: {ckpt.keys()}") 160 | g_ema.eval() 161 | 162 | # Print mode in a nice format 163 | if get_rank() == 0: 164 | print("\n" + ("=" * 50)) 165 | print(f" Mode: {args.type} ".center(49, "=")) 166 | print("=" * 50, "\n") 167 | 168 | if args.type == "train": 169 | train(args=args, 170 | generator=generator, 171 | g_ema=g_ema, 172 | ckpt=ckpt, 173 | ) 174 | elif args.type == "inference": 175 | inference(args=args, generator=g_ema) 176 | elif args.type == "evaluate": 177 | evaluate(args=args, generator=g_ema) 178 | elif args.type == "attention_map": 179 | visualize_attention(args, g_ema, 180 | save_maps=args.analysis.save_path, 181 | ) 182 | elif args.type == "throughput": 183 | throughput(generator=g_ema, 184 | style_dim=args.runs.generator.style_dim, 185 | batch_size=args.throughput.batch_size, 186 | rounds=args.throughput.rounds, 187 | warmup=args.throughput.warmup, 188 | ) 189 | 190 | 191 | if __name__ == '__main__': 192 | main() 193 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Most of the Modules from here come from rosinality's StyleGAN2-pytorch 3 | https://github.com/rosinality/stylegan2-pytorch 4 | 5 | Slight changes to these have been made to include Spectral Norm. 6 | StyleSwin changed the code to include this option. 7 | ''' 8 | import math 9 | import torch 10 | from op import FusedLeakyReLU, upfirdn2d 11 | from torch import nn 12 | from torch.nn import functional as F 13 | from torch.nn.utils import spectral_norm 14 | 15 | from models.basic_layers import (Blur, Downsample, EqualConv2d, EqualLinear, 16 | ScaledLeakyReLU) 17 | 18 | 19 | class ConvLayer(nn.Sequential): 20 | def __init__( 21 | self, 22 | in_channel, 23 | out_channel, 24 | kernel_size, 25 | downsample=False, 26 | blur_kernel=[1, 3, 3, 1], 27 | bias=True, 28 | activate=True, 29 | sn=False 30 | ): 31 | layers = [] 32 | 33 | if downsample: 34 | factor = 2 35 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 36 | pad0 = (p + 1) // 2 37 | pad1 = p // 2 38 | 39 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 40 | 41 | stride = 2 42 | self.padding = 0 43 | 44 | else: 45 | stride = 1 46 | self.padding = kernel_size // 2 47 | 48 | if sn: 49 | # Not use equal conv2d when apply SN 50 | layers.append( 51 | spectral_norm(nn.Conv2d( 52 | in_channel, 53 | out_channel, 54 | kernel_size, 55 | padding=self.padding, 56 | stride=stride, 57 | bias=bias and not activate, 58 | )) 59 | ) 60 | else: 61 | layers.append( 62 | EqualConv2d( 63 | in_channel, 64 | out_channel, 65 | kernel_size, 66 | padding=self.padding, 67 | stride=stride, 68 | bias=bias and not activate, 69 | ) 70 | ) 71 | 72 | if activate: 73 | if bias: 74 | layers.append(FusedLeakyReLU(out_channel)) 75 | else: 76 | layers.append(ScaledLeakyReLU(0.2)) 77 | 78 | super().__init__(*layers) 79 | 80 | 81 | class ConvBlock(nn.Module): 82 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], sn=False): 83 | super().__init__() 84 | 85 | self.conv1 = ConvLayer(in_channel, in_channel, 3, sn=sn) 86 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, sn=sn) 87 | 88 | def forward(self, input): 89 | out = self.conv1(input) 90 | out = self.conv2(out) 91 | 92 | return out 93 | 94 | 95 | def get_haar_wavelet(in_channels): 96 | haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2) 97 | haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2) 98 | haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0] 99 | 100 | haar_wav_ll = haar_wav_l.T * haar_wav_l 101 | haar_wav_lh = haar_wav_h.T * haar_wav_l 102 | haar_wav_hl = haar_wav_l.T * haar_wav_h 103 | haar_wav_hh = haar_wav_h.T * haar_wav_h 104 | 105 | return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh 106 | 107 | 108 | class HaarTransform(nn.Module): 109 | def __init__(self, in_channels): 110 | super().__init__() 111 | 112 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 113 | 114 | self.register_buffer('ll', ll) 115 | self.register_buffer('lh', lh) 116 | self.register_buffer('hl', hl) 117 | self.register_buffer('hh', hh) 118 | 119 | def forward(self, input): 120 | ll = upfirdn2d(input, self.ll, down=2) 121 | lh = upfirdn2d(input, self.lh, down=2) 122 | hl = upfirdn2d(input, self.hl, down=2) 123 | hh = upfirdn2d(input, self.hh, down=2) 124 | 125 | return torch.cat((ll, lh, hl, hh), 1) 126 | 127 | 128 | class InverseHaarTransform(nn.Module): 129 | def __init__(self, in_channels): 130 | super().__init__() 131 | 132 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 133 | 134 | self.register_buffer('ll', ll) 135 | self.register_buffer('lh', -lh) 136 | self.register_buffer('hl', -hl) 137 | self.register_buffer('hh', hh) 138 | 139 | def forward(self, input): 140 | ll, lh, hl, hh = input.chunk(4, 1) 141 | ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0)) 142 | lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0)) 143 | hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0)) 144 | hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0)) 145 | 146 | return ll + lh + hl + hh 147 | 148 | 149 | class FromRGB(nn.Module): 150 | def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1], sn=False): 151 | super().__init__() 152 | 153 | self.downsample = downsample 154 | 155 | if downsample: 156 | self.iwt = InverseHaarTransform(3) 157 | self.downsample = Downsample(blur_kernel) 158 | self.dwt = HaarTransform(3) 159 | 160 | self.conv = ConvLayer(3 * 4, out_channel, 1, sn=sn) 161 | 162 | def forward(self, input, skip=None): 163 | if self.downsample: 164 | input = self.iwt(input) 165 | input = self.downsample(input) 166 | input = self.dwt(input) 167 | 168 | out = self.conv(input) 169 | 170 | if skip is not None: 171 | out = out + skip 172 | 173 | return input, out 174 | 175 | 176 | class Discriminator(nn.Module): 177 | def __init__(self, 178 | args, 179 | size, 180 | ): 181 | super().__init__() 182 | 183 | channels = { 184 | 4: 512, 185 | 8: 512, 186 | 16: 512, 187 | 32: 512, 188 | 64: 256 * args.channel_multiplier, 189 | 128: 128 * args.channel_multiplier, 190 | 256: 64 * args.channel_multiplier, 191 | 512: 32 * args.channel_multiplier, 192 | 1024: 16 * args.channel_multiplier, 193 | } 194 | self.size = size 195 | self.args = args 196 | 197 | self.dwt = HaarTransform(3) 198 | 199 | self.from_rgbs = nn.ModuleList() 200 | self.convs = nn.ModuleList() 201 | 202 | log_size = int(math.log(self.size, 2)) - 1 203 | 204 | in_channel = channels[self.size] 205 | 206 | for i in range(log_size, 2, -1): 207 | out_channel = channels[2 ** (i - 1)] 208 | 209 | self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size, sn=args.sn)) 210 | self.convs.append(ConvBlock(in_channel, out_channel, args.blur_kernel, sn=args.sn)) 211 | 212 | in_channel = out_channel 213 | 214 | self.from_rgbs.append(FromRGB(channels[4], sn=args.sn)) 215 | 216 | self.stddev_group = 4 217 | self.stddev_feat = 1 218 | 219 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, sn=args.sn) 220 | if args.sn: 221 | self.final_linear = nn.Sequential( 222 | spectral_norm(nn.Linear(channels[4] * 4 * 4, channels[4])), 223 | FusedLeakyReLU(channels[4]), 224 | spectral_norm(nn.Linear(channels[4], 1)), 225 | ) 226 | else: 227 | self.final_linear = nn.Sequential( 228 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 229 | EqualLinear(channels[4], 1), 230 | ) 231 | 232 | def forward(self, input): 233 | input = self.dwt(input) 234 | out = None 235 | 236 | for from_rgb, conv in zip(self.from_rgbs, self.convs): 237 | input, out = from_rgb(input, out) 238 | out = conv(out) 239 | 240 | _, out = self.from_rgbs[-1](input, out) 241 | 242 | batch, channel, height, width = out.shape 243 | group = min(batch, self.stddev_group) 244 | stddev = out.view( 245 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 246 | ) 247 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 248 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 249 | stddev = stddev.repeat(group, 1, height, width) 250 | out = torch.cat([out, stddev], 1) 251 | 252 | out = self.final_conv(out) 253 | 254 | out = out.view(batch, -1) 255 | out = self.final_linear(out) 256 | 257 | return out 258 | 259 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 5 | from .upfirdn2d import upfirdn2d 6 | 7 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | # From StyleSwin 2 | 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.autograd import Function 9 | from torch.utils.cpp_extension import load 10 | 11 | from torch.cuda.amp import custom_fwd, custom_bwd 12 | 13 | module_path = os.path.dirname(__file__) 14 | fused = load( 15 | "fused", 16 | sources=[ 17 | os.path.join(module_path, "fused_bias_act.cpp"), 18 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 19 | ], 20 | ) 21 | 22 | 23 | class FusedLeakyReLUFunctionBackward(Function): 24 | @staticmethod 25 | def forward(ctx, grad_output, out, negative_slope, scale): 26 | ctx.save_for_backward(out) 27 | ctx.negative_slope = negative_slope 28 | ctx.scale = scale 29 | 30 | empty = grad_output.new_empty(0) 31 | 32 | grad_input = fused.fused_bias_act( 33 | grad_output, empty, out, 3, 1, negative_slope, scale 34 | ) 35 | 36 | dim = [0] 37 | 38 | if grad_input.ndim > 2: 39 | dim += list(range(2, grad_input.ndim)) 40 | 41 | grad_bias = grad_input.sum(dim).detach() 42 | 43 | return grad_input, grad_bias 44 | 45 | @staticmethod 46 | def backward(ctx, gradgrad_input, gradgrad_bias): 47 | out, = ctx.saved_tensors 48 | gradgrad_out = fused.fused_bias_act( 49 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 50 | ) 51 | 52 | return gradgrad_out, None, None, None 53 | 54 | 55 | class FusedLeakyReLUFunction(Function): 56 | @staticmethod 57 | @custom_fwd(cast_inputs=torch.float32) 58 | def forward(ctx, input, bias, negative_slope, scale): 59 | empty = input.new_empty(0) 60 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 61 | ctx.save_for_backward(out) 62 | ctx.negative_slope = negative_slope 63 | ctx.scale = scale 64 | 65 | return out 66 | 67 | @staticmethod 68 | @custom_bwd 69 | def backward(ctx, grad_output): 70 | out, = ctx.saved_tensors 71 | 72 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 73 | grad_output, out, ctx.negative_slope, ctx.scale 74 | ) 75 | 76 | return grad_input, grad_bias, None, None 77 | 78 | 79 | class FusedLeakyReLU(nn.Module): 80 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 81 | super().__init__() 82 | 83 | self.bias = nn.Parameter(torch.zeros(channel)) 84 | self.negative_slope = negative_slope 85 | self.scale = scale 86 | 87 | def forward(self, input): 88 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 89 | 90 | 91 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 92 | if input.device.type == "cpu": 93 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 94 | return ( 95 | F.leaky_relu( 96 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 97 | ) 98 | * scale 99 | ) 100 | 101 | else: 102 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 103 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | 7 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 15 | int act, int grad, float alpha, float scale) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(bias); 18 | 19 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 24 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | 7 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 8 | int up_x, int up_y, int down_x, int down_y, 9 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 16 | int up_x, int up_y, int down_x, int down_y, 17 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(kernel); 20 | 21 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 26 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | 6 | import torch 7 | from torch.nn import functional as F 8 | from torch.autograd import Function 9 | from torch.utils.cpp_extension import load 10 | 11 | 12 | module_path = os.path.dirname(__file__) 13 | upfirdn2d_op = load( 14 | "upfirdn2d", 15 | sources=[ 16 | os.path.join(module_path, "upfirdn2d.cpp"), 17 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 18 | ], 19 | ) 20 | 21 | 22 | class UpFirDn2dBackward(Function): 23 | @staticmethod 24 | def forward( 25 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 26 | ): 27 | 28 | up_x, up_y = up 29 | down_x, down_y = down 30 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 31 | 32 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 33 | 34 | grad_input = upfirdn2d_op.upfirdn2d( 35 | grad_output, 36 | grad_kernel, 37 | down_x, 38 | down_y, 39 | up_x, 40 | up_y, 41 | g_pad_x0, 42 | g_pad_x1, 43 | g_pad_y0, 44 | g_pad_y1, 45 | ) 46 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 47 | 48 | ctx.save_for_backward(kernel) 49 | 50 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 51 | 52 | ctx.up_x = up_x 53 | ctx.up_y = up_y 54 | ctx.down_x = down_x 55 | ctx.down_y = down_y 56 | ctx.pad_x0 = pad_x0 57 | ctx.pad_x1 = pad_x1 58 | ctx.pad_y0 = pad_y0 59 | ctx.pad_y1 = pad_y1 60 | ctx.in_size = in_size 61 | ctx.out_size = out_size 62 | 63 | return grad_input 64 | 65 | @staticmethod 66 | def backward(ctx, gradgrad_input): 67 | kernel, = ctx.saved_tensors 68 | 69 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 70 | 71 | gradgrad_out = upfirdn2d_op.upfirdn2d( 72 | gradgrad_input, 73 | kernel, 74 | ctx.up_x, 75 | ctx.up_y, 76 | ctx.down_x, 77 | ctx.down_y, 78 | ctx.pad_x0, 79 | ctx.pad_x1, 80 | ctx.pad_y0, 81 | ctx.pad_y1, 82 | ) 83 | gradgrad_out = gradgrad_out.view( 84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 85 | ) 86 | 87 | return gradgrad_out, None, None, None, None, None, None, None, None 88 | 89 | 90 | class UpFirDn2d(Function): 91 | @staticmethod 92 | def forward(ctx, input, kernel, up, down, pad): 93 | up_x, up_y = up 94 | down_x, down_y = down 95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 96 | 97 | kernel_h, kernel_w = kernel.shape 98 | batch, channel, in_h, in_w = input.shape 99 | ctx.in_size = input.shape 100 | 101 | input = input.reshape(-1, in_h, in_w, 1) 102 | 103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 104 | 105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 107 | ctx.out_size = (out_h, out_w) 108 | 109 | ctx.up = (up_x, up_y) 110 | ctx.down = (down_x, down_y) 111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 112 | 113 | g_pad_x0 = kernel_w - pad_x0 - 1 114 | g_pad_y0 = kernel_h - pad_y0 - 1 115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 117 | 118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 119 | 120 | out = upfirdn2d_op.upfirdn2d( 121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 122 | ) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = UpFirDn2dBackward.apply( 132 | grad_output, 133 | kernel, 134 | grad_kernel, 135 | ctx.up, 136 | ctx.down, 137 | ctx.pad, 138 | ctx.g_pad, 139 | ctx.in_size, 140 | ctx.out_size, 141 | ) 142 | 143 | return grad_input, None, None, None, None 144 | 145 | 146 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 147 | if input.device.type == "cpu": 148 | out = upfirdn2d_native( 149 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 150 | ) 151 | 152 | else: 153 | out = UpFirDn2d.apply( 154 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 155 | ) 156 | 157 | return out 158 | 159 | 160 | def upfirdn2d_native( 161 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 162 | ): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad( 174 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 175 | ) 176 | out = out[ 177 | :, 178 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 179 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 180 | :, 181 | ] 182 | 183 | out = out.permute(0, 3, 1, 2) 184 | out = out.reshape( 185 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 186 | ) 187 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 188 | out = F.conv2d(out, w) 189 | out = out.reshape( 190 | -1, 191 | minor, 192 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 193 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 194 | ) 195 | out = out.permute(0, 2, 3, 1) 196 | out = out[:, ::down_y, ::down_x, :] 197 | 198 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 199 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 200 | 201 | return out.view(-1, channel, out_h, out_w) 202 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | lmdb 4 | timm 5 | scipy 6 | scikit-learn 7 | einops 8 | tqdm 9 | wandb 10 | ninja 11 | natten>=0.14.6 12 | hydra-core 13 | hydra_colorlog 14 | joblib 15 | dill 16 | imageio-ffmpeg 17 | ftfy 18 | matplotlib 19 | rich 20 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | import time 3 | import os 4 | from tqdm import tqdm 5 | from joblib import Parallel, delayed 6 | import wandb 7 | import torch 8 | from torchvision import transforms as T 9 | from torchvision.utils import make_grid, save_image 10 | 11 | from dataset.dataset import unnormalize 12 | from utils import fid_score 13 | from utils.improved_precision_recall import IPR 14 | from utils.distributed import get_rank 15 | 16 | @torch.inference_mode() 17 | def save_images_batched(args, generator, steps=None, log_first_batch=True): 18 | if args.logging.sample_path[0] != "/": 19 | path = args.save_root + args.logging.sample_path 20 | else: 21 | path = args.logging.sample_path 22 | if steps is not None: 23 | path += f"eval_{str(steps)}" 24 | # We're only going to make the paths for training 25 | if not os.path.exists(path): 26 | os.mkdir(path) 27 | assert(args.evaluation.total_size % args.evaluation.batch == 0),\ 28 | f"Evaluation total size % batch should be zero. Got "\ 29 | f"{args.evaluation.total_size % args.evaluation.batch}" 30 | print(f"Saving {args.evaluation.total_size} "\ 31 | f"Images to {path}") 32 | cnt = 0 33 | nbatches = args.evaluation.total_size // args.evaluation.batch 34 | for _ in tqdm(range(nbatches), desc="Saving Images"): 35 | noise = torch.randn((args.evaluation.batch, args.runs.generator.style_dim)).to(args.device) 36 | sample, _ = generator(noise) 37 | sample = unnormalize(sample) 38 | if args.logging.wandb and get_rank() == 0 and log_first_batch and cnt == 0: 39 | grid = make_grid(sample, nrow=args.evaluation.batch) 40 | wandb.log({"samples": wandb.Image(grid)}, commit=False) 41 | Parallel(n_jobs=args.evaluation.batch)(delayed(save_image) 42 | (img, f"{path}/{str(cnt+j).zfill(6)}.png", 43 | nrow=1, padding=0, normalize=True, value_range=(0,1),) 44 | for j, img in enumerate(sample)) 45 | cnt += args.evaluation.batch 46 | 47 | @torch.inference_mode() 48 | def clear_directory(args): 49 | if args.logging.sample_path[0] != "/": 50 | path = args.save_root + args.logging.sample_path 51 | else: 52 | path = args.logging.sample_path 53 | if not os.path.exists(path): 54 | print(f"[bold yellow]WARNING:[/] (Evaluate) {path} does not exist. Creating...") 55 | os.mkdir(path) 56 | _files = os.listdir(path) 57 | if _files == []: 58 | print(f"Directory {path} is already empty. Worry if not first time") 59 | return 60 | assert(".png" in _files[0]) 61 | Parallel(n_jobs=32)(delayed(os.remove) 62 | (path + img) for img in _files) 63 | 64 | 65 | @torch.inference_mode() 66 | def evaluate(args, 67 | generator, # Should be g_ema 68 | steps=None, # Save to specific eval dir or common 69 | log_first_batch=True, # Log first batch of images to wandb? 70 | ): 71 | #print(f" Parameters ".center(40, "-")) 72 | print(f"> Generator has:".ljust(19),f"{args.runs.generator.params:.4f} M Parameters") 73 | if not hasattr(args.logging, "sample_path"): 74 | path = args.save_root 75 | else: 76 | if args.logging.sample_path[0] != "/": 77 | path = args.save_root+ args.logging.sample_path 78 | else: 79 | path = args.logging.sample_path 80 | assert(type(path) == str),f"Path needs to be a string not {type(path)}" 81 | if path[-1] != "/": path = path + "/" 82 | if steps is not None: 83 | path += f"eval_{str(steps)}" 84 | # We're only going to make the paths for training 85 | if not os.path.exists(path): 86 | os.mkdir(path) 87 | # Save ALL the images 88 | if args.logging.reuse_samplepath: 89 | clear_directory(args) 90 | torch.cuda.synchronize() 91 | save_images_batched(args=args, 92 | generator=generator, 93 | steps=steps, 94 | log_first_batch=log_first_batch) 95 | if args.evaluation.gt_path[0] != "/": 96 | gt_path = args.dataset.path + args.evaluation.gt_path 97 | else: 98 | gt_path = args.evaluation.gt_path 99 | fid = fid_score.calculate_fid_given_paths([path, gt_path], 100 | batch_size=args.evaluation.batch, 101 | device=args.device, 102 | dims=2048, 103 | num_workers=args.workers, 104 | N=args.evaluation.total_size) 105 | print(f"{args.dataset.name} FID-{args.evaluation.total_size//1000}k : "\ 106 | f"{fid:.3f} for {steps} steps") 107 | return fid 108 | 109 | -------------------------------------------------------------------------------- /src/inference.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | import os 3 | import time 4 | from PIL import Image 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.utils import make_grid, save_image 8 | 9 | from dataset.dataset import unnormalize 10 | 11 | toPIL = T.ToPILImage() 12 | toTensor = T.ToTensor() 13 | 14 | @torch.inference_mode() 15 | def add_watermark(image, im_size, watermark_path="assets/watermark.jpg", 16 | wmsize=16, bbuf=5, opacity=0.9): 17 | ''' 18 | Creates a watermark on the saved inference image. 19 | We request that you do not remove this to properly assign credit to 20 | Shi-Lab's work. 21 | ''' 22 | image = image.cpu() 23 | watermark = Image.open(watermark_path).resize((wmsize, wmsize)) 24 | watermark = toTensor(watermark) 25 | loc = im_size - wmsize - bbuf 26 | image[:,:,loc:-bbuf, loc:-bbuf] = watermark 27 | return image 28 | 29 | def extract_range(args): 30 | ''' 31 | Helper function to allow for a range of inputs from hydra-core 32 | This only works if you specify a range from the config file. 33 | seeds: range(start, end, step) 34 | 35 | If you are using the command line pass in a comma delimited sequence with 36 | `seq` 37 | python main type=inference inference.seeds=[`seq -s, start step_size end`] 38 | The back ticks are important because they run a bash command 39 | ''' 40 | start = 0 41 | step_size=1 42 | start_end = args.inference.seeds.split("range(")[1].split(")")[0] 43 | if "," in start_end: 44 | nums = start_end.split(',') 45 | match len(nums): 46 | case 1: 47 | end = nums[0] 48 | case 2: 49 | start = nums[0] 50 | end = nums[1] 51 | case 3: 52 | start = nums[0] 53 | end = nums[1] 54 | step_size = nums[2] 55 | case _: 56 | raise ValueError 57 | else: 58 | end = start_end 59 | args.inference.seeds = list(range(int(start), int(end), int(step_size))) 60 | print(f"Using Range of Seeds from {start} to {end} with a step size of {step_size}") 61 | 62 | @torch.inference_mode() 63 | def inference(args, generator): 64 | save_path = args.inference.save_path 65 | if save_path[0] != "/": 66 | save_path = args.save_root + save_path 67 | if not os.path.exists(save_path): 68 | print(f"[bold yellow]WARNING:[/] (Inference) {save_path} does not exist. Creating...") 69 | os.mkdir(save_path) 70 | assert('num_images' in args.inference or 'seeds' in args.inference),\ 71 | f"Inference must either specify a number of images "\ 72 | f"(random seed generation) or a set of seeds to use to generate." 73 | if 'num_images' in args.inference: 74 | num_imgs = args.inference.num_images 75 | print(f"Got {num_imgs=}, {args.inference.num_images=}") 76 | if "seeds" in args.inference and args.inference.seeds is not None: 77 | # Handles "range(start, end)" input from hydra file 78 | if "range" in args.inference.seeds: 79 | extract_range(args) 80 | num_imgs = len(args.inference.seeds) 81 | else: 82 | num_imgs = len(args.inference.seeds) 83 | if "num_images" in args.inference and args.inference.num_images != num_imgs: 84 | print(f"[bold red]WARNING:[/] You asked for "\ 85 | f"{args.inference.num_images} image in the config but specified " \ 86 | f"seeds. Seeds overrides and you will get {num_imgs} images.") 87 | 88 | for i in range(num_imgs): 89 | if "seeds" in args.inference and args.inference.seeds is not None and \ 90 | i < len(args.inference.seeds): 91 | seed = args.inference.seeds[i] 92 | torch.random.manual_seed(seed) 93 | else: 94 | seed = torch.random.seed() 95 | print(f"Using Seed: {seed}") 96 | 97 | noise = torch.randn((args.inference.batch, 98 | args.runs.generator.style_dim)).to(args.device) 99 | sample, latent = generator(noise) 100 | sample = unnormalize(sample) 101 | sample = add_watermark(sample, im_size=args.runs.size) 102 | save_image(sample, f"{save_path}/{seed}.png", 103 | nrow=1, padding=0, normalize=True, value_range=(0,1)) 104 | print(f"Saved {save_path}/{seed}.png") 105 | -------------------------------------------------------------------------------- /src/logging.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/StyleNAT/80927246c9d1a48a1f5b20b65f53b946e2075bb7/src/logging.py -------------------------------------------------------------------------------- /src/throughput.py: -------------------------------------------------------------------------------- 1 | from rich import print 2 | import torch 3 | 4 | @torch.inference_mode() 5 | def throughput(generator, 6 | rounds=100, 7 | warmup=10, 8 | batch_size=1, 9 | style_dim=512, 10 | ): 11 | #torch.backends.cuda.matmul.allow_tf32 = True 12 | times = torch.empty(rounds) 13 | noise = torch.randn((batch_size, style_dim), device="cuda") 14 | for _ in range(warmup): 15 | imgs, _ = generator(noise) 16 | for i in range(rounds): 17 | starter = torch.cuda.Event(enable_timing=True) 18 | ender = torch.cuda.Event(enable_timing=True) 19 | starter.record() 20 | imgs, _ = generator(noise) 21 | ender.record() 22 | torch.cuda.synchronize() 23 | times[i] = starter.elapsed_time(ender)/1000 24 | #print(f"{imgs.shape=}") 25 | total_time = times.sum() 26 | total_images = rounds * batch_size 27 | imgs_per_second = total_images / total_time 28 | print(f"{torch.std_mean(times)=}") 29 | print(f"{total_images=}, {total_time=}") 30 | print(f"{imgs_per_second=}") 31 | #print(f"{torch.std_mean(imgs_per_second)=}") 32 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to temporarily suppress known warnings in torch.jit.trace(). 68 | # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 69 | 70 | @contextlib.contextmanager 71 | def suppress_tracer_warnings(): 72 | flt = ('ignore', None, torch.jit.TracerWarning, None, 0) 73 | warnings.filters.insert(0, flt) 74 | yield 75 | warnings.filters.remove(flt) 76 | 77 | #---------------------------------------------------------------------------- 78 | # Assert that the shape of a tensor matches the given list of integers. 79 | # None indicates that the size of a dimension is allowed to vary. 80 | # Performs symbolic assertion when used in torch.jit.trace(). 81 | 82 | def assert_shape(tensor, ref_shape): 83 | if tensor.ndim != len(ref_shape): 84 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 85 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 86 | if ref_size is None: 87 | pass 88 | elif isinstance(ref_size, torch.Tensor): 89 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 90 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 91 | elif isinstance(size, torch.Tensor): 92 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 93 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 94 | elif size != ref_size: 95 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 96 | 97 | #---------------------------------------------------------------------------- 98 | # Function decorator that calls torch.autograd.profiler.record_function(). 99 | 100 | def profiled_function(fn): 101 | def decorator(*args, **kwargs): 102 | with torch.autograd.profiler.record_function(fn.__name__): 103 | return fn(*args, **kwargs) 104 | decorator.__name__ = fn.__name__ 105 | return decorator 106 | 107 | #---------------------------------------------------------------------------- 108 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 109 | # indefinitely, shuffling items as it goes. 110 | 111 | class InfiniteSampler(torch.utils.data.Sampler): 112 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 113 | assert len(dataset) > 0 114 | assert num_replicas > 0 115 | assert 0 <= rank < num_replicas 116 | assert 0 <= window_size <= 1 117 | super().__init__(dataset) 118 | self.dataset = dataset 119 | self.rank = rank 120 | self.num_replicas = num_replicas 121 | self.shuffle = shuffle 122 | self.seed = seed 123 | self.window_size = window_size 124 | 125 | def __iter__(self): 126 | order = np.arange(len(self.dataset)) 127 | rnd = None 128 | window = 0 129 | if self.shuffle: 130 | rnd = np.random.RandomState(self.seed) 131 | rnd.shuffle(order) 132 | window = int(np.rint(order.size * self.window_size)) 133 | 134 | idx = 0 135 | while True: 136 | i = idx % order.size 137 | if idx % self.num_replicas == self.rank: 138 | yield order[i] 139 | if window >= 2: 140 | j = (i - rnd.randint(window)) % order.size 141 | order[i], order[j] = order[j], order[i] 142 | idx += 1 143 | 144 | #---------------------------------------------------------------------------- 145 | # Utilities for operating with torch.nn.Module parameters and buffers. 146 | def spectral_to_cpu(model: torch.nn.Module): 147 | def wrapped_in_spectral(m): return hasattr(m, 'weight_v') 148 | children = get_children(model) 149 | for child in children: 150 | if wrapped_in_spectral(child): 151 | child.weight = child.weight.cpu() 152 | return model 153 | 154 | def get_children(model: torch.nn.Module): 155 | children = list(model.children()) 156 | flatt_children = [] 157 | if children == []: 158 | return model 159 | else: 160 | for child in children: 161 | try: 162 | flatt_children.extend(get_children(child)) 163 | except TypeError: 164 | flatt_children.append(get_children(child)) 165 | return flatt_children 166 | 167 | def params_and_buffers(module): 168 | assert isinstance(module, torch.nn.Module) 169 | return list(module.parameters()) + list(module.buffers()) 170 | 171 | def named_params_and_buffers(module): 172 | assert isinstance(module, torch.nn.Module) 173 | return list(module.named_parameters()) + list(module.named_buffers()) 174 | 175 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 176 | assert isinstance(src_module, torch.nn.Module) 177 | assert isinstance(dst_module, torch.nn.Module) 178 | src_tensors = dict(named_params_and_buffers(src_module)) 179 | for name, tensor in named_params_and_buffers(dst_module): 180 | assert (name in src_tensors) or (not require_all) 181 | if name in src_tensors: 182 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 183 | 184 | #---------------------------------------------------------------------------- 185 | # Context manager for easily enabling/disabling DistributedDataParallel 186 | # synchronization. 187 | 188 | @contextlib.contextmanager 189 | def ddp_sync(module, sync): 190 | assert isinstance(module, torch.nn.Module) 191 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 192 | yield 193 | else: 194 | with module.no_sync(): 195 | yield 196 | 197 | #---------------------------------------------------------------------------- 198 | # Check DistributedDataParallel consistency across processes. 199 | 200 | def check_ddp_consistency(module, ignore_regex=None): 201 | assert isinstance(module, torch.nn.Module) 202 | for name, tensor in named_params_and_buffers(module): 203 | fullname = type(module).__name__ + '.' + name 204 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 205 | continue 206 | tensor = tensor.detach() 207 | if tensor.is_floating_point(): 208 | tensor = nan_to_num(tensor) 209 | other = tensor.clone() 210 | torch.distributed.broadcast(tensor=other, src=0) 211 | assert (tensor == other).all(), fullname 212 | 213 | #---------------------------------------------------------------------------- 214 | # Print summary table of module hierarchy. 215 | 216 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 217 | assert isinstance(module, torch.nn.Module) 218 | assert not isinstance(module, torch.jit.ScriptModule) 219 | assert isinstance(inputs, (tuple, list)) 220 | 221 | # Register hooks. 222 | entries = [] 223 | nesting = [0] 224 | def pre_hook(_mod, _inputs): 225 | nesting[0] += 1 226 | def post_hook(mod, _inputs, outputs): 227 | nesting[0] -= 1 228 | if nesting[0] <= max_nesting: 229 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 230 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 231 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 232 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 233 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 234 | 235 | # Run module. 236 | outputs = module(*inputs) 237 | for hook in hooks: 238 | hook.remove() 239 | 240 | # Identify unique outputs, parameters, and buffers. 241 | tensors_seen = set() 242 | for e in entries: 243 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 244 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 245 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 246 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 247 | 248 | # Filter out redundant entries. 249 | if skip_redundant: 250 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 251 | 252 | # Construct table. 253 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 254 | rows += [['---'] * len(rows[0])] 255 | param_total = 0 256 | buffer_total = 0 257 | submodule_names = {mod: name for name, mod in module.named_modules()} 258 | for e in entries: 259 | name = '' if e.mod is module else submodule_names[e.mod] 260 | param_size = sum(t.numel() for t in e.unique_params) 261 | buffer_size = sum(t.numel() for t in e.unique_buffers) 262 | output_shapes = [str(list(t.shape)) for t in e.outputs] 263 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 264 | rows += [[ 265 | name + (':0' if len(e.outputs) >= 2 else ''), 266 | str(param_size) if param_size else '-', 267 | str(buffer_size) if buffer_size else '-', 268 | (output_shapes + ['-'])[0], 269 | (output_dtypes + ['-'])[0], 270 | ]] 271 | for idx in range(1, len(e.outputs)): 272 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 273 | param_total += param_size 274 | buffer_total += buffer_size 275 | rows += [['---'] * len(rows[0])] 276 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 277 | 278 | # Print table. 279 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 280 | print() 281 | for row in rows: 282 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 283 | print() 284 | return outputs 285 | 286 | #---------------------------------------------------------------------------- 287 | 288 | # Added by Katja 289 | import os 290 | 291 | def get_ckpt_path(run_dir): 292 | return os.path.join(run_dir, f'network-snapshot.pkl') 293 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | 15 | # pylint: disable=redefined-builtin 16 | # pylint: disable=arguments-differ 17 | # pylint: disable=protected-access 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | enabled = False # Enable the custom op by setting this to true. 22 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 23 | 24 | @contextlib.contextmanager 25 | def no_weight_gradients(disable=True): 26 | global weight_gradients_disabled 27 | old = weight_gradients_disabled 28 | if disable: 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | return True 54 | 55 | def _tuple_of_ints(xs, ndim): 56 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 57 | assert len(xs) == ndim 58 | assert all(isinstance(x, int) for x in xs) 59 | return xs 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | _conv2d_gradfix_cache = dict() 64 | _null_tensor = torch.empty([0]) 65 | 66 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 67 | # Parse arguments. 68 | ndim = 2 69 | weight_shape = tuple(weight_shape) 70 | stride = _tuple_of_ints(stride, ndim) 71 | padding = _tuple_of_ints(padding, ndim) 72 | output_padding = _tuple_of_ints(output_padding, ndim) 73 | dilation = _tuple_of_ints(dilation, ndim) 74 | 75 | # Lookup from cache. 76 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 77 | if key in _conv2d_gradfix_cache: 78 | return _conv2d_gradfix_cache[key] 79 | 80 | # Validate arguments. 81 | assert groups >= 1 82 | assert len(weight_shape) == ndim + 2 83 | assert all(stride[i] >= 1 for i in range(ndim)) 84 | assert all(padding[i] >= 0 for i in range(ndim)) 85 | assert all(dilation[i] >= 0 for i in range(ndim)) 86 | if not transpose: 87 | assert all(output_padding[i] == 0 for i in range(ndim)) 88 | else: # transpose 89 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 90 | 91 | # Helpers. 92 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 93 | def calc_output_padding(input_shape, output_shape): 94 | if transpose: 95 | return [0, 0] 96 | return [ 97 | input_shape[i + 2] 98 | - (output_shape[i + 2] - 1) * stride[i] 99 | - (1 - 2 * padding[i]) 100 | - dilation[i] * (weight_shape[i + 2] - 1) 101 | for i in range(ndim) 102 | ] 103 | 104 | # Forward & backward. 105 | class Conv2d(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, input, weight, bias): 108 | assert weight.shape == weight_shape 109 | ctx.save_for_backward( 110 | input if weight.requires_grad else _null_tensor, 111 | weight if input.requires_grad else _null_tensor, 112 | ) 113 | ctx.input_shape = input.shape 114 | 115 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 116 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 117 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 118 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 119 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 120 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 121 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 122 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 123 | 124 | # General case => cuDNN. 125 | if transpose: 126 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 127 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | input, weight = ctx.saved_tensors 132 | input_shape = ctx.input_shape 133 | grad_input = None 134 | grad_weight = None 135 | grad_bias = None 136 | 137 | if ctx.needs_input_grad[0]: 138 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 139 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 140 | grad_input = op.apply(grad_output, weight, None) 141 | assert grad_input.shape == input_shape 142 | 143 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 144 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 145 | assert grad_weight.shape == weight_shape 146 | 147 | if ctx.needs_input_grad[2]: 148 | grad_bias = grad_output.sum([0, 2, 3]) 149 | 150 | return grad_input, grad_weight, grad_bias 151 | 152 | # Gradient with respect to the weights. 153 | class Conv2dGradWeight(torch.autograd.Function): 154 | @staticmethod 155 | def forward(ctx, grad_output, input): 156 | ctx.save_for_backward( 157 | grad_output if input.requires_grad else _null_tensor, 158 | input if grad_output.requires_grad else _null_tensor, 159 | ) 160 | ctx.grad_output_shape = grad_output.shape 161 | ctx.input_shape = input.shape 162 | 163 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 164 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 165 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 166 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 167 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 168 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 169 | 170 | # General case => cuDNN. 171 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 172 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 173 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 174 | 175 | @staticmethod 176 | def backward(ctx, grad2_grad_weight): 177 | grad_output, input = ctx.saved_tensors 178 | grad_output_shape = ctx.grad_output_shape 179 | input_shape = ctx.input_shape 180 | grad2_grad_output = None 181 | grad2_input = None 182 | 183 | if ctx.needs_input_grad[0]: 184 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 185 | assert grad2_grad_output.shape == grad_output_shape 186 | 187 | if ctx.needs_input_grad[1]: 188 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 189 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 190 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 191 | assert grad2_input.shape == input_shape 192 | 193 | return grad2_grad_output, grad2_input 194 | 195 | _conv2d_gradfix_cache[key] = Conv2d 196 | return Conv2d 197 | 198 | #---------------------------------------------------------------------------- 199 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /torch_utils/utils_spectrum.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fft import fftn 3 | 4 | 5 | def roll_quadrants(data, backwards=False): 6 | """ 7 | Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] 8 | Args: 9 | data: fourier transform, (NxHxW) 10 | backwards: bool, if True shift high frequencies back to center 11 | 12 | Returns: 13 | Shifted fourier transform. 14 | """ 15 | dim = data.ndim - 1 16 | 17 | if dim != 2: 18 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 19 | if any(s % 2 == 0 for s in data.shape[1:]): 20 | raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') 21 | 22 | # for each dimension swap left and right half 23 | dims = tuple(range(1, dim+1)) # add one for batch dimension 24 | shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd 25 | if backwards: 26 | shifts *= -1 27 | return data.roll(shifts.tolist(), dims=dims) 28 | 29 | 30 | def batch_fft(data, normalize=False): 31 | """ 32 | Compute fourier transform of batch. 33 | Args: 34 | data: input tensor, (NxHxW) 35 | 36 | Returns: 37 | Batch fourier transform of input data. 38 | """ 39 | 40 | dim = data.ndim - 1 # subtract one for batch dimension 41 | if dim != 2: 42 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 43 | 44 | dims = tuple(range(1, dim + 1)) # add one for batch dimension 45 | if normalize: 46 | norm = 'ortho' 47 | else: 48 | norm = 'backward' 49 | 50 | if not torch.is_complex(data): 51 | data = torch.complex(data, torch.zeros_like(data)) 52 | freq = fftn(data, dim=dims, norm=norm) 53 | 54 | return freq 55 | 56 | 57 | def azimuthal_average(image, center=None): 58 | # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/ 59 | """ 60 | Calculate the azimuthally averaged radial profile. 61 | Requires low frequencies to be at the center of the image. 62 | Args: 63 | image: Batch of 2D images, NxHxW 64 | center: The [x,y] pixel coordinates used as the center. The default is 65 | None, which then uses the center of the image (including 66 | fracitonal pixels). 67 | 68 | Returns: 69 | Azimuthal average over the image around the center 70 | """ 71 | # Check input shapes 72 | assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ 73 | f'(but it is len(center)={len(center)}.' 74 | # Calculate the indices from the image 75 | H, W = image.shape[-2:] 76 | h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 77 | 78 | if center is None: 79 | center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) 80 | 81 | # Compute radius for each pixel wrt center 82 | r = torch.stack([w-center[0], h-center[1]]).norm(2, 0) 83 | 84 | # Get sorted radii 85 | r_sorted, ind = r.flatten().sort() 86 | i_sorted = image.flatten(-2, -1)[..., ind] 87 | 88 | # Get the integer part of the radii (bin size = 1) 89 | r_int = r_sorted.long() # attribute to the smaller integer 90 | 91 | # Find all pixels that fall within each radial bin. 92 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii 93 | rind = torch.where(deltar)[0] # location of changed radius 94 | 95 | # compute number of elements in each bin 96 | nind = rind + 1 # number of elements = idx + 1 97 | nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders 98 | nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius 99 | 100 | # Cumulative sum to figure out sums for each radius bin 101 | if H % 2 == 0: 102 | raise NotImplementedError('Not sure if implementation correct, please check') 103 | rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders 104 | else: 105 | rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders 106 | csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius 107 | tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] 108 | # add mean 109 | tbin = torch.cat([csim[:, 0:1], tbin], 1) 110 | 111 | radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins 112 | 113 | return radial_prof 114 | 115 | 116 | def get_spectrum(data, normalize=False): 117 | dim = data.ndim - 1 # subtract one for batch dimension 118 | if dim != 2: 119 | raise AttributeError(f'Data must be 2d but it is {dim}d.') 120 | 121 | freq = batch_fft(data, normalize=normalize) 122 | power_spec = freq.real ** 2 + freq.imag ** 2 123 | N = data.shape[1] 124 | if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum 125 | # and is not averaged with the mean value 126 | N_2 = N//2 127 | power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) 128 | power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) 129 | 130 | power_spec = roll_quadrants(power_spec) 131 | power_spec = azimuthal_average(power_spec) 132 | return power_spec 133 | 134 | 135 | def plot_std(mean, std, x=None, ax=None, **kwargs): 136 | import matplotlib.pyplot as plt 137 | if ax is None: 138 | fig, ax = plt.subplots(1) 139 | 140 | # plot error margins in same color as line 141 | err_kwargs = { 142 | 'alpha': 0.3 143 | } 144 | 145 | if 'c' in kwargs.keys(): 146 | err_kwargs['color'] = kwargs['c'] 147 | elif 'color' in kwargs.keys(): 148 | err_kwargs['color'] = kwargs['color'] 149 | 150 | if x is None: 151 | x = torch.linspace(0, 1, len(mean)) # use normalized x axis 152 | ax.plot(x, mean, **kwargs) 153 | ax.fill_between(x, mean-std, mean+std, **err_kwargs) 154 | 155 | return ax 156 | -------------------------------------------------------------------------------- /utils/CRDiffAug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def CR_DiffAug(x, flip=True, translation=True, color=True, cutout=True): 9 | if flip: 10 | x = random_flip(x, 0.5) 11 | if translation: 12 | x = rand_translation(x, 1/8) 13 | if color: 14 | aug_list = [rand_brightness, rand_saturation, rand_contrast] 15 | for func in aug_list: 16 | x = func(x) 17 | if cutout: 18 | x = rand_cutout(x) 19 | if flip or translation: 20 | x = x.contiguous() 21 | return x 22 | 23 | 24 | def random_flip(x, p): 25 | x_out = x.clone() 26 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 27 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) 28 | flip_mask = flip_prob < p 29 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) 30 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) 31 | return x_out 32 | 33 | 34 | def rand_brightness(x): 35 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 36 | return x 37 | 38 | 39 | def rand_saturation(x): 40 | x_mean = x.mean(dim=1, keepdim=True) 41 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 42 | return x 43 | 44 | 45 | def rand_contrast(x): 46 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 47 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 48 | return x 49 | 50 | 51 | def rand_translation(x, ratio=0.125): 52 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 53 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 54 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 55 | grid_batch, grid_x, grid_y = torch.meshgrid( 56 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 57 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 58 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 59 | indexing='ij' 60 | ) 61 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 62 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 63 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 64 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 65 | return x 66 | 67 | 68 | def rand_cutout(x, ratio=0.5): 69 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 70 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 71 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 72 | grid_batch, grid_x, grid_y = torch.meshgrid( 73 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 74 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 75 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 76 | indexing='ij' 77 | ) 78 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 79 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 80 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 81 | mask[grid_batch, grid_x, grid_y] = 0 82 | x = x * mask.unsqueeze(1) 83 | return x 84 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From StyleSwin 3 | ''' 4 | 5 | import pickle 6 | 7 | import torch 8 | from torch import distributed as dist 9 | 10 | 11 | def get_rank(): 12 | if not dist.is_available(): 13 | return 0 14 | 15 | if not dist.is_initialized(): 16 | return 0 17 | 18 | return dist.get_rank() 19 | 20 | 21 | def synchronize(): 22 | if not dist.is_available(): 23 | return 24 | 25 | if not dist.is_initialized(): 26 | return 27 | 28 | world_size = dist.get_world_size() 29 | 30 | if world_size == 1: 31 | return 32 | 33 | dist.barrier() 34 | 35 | 36 | def get_world_size(): 37 | if not dist.is_available(): 38 | return 1 39 | 40 | if not dist.is_initialized(): 41 | return 1 42 | 43 | return dist.get_world_size() 44 | 45 | 46 | def reduce_sum(tensor): 47 | if not dist.is_available(): 48 | return tensor 49 | 50 | if not dist.is_initialized(): 51 | return tensor 52 | 53 | tensor = tensor.clone() 54 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 55 | 56 | return tensor 57 | 58 | 59 | def gather_grad(params): 60 | world_size = get_world_size() 61 | 62 | if world_size == 1: 63 | return 64 | 65 | for param in params: 66 | if param.grad is not None: 67 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 68 | param.grad.data.div_(world_size) 69 | 70 | 71 | def all_gather(data): 72 | world_size = get_world_size() 73 | 74 | if world_size == 1: 75 | return [data] 76 | 77 | buffer = pickle.dumps(data) 78 | storage = torch.ByteStorage.from_buffer(buffer) 79 | tensor = torch.ByteTensor(storage).to('cuda') 80 | 81 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 82 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 83 | dist.all_gather(size_list, local_size) 84 | size_list = [int(size.item()) for size in size_list] 85 | max_size = max(size_list) 86 | 87 | tensor_list = [] 88 | for _ in size_list: 89 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 90 | 91 | if local_size != max_size: 92 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 93 | tensor = torch.cat((tensor, padding), 0) 94 | 95 | dist.all_gather(tensor_list, tensor) 96 | 97 | data_list = [] 98 | 99 | for size, tensor in zip(size_list, tensor_list): 100 | buffer = tensor.cpu().numpy().tobytes()[:size] 101 | data_list.append(pickle.loads(buffer)) 102 | 103 | return data_list 104 | 105 | 106 | def reduce_loss_dict(loss_dict): 107 | world_size = get_world_size() 108 | 109 | if world_size < 2: 110 | return loss_dict 111 | 112 | with torch.no_grad(): 113 | keys = [] 114 | losses = [] 115 | 116 | for k in sorted(loss_dict.keys()): 117 | keys.append(k) 118 | losses.append(loss_dict[k]) 119 | 120 | losses = torch.stack(losses, 0) 121 | dist.reduce(losses, dst=0) 122 | 123 | if dist.get_rank() == 0: 124 | losses /= world_size 125 | 126 | reduced_losses = {k: v for k, v in zip(keys, losses)} 127 | 128 | return reduced_losses 129 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import os 3 | import random 4 | import warnings 5 | import logging 6 | from omegaconf import OmegaConf, open_dict 7 | import torch 8 | from rich import print 9 | # For legacy 10 | import argparse 11 | 12 | _allowed_run_types=['train', 'inference', 'evaluate', 'attention_map', 'throughput'] 13 | 14 | def check_and_set_hydra(args, key : str, value : Any) -> None: 15 | if hasattr(args, key): 16 | args['key'] = value 17 | else: 18 | with open_dict(args): 19 | args['key'] = value 20 | logging.info(f"{args}.{key} = {value}") 21 | 22 | def validate_args(args): 23 | ''' 24 | Check some of the args and do some sanity checking 25 | We'll define default values here so that users don't need to 26 | set them themselves. Reduce user burden, reduce user error. 27 | ''' 28 | assert(args.type in _allowed_run_types),f"Type must be from {_allowed_run_types} but got {args.type}" 29 | arg_keys = args.keys() 30 | with open_dict(args): 31 | if "rank" not in arg_keys: check_and_set_hydra(args,"rank",0) 32 | if "device" not in arg_keys: check_and_set_hydra(args,"device","cpu") 33 | if "world_size" not in arg_keys: check_and_set_hydra(args,"world_size",1) 34 | if "rank" not in arg_keys: check_and_set_hydra(args,"rank",0) 35 | if "local_rank" not in arg_keys: check_and_set_hydra(args,"local_rank",0) 36 | if "distributed" not in arg_keys: #args.distributed = False 37 | if "WORLD_SIZE" in os.environ: 38 | # Single node multi GPU 39 | n_gpu = int(os.environ["WORLD_SIZE"]) 40 | else: 41 | n_gpu = torch.cuda.device_count() 42 | check_and_set_hydra(args,"distributed",n_gpu > 1) 43 | if "workers" not in arg_keys: args.workers = 0 44 | # Validate training args 45 | if args.type == "train": 46 | # logging 47 | assert(hasattr(args, "logging")) 48 | if "print_freq" not in args.logging: 49 | check_and_set_hydra(args.logging,"print_freq",10000) 50 | if "eval_freq" not in args.logging: 51 | check_and_set_hydra(args.logging,"eval_freq",-1) 52 | if "save_freq" not in args.logging: 53 | check_and_set_hydra(args.logging,"save_freq",-1) 54 | if "checkpoint_path" in args.logging: 55 | if args.logging.checkpoint_path[-1] != "/": 56 | args.logging.checkpoint_path += "/" 57 | else: 58 | check_and_set_hydra(args.logging,"checkpoint_path","./") 59 | if "sample_path" in args.logging: 60 | if args.logging.sample_path[-1] != "/": 61 | args.logging.sample_path += "/" 62 | else: 63 | check_and_set_hydra(args.logging, "sample_path", "./") 64 | if "reuse_samplepath" not in args.logging: 65 | check_and_set_hydra(args.logging,"reuse_samplepath",False) 66 | if args.type == "evaluation" or args.type == "train": 67 | assert(hasattr(args.evaluation, "gt_path")),f"You must specify "\ 68 | f"the ground truth data path" 69 | assert(hasattr(args.evaluation, "total_size")),f"You must specify "\ 70 | f"the number of images for FID" 71 | if "batch" not in args.evaluation: 72 | check_and_set_hydra(args.evaluation,"batch", 1) 73 | if "save_root" not in args: 74 | check_and_set_hydra(args,"save_root","/tmp/") 75 | if args.type == "training": 76 | logging.warning("Save root path not set, using /tmp") 77 | if args.save_root[-1] != "/": 78 | args.save_root += "/" 79 | if args.type == "inference": 80 | if "batch" not in args.inference: 81 | check_and_set_hydra(args.inference,"batch",1) 82 | # assert(hasattr(args.evaluation, "gt_path")) 83 | if "misc" not in args.keys(): 84 | check_and_set_hydra(args,"misc",{}) 85 | if "seed" not in args.misc: 86 | check_and_set_hydra(args.misc,"seed",None) 87 | if "rng_state" not in args.misc: 88 | check_and_set_hydra(args.misc,"rng_state",None) 89 | if "py_rng_state" not in args.misc: 90 | check_and_set_hydra(args.misc,"rng_state",None) 91 | 92 | def rng_reproducibility(args, ckpt=None): 93 | # Store RNG info 94 | # Cumbersome but reproducibility is not super easy 95 | with open_dict(args): 96 | if args.misc.seed is None: 97 | args.misc.seed = torch.initial_seed() 98 | else: 99 | torch.manual_seed(args.misc.seed) 100 | if args.misc.rng_state is None: 101 | args.misc.rng_state = torch.get_rng_state().tolist() 102 | else: 103 | torch.set_rng_state(args.misc.rng_state) 104 | if args.misc.py_rng_state is None: 105 | args.misc.py_rng_state = random.getstate() 106 | else: 107 | random.setstate(args.misc.py_rng_state) 108 | 109 | if ckpt is not None \ 110 | and "reuse_rng" in args.restart \ 111 | and args.restart.reuse_rng: 112 | with open_dict(args): 113 | #if "misc" in ckpt['args'].keys() and "seed" in ckpt['args']['misc'].keys(): 114 | try: 115 | if hasattr(ckpt['args'], "misc"): 116 | if hasattr(ckpt['args']['misc'], "seed"): 117 | try: 118 | args.misc.seed = ckpt['args']['misc']['seed'] 119 | print(f"[bold green]RNG Seed successfully loaded") 120 | except: 121 | print("[bold yellow]Seed couldn't be loaded (new style ckpt)") 122 | else: 123 | print("[bold yellow]Couldn't find seed (new style ckpt)") 124 | if hasattr(ckpt['args']['misc'], 'rng_state'): 125 | try: 126 | args.misc.rng_state = ckpt['args']['misc']['rng_state'] 127 | print(f"[bold green]RNG State successfully loaded") 128 | except: 129 | print("[bold yellow]RNG State couldn't be loaded (new style ckpt)") 130 | else: 131 | print("[bold yellow]Couldn't find RNG State (new style ckpt)") 132 | if hasattr(ckpt['args']['misc'], 'py_rng_state'): 133 | try: 134 | args.misc.py_rng_state = ckpt['args']['misc']['py_rng_state'] 135 | print(f"[bold green] Py-RNG State successfully loaded") 136 | except: 137 | print("[bold yellow]Py-RNG State couldn't be loaded (new style ckpt)") 138 | else: 139 | print("[bold yellow]Couldn't find Py-RNG State (new style ckpt)") 140 | elif type(ckpt['args']) == argparse.Namespace: 141 | try: 142 | args.misc.seed = ckpt['args'].seed 143 | print(f"[bold green]RNG Seed successfully loaded") 144 | except: 145 | print("[bold yellow]Seed couldn't be loaded (old style ckpt)") 146 | try: 147 | args.misc.rng_state = ckpt['args'].rng_state.tolist() 148 | print(f"[bold green]RNG State successfully loaded") 149 | except: 150 | print("[bold yellow]RNG State couldn't be loaded (old style ckpt)") 151 | try: 152 | args.misc.py_rng_state = ckpt['args'].rng_state.tolist() 153 | print(f"[bold green] Py-RNG State successfully loaded") 154 | except: 155 | print("[bold yellow]Py-RNG State couldn't be loaded (old style ckpt)") 156 | else: 157 | print("[bold yellow]No rng loading. {type(ckpt['args'])=}") 158 | except: 159 | print("[bold yellow]Seeds couldn't be loaded and don't know why") 160 | print(f"[bold yellow]\t {type(ckpt['args'])}") 161 | print(f"{type(ckpt['args']) == argparse.Namespace}") 162 | try: 163 | torch.manual_seed(args.misc.seed) 164 | _seed = "[bold green]True[/]" 165 | except: 166 | print("[bold yellow]Unable to set manual_seed") 167 | _seed = "[bold red]False[/]" 168 | try: 169 | torch.set_rng_state(torch.as_tensor( 170 | args.misc.rng_state, dtype=torch.uint8), 171 | ) 172 | _pt_rng = "[bold green]True[/]" 173 | except: 174 | print("[bold yellow]Unable to set ptyroch's rng state") 175 | _pt_rng = "[bold red]False[/]" 176 | try: 177 | l = tuple(args.misc.py_rng_state) 178 | random.setstate((l[0], tuple(l[1]), l[2])) 179 | _py_rng = "[bold green]True[/]" 180 | except: 181 | print("[bold yellow]Unable to set python's rng state") 182 | _py_rng = "[bold red]False[/]" 183 | print(f"[bold green]RNG Loading Success States:[/]\n"\ 184 | f"\tSeed: {_seed}, PyTorch RNG: {_pt_rng}, Python RNG {_py_rng}") 185 | -------------------------------------------------------------------------------- /utils/inception.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From mseitzer's Pytorch-FID 3 | https://github.com/mseitzer/pytorch-fid 4 | License included in fid_score 5 | ''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | 12 | try: 13 | from torchvision.models.utils import load_state_dict_from_url 14 | except ImportError: 15 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 16 | 17 | # Inception weights ported to Pytorch from 18 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 19 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 20 | 21 | 22 | class InceptionV3(nn.Module): 23 | """Pretrained InceptionV3 network returning feature maps""" 24 | 25 | # Index of default block of inception to return, 26 | # corresponds to output of final average pooling 27 | DEFAULT_BLOCK_INDEX = 3 28 | 29 | # Maps feature dimensionality to their output blocks indices 30 | BLOCK_INDEX_BY_DIM = { 31 | 64: 0, # First max pooling features 32 | 192: 1, # Second max pooling featurs 33 | 768: 2, # Pre-aux classifier features 34 | 2048: 3 # Final average pooling features 35 | } 36 | 37 | def __init__(self, 38 | output_blocks=(DEFAULT_BLOCK_INDEX,), 39 | resize_input=True, 40 | normalize_input=True, 41 | requires_grad=False, 42 | use_fid_inception=True): 43 | """Build pretrained InceptionV3 44 | 45 | Parameters 46 | ---------- 47 | output_blocks : list of int 48 | Indices of blocks to return features of. Possible values are: 49 | - 0: corresponds to output of first max pooling 50 | - 1: corresponds to output of second max pooling 51 | - 2: corresponds to output which is fed to aux classifier 52 | - 3: corresponds to output of final average pooling 53 | resize_input : bool 54 | If true, bilinearly resizes input to width and height 299 before 55 | feeding input to model. As the network without fully connected 56 | layers is fully convolutional, it should be able to handle inputs 57 | of arbitrary size, so resizing might not be strictly needed 58 | normalize_input : bool 59 | If true, scales the input from range (0, 1) to the range the 60 | pretrained Inception network expects, namely (-1, 1) 61 | requires_grad : bool 62 | If true, parameters of the model require gradients. Possibly useful 63 | for finetuning the network 64 | use_fid_inception : bool 65 | If true, uses the pretrained Inception model used in Tensorflow's 66 | FID implementation. If false, uses the pretrained Inception model 67 | available in torchvision. The FID Inception model has different 68 | weights and a slightly different structure from torchvision's 69 | Inception model. If you want to compute FID scores, you are 70 | strongly advised to set this parameter to true to get comparable 71 | results. 72 | """ 73 | super(InceptionV3, self).__init__() 74 | 75 | self.resize_input = resize_input 76 | self.normalize_input = normalize_input 77 | self.output_blocks = sorted(output_blocks) 78 | self.last_needed_block = max(output_blocks) 79 | 80 | assert self.last_needed_block <= 3, \ 81 | 'Last possible output block index is 3' 82 | 83 | self.blocks = nn.ModuleList() 84 | 85 | if use_fid_inception: 86 | inception = fid_inception_v3() 87 | else: 88 | inception = _inception_v3(pretrained=True) 89 | 90 | # Block 0: input to maxpool1 91 | block0 = [ 92 | inception.Conv2d_1a_3x3, 93 | inception.Conv2d_2a_3x3, 94 | inception.Conv2d_2b_3x3, 95 | nn.MaxPool2d(kernel_size=3, stride=2) 96 | ] 97 | self.blocks.append(nn.Sequential(*block0)) 98 | 99 | # Block 1: maxpool1 to maxpool2 100 | if self.last_needed_block >= 1: 101 | block1 = [ 102 | inception.Conv2d_3b_1x1, 103 | inception.Conv2d_4a_3x3, 104 | nn.MaxPool2d(kernel_size=3, stride=2) 105 | ] 106 | self.blocks.append(nn.Sequential(*block1)) 107 | 108 | # Block 2: maxpool2 to aux classifier 109 | if self.last_needed_block >= 2: 110 | block2 = [ 111 | inception.Mixed_5b, 112 | inception.Mixed_5c, 113 | inception.Mixed_5d, 114 | inception.Mixed_6a, 115 | inception.Mixed_6b, 116 | inception.Mixed_6c, 117 | inception.Mixed_6d, 118 | inception.Mixed_6e, 119 | ] 120 | self.blocks.append(nn.Sequential(*block2)) 121 | 122 | # Block 3: aux classifier to final avgpool 123 | if self.last_needed_block >= 3: 124 | block3 = [ 125 | inception.Mixed_7a, 126 | inception.Mixed_7b, 127 | inception.Mixed_7c, 128 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 129 | ] 130 | self.blocks.append(nn.Sequential(*block3)) 131 | 132 | for param in self.parameters(): 133 | param.requires_grad = requires_grad 134 | 135 | def forward(self, inp): 136 | """Get Inception feature maps 137 | 138 | Parameters 139 | ---------- 140 | inp : torch.autograd.Variable 141 | Input tensor of shape Bx3xHxW. Values are expected to be in 142 | range (0, 1) 143 | 144 | Returns 145 | ------- 146 | List of torch.autograd.Variable, corresponding to the selected output 147 | block, sorted ascending by index 148 | """ 149 | outp = [] 150 | x = inp 151 | 152 | if self.resize_input: 153 | x = F.interpolate(x, 154 | size=(299, 299), 155 | mode='bilinear', 156 | align_corners=False) 157 | 158 | if self.normalize_input: 159 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 160 | 161 | for idx, block in enumerate(self.blocks): 162 | x = block(x) 163 | if idx in self.output_blocks: 164 | outp.append(x) 165 | 166 | if idx == self.last_needed_block: 167 | break 168 | 169 | return outp 170 | 171 | 172 | def _inception_v3(*args, **kwargs): 173 | """Wraps `torchvision.models.inception_v3` 174 | 175 | Skips default weight inititialization if supported by torchvision version. 176 | See https://github.com/mseitzer/pytorch-fid/issues/28. 177 | """ 178 | try: 179 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 180 | except ValueError: 181 | # Just a caution against weird version strings 182 | version = (0,) 183 | 184 | if version >= (0, 6): 185 | kwargs['init_weights'] = False 186 | 187 | return torchvision.models.inception_v3(*args, **kwargs) 188 | 189 | 190 | def fid_inception_v3(): 191 | """Build pretrained Inception model for FID computation 192 | 193 | The Inception model for FID computation uses a different set of weights 194 | and has a slightly different structure than torchvision's Inception. 195 | 196 | This method first constructs torchvision's Inception and then patches the 197 | necessary parts that are different in the FID Inception model. 198 | """ 199 | inception = _inception_v3(num_classes=1008, 200 | aux_logits=False, 201 | pretrained=False) 202 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 203 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 204 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 205 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 206 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 207 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 208 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 209 | inception.Mixed_7b = FIDInceptionE_1(1280) 210 | inception.Mixed_7c = FIDInceptionE_2(2048) 211 | 212 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 213 | inception.load_state_dict(state_dict) 214 | return inception 215 | 216 | 217 | class FIDInceptionA(torchvision.models.inception.InceptionA): 218 | """InceptionA block patched for FID computation""" 219 | def __init__(self, in_channels, pool_features): 220 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 221 | 222 | def forward(self, x): 223 | branch1x1 = self.branch1x1(x) 224 | 225 | branch5x5 = self.branch5x5_1(x) 226 | branch5x5 = self.branch5x5_2(branch5x5) 227 | 228 | branch3x3dbl = self.branch3x3dbl_1(x) 229 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 230 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 231 | 232 | # Patch: Tensorflow's average pool does not use the padded zero's in 233 | # its average calculation 234 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 235 | count_include_pad=False) 236 | branch_pool = self.branch_pool(branch_pool) 237 | 238 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 239 | return torch.cat(outputs, 1) 240 | 241 | 242 | class FIDInceptionC(torchvision.models.inception.InceptionC): 243 | """InceptionC block patched for FID computation""" 244 | def __init__(self, in_channels, channels_7x7): 245 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 246 | 247 | def forward(self, x): 248 | branch1x1 = self.branch1x1(x) 249 | 250 | branch7x7 = self.branch7x7_1(x) 251 | branch7x7 = self.branch7x7_2(branch7x7) 252 | branch7x7 = self.branch7x7_3(branch7x7) 253 | 254 | branch7x7dbl = self.branch7x7dbl_1(x) 255 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 256 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 257 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 258 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 259 | 260 | # Patch: Tensorflow's average pool does not use the padded zero's in 261 | # its average calculation 262 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 263 | count_include_pad=False) 264 | branch_pool = self.branch_pool(branch_pool) 265 | 266 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 267 | return torch.cat(outputs, 1) 268 | 269 | 270 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 271 | """First InceptionE block patched for FID computation""" 272 | def __init__(self, in_channels): 273 | super(FIDInceptionE_1, self).__init__(in_channels) 274 | 275 | def forward(self, x): 276 | branch1x1 = self.branch1x1(x) 277 | 278 | branch3x3 = self.branch3x3_1(x) 279 | branch3x3 = [ 280 | self.branch3x3_2a(branch3x3), 281 | self.branch3x3_2b(branch3x3), 282 | ] 283 | branch3x3 = torch.cat(branch3x3, 1) 284 | 285 | branch3x3dbl = self.branch3x3dbl_1(x) 286 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 287 | branch3x3dbl = [ 288 | self.branch3x3dbl_3a(branch3x3dbl), 289 | self.branch3x3dbl_3b(branch3x3dbl), 290 | ] 291 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 292 | 293 | # Patch: Tensorflow's average pool does not use the padded zero's in 294 | # its average calculation 295 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 296 | count_include_pad=False) 297 | branch_pool = self.branch_pool(branch_pool) 298 | 299 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 300 | return torch.cat(outputs, 1) 301 | 302 | 303 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 304 | """Second InceptionE block patched for FID computation""" 305 | def __init__(self, in_channels): 306 | super(FIDInceptionE_2, self).__init__(in_channels) 307 | 308 | def forward(self, x): 309 | branch1x1 = self.branch1x1(x) 310 | 311 | branch3x3 = self.branch3x3_1(x) 312 | branch3x3 = [ 313 | self.branch3x3_2a(branch3x3), 314 | self.branch3x3_2b(branch3x3), 315 | ] 316 | branch3x3 = torch.cat(branch3x3, 1) 317 | 318 | branch3x3dbl = self.branch3x3dbl_1(x) 319 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 320 | branch3x3dbl = [ 321 | self.branch3x3dbl_3a(branch3x3dbl), 322 | self.branch3x3dbl_3b(branch3x3dbl), 323 | ] 324 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 325 | 326 | # Patch: The FID Inception model uses max pooling instead of average 327 | # pooling. This is likely an error in this specific Inception 328 | # implementation, as other Inception models use average pooling here 329 | # (which matches the description in the paper). 330 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 331 | branch_pool = self.branch_pool(branch_pool) 332 | 333 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 334 | return torch.cat(outputs, 1) 335 | --------------------------------------------------------------------------------