├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── dataset └── dataset.py ├── imgs ├── celebahq.jpg ├── ffhq.jpg ├── latent_interpolation.jpg └── teaser.png ├── models ├── basic_layers.py ├── discriminator.py └── generator.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── requirements.txt ├── train_styleswin.py └── utils ├── CRDiffAug.py ├── distributed.py ├── fid_score.py ├── inception.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleSwin 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/styleswin-transformer-based-gan-for-high-1/image-generation-on-celeba-hq-1024x1024)](https://paperswithcode.com/sota/image-generation-on-celeba-hq-1024x1024?p=styleswin-transformer-based-gan-for-high-1) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/styleswin-transformer-based-gan-for-high-1/image-generation-on-celeba-hq-256x256)](https://paperswithcode.com/sota/image-generation-on-celeba-hq-256x256?p=styleswin-transformer-based-gan-for-high-1) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/styleswin-transformer-based-gan-for-high-1/image-generation-on-ffhq-256-x-256)](https://paperswithcode.com/sota/image-generation-on-ffhq-256-x-256?p=styleswin-transformer-based-gan-for-high-1) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/styleswin-transformer-based-gan-for-high-1/image-generation-on-lsun-churches-256-x-256)](https://paperswithcode.com/sota/image-generation-on-lsun-churches-256-x-256?p=styleswin-transformer-based-gan-for-high-1) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/styleswin-transformer-based-gan-for-high-1/image-generation-on-ffhq-1024-x-1024)](https://paperswithcode.com/sota/image-generation-on-ffhq-1024-x-1024?p=styleswin-transformer-based-gan-for-high-1) 8 | [![WebDemo](https://img.shields.io/badge/%F0%9F%A4%97%20Web%20Demo-Huggingface-blue)](https://huggingface.co/spaces/hysts/StyleSwin) 9 | 10 | ![Teaser](imgs/teaser.png) 11 | 12 | This repo is the official implementation of "[StyleSwin: Transformer-based GAN for High-resolution Image Generation](https://arxiv.org/abs/2112.10762)" (CVPR 2022). 13 | 14 | By [Bowen Zhang](http://home.ustc.edu.cn/~zhangbowen), [Shuyang Gu](http://home.ustc.edu.cn/~gsy777/), [Bo Zhang](https://bo-zhang.me/), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](http://www.dongchen.pro/), [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/), [Yong Wang](https://auto.ustc.edu.cn/2021/0510/c25976a484888/page.htm) and [Baining Guo](microsoft.com/en-us/research/people/bainguo/). 15 | 16 | ## Abstract 17 | 18 | > Despite the tantalizing success in a broad of vision tasks, transformers have not yet demonstrated on-par ability as ConvNets in high-resolution image generative modeling. In this paper, we seek to explore using pure transformers to build a generative adversarial network for high-resolution image synthesis. To this end, we believe that local attention is crucial to strike the balance between computational efficiency and modeling capacity. Hence, the proposed generator adopts Swin transformer in a style-based architecture. To achieve a larger receptive field, we propose double attention which simultaneously leverages the context of the local and the shifted windows, leading to improved generation quality. Moreover, we show that offering the knowledge of the absolute position that has been lost in window-based transformers greatly benefits the generation quality. The proposed StyleSwin is scalable to high resolutions, with both the coarse geometry and fine structures benefit from the strong expressivity of transformers. However, blocking artifacts occur during high-resolution synthesis because performing the local attention in a block-wise manner may break the spatial coherency. To solve this, we empirically investigate various solutions, among which we find that employing a wavelet discriminator to examine the spectral discrepancy effectively suppresses the artifacts. Extensive experiments show the superiority over prior transformer-based GANs, especially on high resolutions, e.g., 1024x1024. The StyleSwin, without complex training strategies, excels over StyleGAN on CelebA-HQ 1024x1024, and achieves on-par performance on FFHQ 1024x1024, proving the promise of using transformers for high-resolution image generation. 19 | 20 | ## Quantitative Results 21 | 22 | | Dataset | Resolution | FID | Pretrained Model | 23 | | :-: | :-: | :-: | :-: | 24 | | FFHQ | 256x256 | 2.81 | [Google Drive](https://drive.google.com/file/d/1OjYZ1zEWGNdiv0RFKv7KhXRmYko72LjO/view?usp=sharing)/[Azure Storage](https://facevcstandard.blob.core.windows.net/v-bowenz/output/styleswin_final_results/FFHQ256/FFHQ_256.pt?sv=2020-10-02&st=2022-03-14T12%3A36%3A35Z&se=2099-12-31T15%3A59%3A00Z&sr=b&sp=r&sig=QBETIToFQ8MtlnnVLpNlbcPB8MPTZkiDDTNjlgovf%2Fo%3D) | 25 | | LSUN Church | 256x256 | 2.95 | [Google Drive](https://drive.google.com/file/d/1HF0wFNuz1WFrqGEbPhOXjL4QrY05Zu_m/view?usp=sharing)/[Azure Storage](https://facevcstandard.blob.core.windows.net/v-bowenz/output/styleswin_final_results/LSUNChurch256/LSUNChurch_256.pt?sv=2020-10-02&st=2022-03-14T12%3A37%3A41Z&se=2099-12-31T15%3A59%3A00Z&sr=b&sp=r&sig=VPWMsvHbJKUj8v6a9gp2u424OAS9o%2BL1qKKfGtYWMN8%3D) | 26 | | CelebA-HQ | 256x256 | 3.25 | [Google Drive](https://drive.google.com/file/d/1YtIJOgLFfkaMI_KL2gBQNABFb1cwOzvM/view?usp=sharing)/[Azure Storage](https://facevcstandard.blob.core.windows.net/v-bowenz/output/styleswin_final_results/CelebAHQ256/CelebAHQ_256.pt?sv=2020-10-02&st=2022-03-14T12%3A39%3A42Z&se=2099-12-31T15%3A59%3A00Z&sr=b&sp=r&sig=xasn7w5ou739bM9NAwmA3HEFkxKXOrqddH76EviXewo%3D) | 27 | | FFHQ | 1024x1024 | 5.07 | [Google Drive](https://drive.google.com/file/d/17-ILwzLBoHq4HTdAPeaCug7iBvxKWkvp/view?usp=sharing)/[Azure Storage](https://facevcstandard.blob.core.windows.net/v-bowenz/output/styleswin_final_results/FFHQ1024/FFHQ_1024.pt?sv=2020-10-02&st=2022-03-14T12%3A40%3A20Z&se=2099-12-31T15%3A59%3A00Z&sr=b&sp=r&sig=Di7J57LLvayVVTmEymjI61y42q%2BxS9pxCmBHbay6t%2Bk%3D) | 28 | | CelebA-HQ | 1024x1024 | 4.43 | [Google Drive](https://drive.google.com/file/d/1y3wkykjvCbteTaGTRF8EedkG-N1Z8jFf/view?usp=sharing)/[Azure Storage](https://facevcstandard.blob.core.windows.net/v-bowenz/output/styleswin_final_results/CelebAHQ1024/CelebAHQ_1024.pt?sv=2020-10-02&st=2022-03-14T12%3A40%3A49Z&se=2099-12-31T15%3A59%3A00Z&sr=b&sp=r&sig=gicenvMBClfUmFr1gew06exUoN033JWADmNLCCtyu4w%3D) | 29 | 30 | ## Requirements 31 | 32 | To install the dependencies: 33 | 34 | ```bash 35 | python -m pip install -r requirements.txt 36 | ``` 37 | 38 | ## Generating image samples with pretrained model 39 | 40 | To generate 50k image samples of resolution **1024** and evaluate the fid score: 41 | 42 | ```bash 43 | python -m torch.distributed.launch --nproc_per_node=1 train_styleswin.py --sample_path /path_to_save_generated_samples --size 1024 --ckpt /path/to/checkpoint --eval --val_num_batches 12500 --val_batch_size 4 --eval_gt_path /path_to_real_images_50k 44 | ``` 45 | 46 | To generate 50k image samples of resolution **256** and evaluate the fid score: 47 | 48 | ```bash 49 | python -m torch.distributed.launch --nproc_per_node=1 train_styleswin.py --sample_path /path_to_save_generated_samples --size 256 --G_channel_multiplier 2 --ckpt /path/to/checkpoint --eval --val_num_batches 12500 --val_batch_size 4 --eval_gt_path /path_to_real_images_50k 50 | ``` 51 | 52 | ## Training 53 | 54 | ### Data preparing 55 | 56 | When training FFHQ and CelebA-HQ, we use `ImageFolder` datasets. The data structure is like this: 57 | 58 | ``` 59 | FFHQ 60 | ├── images 61 | │ ├── 000001.png 62 | │ ├── ... 63 | ``` 64 | 65 | When training LSUN Church, please follow [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch#usage) to create a lmdb dataset first. After this, the data structure is like this: 66 | 67 | ``` 68 | LSUN Church 69 | ├── data.mdb 70 | └── lock.mdb 71 | ``` 72 | 73 | ### FFHQ-1024 74 | 75 | To train a new model of **FFHQ-1024** from scratch: 76 | 77 | ```bash 78 | python -m torch.distributed.launch --nproc_per_node=8 train_styleswin.py --batch 2 --path /path_to_ffhq_1024 --checkpoint_path /tmp --sample_path /tmp --size 1024 --D_lr 0.0002 --D_sn --ttur --eval_gt_path /path_to_ffhq_real_images_50k --lr_decay --lr_decay_start_steps 600000 79 | ``` 80 | 81 | ### CelebA-HQ 1024 82 | 83 | To train a new model of **CelebA-HQ 1024** from scratch: 84 | 85 | ```bash 86 | python -m torch.distributed.launch --nproc_per_node=8 train_styleswin.py --batch 2 --path /path_to_celebahq_1024 --checkpoint_path /tmp --sample_path /tmp --size 1024 --D_lr 0.0002 --D_sn --ttur --eval_gt_path /path_to_celebahq_real_images_50k 87 | ``` 88 | 89 | ### FFHQ-256 90 | 91 | To train a new model of **FFHQ-256** from scratch: 92 | 93 | ```bash 94 | python -m torch.distributed.launch --nproc_per_node=8 train_styleswin.py --batch 4 --path /path_to_ffhq_256 --checkpoint_path /tmp --sample_path /tmp --size 256 --G_channel_multiplier 2 --bcr --D_lr 0.0002 --D_sn --ttur --eval_gt_path /path_to_ffhq_real_images_50k --lr_decay --lr_decay_start_steps 775000 --iter 1000000 95 | ``` 96 | 97 | ### CelebA-HQ 256 98 | 99 | To train a new model of **CelebA-HQ 256** from scratch: 100 | 101 | ```bash 102 | python -m torch.distributed.launch --nproc_per_node=8 train_styleswin.py --batch 4 --path /path_to_celebahq_256 --checkpoint_path /tmp --sample_path /tmp --size 256 --G_channel_multiplier 2 --bcr --r1 5 --D_lr 0.0002 --D_sn --ttur --eval_gt_path /path_to_celebahq_real_images_50k --lr_decay --lr_decay_start_steps 500000 103 | ``` 104 | 105 | ### LSUN Church 256 106 | 107 | To train a new model of **LSUN Church 256** from scratch: 108 | 109 | ```bash 110 | python -m torch.distributed.launch --nproc_per_node=8 train_styleswin.py --batch 4 --path /path_to_lsun_church_256 --checkpoint_path /tmp --sample_path /tmp --size 256 --G_channel_multiplier 2 --use_flip --r1 5 --lmdb --D_lr 0.0002 --D_sn --ttur --eval_gt_path /path_to_lsun_church_real_images_50k --lr_decay --lr_decay_start_steps 1300000 --iter 1500000 111 | ``` 112 | 113 | **Notice**: When training on 16 GB GPUs, you could add `--use_checkpoint` to save GPU memory. 114 | 115 | ## Qualitative Results 116 | 117 | Image samples of FFHQ-1024 generated by StyleSwin: 118 | 119 | ![](imgs/ffhq.jpg) 120 | 121 | Image samples of CelebA-HQ 1024 generated by StyleSwin: 122 | 123 | ![](imgs/celebahq.jpg) 124 | 125 | Latent code interpolation examples of FFHQ-1024 between the left-most and the right-most images: 126 | 127 | ![](imgs/latent_interpolation.jpg) 128 | 129 | ## Citing StyleSwin 130 | 131 | ``` 132 | @misc{zhang2021styleswin, 133 | title={StyleSwin: Transformer-based GAN for High-resolution Image Generation}, 134 | author={Bowen Zhang and Shuyang Gu and Bo Zhang and Jianmin Bao and Dong Chen and Fang Wen and Yong Wang and Baining Guo}, 135 | year={2021}, 136 | eprint={2112.10762}, 137 | archivePrefix={arXiv}, 138 | primaryClass={cs.CV} 139 | } 140 | ``` 141 | 142 | ## Responsible AI Considerations 143 | 144 | Our work does not directly modify the exiting images which may alter the identity or expression of the people. We discourage the use of our work in such applications as it is not designed to do so. We have quantitatively verified that the proposed method does not show evident disparity, on gender and ages as the model mostly follows the dataset distribution; however, we encourage additional care if you intend to use the system on certain demographic groups. We also encourage use of fair and representative data when training on customized data. We caution that the high-resolution images produced by our model may potentially be misused for impersonating humans and viable solutions so avoid this include adding tags or watermarks when distributing the generated photos. 145 | 146 | ## Acknowledgements 147 | 148 | This code borrows heavily from [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch) and [Swin-Transformer](https://github.com/microsoft/Swin-Transformer). We also thank the contributors of code [Positional Encoding in GANs](https://github.com/open-mmlab/mmgeneration/blob/master/configs/positional_encoding_in_gans/README.md), [DiffAug](https://github.com/mit-han-lab/data-efficient-gans), [StudioGAN](https://github.com/POSTECH-CVLab/PyTorch-StudioGAN) and [GIQA](https://github.com/cientgu/GIQA). 149 | 150 | ## Maintenance 151 | 152 | This is the codebase for our research work. Please open a GitHub issue for any help. If you have any questions regarding the technical details, feel free to contact [zhangbowen@mail.ustc.edu.cn](zhangbowen@mail.ustc.edu.cn) or [zhanbo@microsoft.com](zhanbo@microsoft.com). 153 | 154 | 155 | ## License 156 | The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file. We use our labeled dataset to train the scratch detection model. 157 | 158 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 159 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from io import BytesIO 5 | 6 | import lmdb 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class MultiResolutionDataset(Dataset): 12 | def __init__(self, path, transform, resolution=256): 13 | self.env = lmdb.open( 14 | path, 15 | max_readers=32, 16 | readonly=True, 17 | lock=False, 18 | readahead=False, 19 | meminit=False, 20 | ) 21 | 22 | if not self.env: 23 | raise IOError('Cannot open lmdb dataset', path) 24 | 25 | with self.env.begin(write=False) as txn: 26 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 27 | 28 | self.resolution = resolution 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return self.length 33 | 34 | def __getitem__(self, index): 35 | with self.env.begin(write=False) as txn: 36 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 37 | img_bytes = txn.get(key) 38 | 39 | buffer = BytesIO(img_bytes) 40 | img = Image.open(buffer) 41 | img = self.transform(img) 42 | 43 | return img 44 | -------------------------------------------------------------------------------- /imgs/celebahq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/StyleSwin/a7c28b3121ca86f0577696d6f1d699e13d368672/imgs/celebahq.jpg -------------------------------------------------------------------------------- /imgs/ffhq.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/StyleSwin/a7c28b3121ca86f0577696d6f1d699e13d368672/imgs/ffhq.jpg -------------------------------------------------------------------------------- /imgs/latent_interpolation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/StyleSwin/a7c28b3121ca86f0577696d6f1d699e13d368672/imgs/latent_interpolation.jpg -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/StyleSwin/a7c28b3121ca86f0577696d6f1d699e13d368672/imgs/teaser.png -------------------------------------------------------------------------------- /models/basic_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | 6 | import numpy as np 7 | import torch 8 | from op import fused_leaky_relu, upfirdn2d 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class Blur(nn.Module): 14 | def __init__(self, kernel, pad, upsample_factor=1): 15 | super().__init__() 16 | 17 | kernel = make_kernel(kernel) 18 | 19 | if upsample_factor > 1: 20 | kernel = kernel * (upsample_factor ** 2) 21 | 22 | self.register_buffer('kernel', kernel) 23 | 24 | self.pad = pad 25 | 26 | def forward(self, input): 27 | out = upfirdn2d(input, self.kernel, pad=self.pad) 28 | 29 | return out 30 | 31 | 32 | class EqualConv2d(nn.Module): 33 | def __init__( 34 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 35 | ): 36 | super().__init__() 37 | 38 | self.weight = nn.Parameter( 39 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 40 | ) 41 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 42 | 43 | self.stride = stride 44 | self.padding = padding 45 | 46 | if bias: 47 | self.bias = nn.Parameter(torch.zeros(out_channel)) 48 | 49 | else: 50 | self.bias = None 51 | 52 | def forward(self, input): 53 | out = F.conv2d( 54 | input, 55 | self.weight * self.scale, 56 | bias=self.bias, 57 | stride=self.stride, 58 | padding=self.padding, 59 | ) 60 | 61 | return out 62 | 63 | def __repr__(self): 64 | return ( 65 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 66 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 67 | ) 68 | 69 | 70 | class EqualLinear(nn.Module): 71 | def __init__( 72 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 73 | ): 74 | super().__init__() 75 | 76 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 77 | 78 | if bias: 79 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 80 | 81 | else: 82 | self.bias = None 83 | 84 | self.activation = activation 85 | 86 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 87 | self.lr_mul = lr_mul 88 | 89 | def forward(self, input): 90 | if self.activation: 91 | out = F.linear(input, self.weight * self.scale) 92 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 93 | 94 | else: 95 | out = F.linear( 96 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 97 | ) 98 | 99 | return out 100 | 101 | def __repr__(self): 102 | return ( 103 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 104 | ) 105 | 106 | 107 | class PixelNorm(nn.Module): 108 | def __init__(self): 109 | super().__init__() 110 | 111 | def forward(self, input): 112 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 113 | 114 | 115 | def make_kernel(k): 116 | k = torch.tensor(k, dtype=torch.float32) 117 | 118 | if k.ndim == 1: 119 | k = k[None, :] * k[:, None] 120 | 121 | k /= k.sum() 122 | 123 | return k 124 | 125 | 126 | class Upsample(nn.Module): 127 | def __init__(self, kernel, factor=2): 128 | super().__init__() 129 | 130 | self.factor = factor 131 | kernel = make_kernel(kernel) * (factor ** 2) 132 | self.register_buffer('kernel', kernel) 133 | 134 | p = kernel.shape[0] - factor 135 | 136 | pad0 = (p + 1) // 2 + factor - 1 137 | pad1 = p // 2 138 | 139 | self.pad = (pad0, pad1) 140 | 141 | def forward(self, input): 142 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 143 | 144 | return out 145 | 146 | 147 | class Downsample(nn.Module): 148 | def __init__(self, kernel, factor=2): 149 | super().__init__() 150 | 151 | self.factor = factor 152 | kernel = make_kernel(kernel) 153 | self.register_buffer('kernel', kernel) 154 | 155 | p = kernel.shape[0] - factor 156 | 157 | pad0 = (p + 1) // 2 158 | pad1 = p // 2 159 | 160 | self.pad = (pad0, pad1) 161 | 162 | def forward(self, input): 163 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 164 | 165 | return out 166 | 167 | 168 | class ScaledLeakyReLU(nn.Module): 169 | def __init__(self, negative_slope=0.2): 170 | super().__init__() 171 | 172 | self.negative_slope = negative_slope 173 | 174 | def forward(self, input): 175 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 176 | 177 | return out * math.sqrt(2) 178 | 179 | 180 | class ModulatedConv2d(nn.Module): 181 | def __init__( 182 | self, 183 | in_channel, 184 | out_channel, 185 | kernel_size, 186 | style_dim, 187 | demodulate=True, 188 | upsample=False, 189 | downsample=False, 190 | blur_kernel=[1, 3, 3, 1], 191 | ): 192 | super().__init__() 193 | 194 | self.eps = 1e-8 195 | self.kernel_size = kernel_size 196 | self.in_channel = in_channel 197 | self.out_channel = out_channel 198 | self.upsample = upsample 199 | self.downsample = downsample 200 | 201 | if upsample: 202 | factor = 2 203 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 204 | pad0 = (p + 1) // 2 + factor - 1 205 | pad1 = p // 2 + 1 206 | 207 | self.blur = Blur(blur_kernel, pad=( 208 | pad0, pad1), upsample_factor=factor) 209 | 210 | if downsample: 211 | factor = 2 212 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 213 | pad0 = (p + 1) // 2 214 | pad1 = p // 2 215 | 216 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 217 | 218 | fan_in = in_channel * kernel_size ** 2 219 | self.scale = 1 / math.sqrt(fan_in) 220 | self.padding = kernel_size // 2 221 | 222 | self.weight = nn.Parameter( 223 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 224 | ) 225 | 226 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 227 | 228 | self.demodulate = demodulate 229 | 230 | def __repr__(self): 231 | return ( 232 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 233 | f'upsample={self.upsample}, downsample={self.downsample})' 234 | ) 235 | 236 | def forward(self, input, style): 237 | batch, in_channel, height, width = input.shape 238 | 239 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 240 | weight = self.scale * self.weight * style 241 | 242 | if self.demodulate: 243 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 244 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 245 | 246 | weight = weight.view( 247 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 248 | ) 249 | 250 | if self.upsample: 251 | input = input.view(1, batch * in_channel, height, width) 252 | weight = weight.view( 253 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 254 | ) 255 | weight = weight.transpose(1, 2).reshape( 256 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 257 | ) 258 | out = F.conv_transpose2d( 259 | input, weight, padding=0, stride=2, groups=batch) 260 | _, _, height, width = out.shape 261 | out = out.view(batch, self.out_channel, height, width) 262 | out = self.blur(out) 263 | 264 | elif self.downsample: 265 | input = self.blur(input) 266 | _, _, height, width = input.shape 267 | input = input.view(1, batch * in_channel, height, width) 268 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 269 | _, _, height, width = out.shape 270 | out = out.view(batch, self.out_channel, height, width) 271 | 272 | else: 273 | input = input.reshape(1, batch * in_channel, height, width) 274 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 275 | _, _, height, width = out.shape 276 | out = out.view(batch, self.out_channel, height, width) 277 | 278 | return out 279 | 280 | 281 | class SinusoidalPositionalEmbedding(nn.Module): 282 | """Sinusoidal Positional Embedding 1D or 2D (SPE/SPE2d). 283 | 284 | This module is a modified from: 285 | https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py # noqa 286 | 287 | Based on the original SPE in single dimension, we implement a 2D sinusoidal 288 | positional encodding (SPE2d), as introduced in Positional Encoding as 289 | Spatial Inductive Bias in GANs, CVPR'2021. 290 | 291 | Args: 292 | embedding_dim (int): The number of dimensions for the positional 293 | encoding. 294 | padding_idx (int | list[int]): The index for the padding contents. The 295 | padding positions will obtain an encoding vector filling in zeros. 296 | init_size (int, optional): The initial size of the positional buffer. 297 | Defaults to 1024. 298 | div_half_dim (bool, optional): If true, the embedding will be divided 299 | by :math:`d/2`. Otherwise, it will be divided by 300 | :math:`(d/2 -1)`. Defaults to False. 301 | center_shift (int | None, optional): Shift the center point to some 302 | index. Defaults to None. 303 | """ 304 | 305 | def __init__(self, 306 | embedding_dim, 307 | padding_idx, 308 | init_size=1024, 309 | div_half_dim=False, 310 | center_shift=None): 311 | super().__init__() 312 | self.embedding_dim = embedding_dim 313 | self.padding_idx = padding_idx 314 | self.div_half_dim = div_half_dim 315 | self.center_shift = center_shift 316 | 317 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 318 | init_size, embedding_dim, padding_idx, self.div_half_dim) 319 | 320 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 321 | 322 | self.max_positions = int(1e5) 323 | 324 | @staticmethod 325 | def get_embedding(num_embeddings, 326 | embedding_dim, 327 | padding_idx=None, 328 | div_half_dim=False): 329 | """Build sinusoidal embeddings. 330 | 331 | This matches the implementation in tensor2tensor, but differs slightly 332 | from the description in Section 3.5 of "Attention Is All You Need". 333 | """ 334 | assert embedding_dim % 2 == 0, ( 335 | 'In this version, we request ' 336 | f'embedding_dim divisible by 2 but got {embedding_dim}') 337 | 338 | # there is a little difference from the original paper. 339 | half_dim = embedding_dim // 2 340 | if not div_half_dim: 341 | emb = np.log(10000) / (half_dim - 1) 342 | else: 343 | emb = np.log(1e4) / half_dim 344 | # compute exp(-log10000 / d * i) 345 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 346 | emb = torch.arange( 347 | num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 348 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], 349 | dim=1).view(num_embeddings, -1) 350 | if padding_idx is not None: 351 | emb[padding_idx, :] = 0 352 | 353 | return emb 354 | 355 | def forward(self, input, **kwargs): 356 | """Input is expected to be of size [bsz x seqlen]. 357 | 358 | Returned tensor is expected to be of size [bsz x seq_len x emb_dim] 359 | """ 360 | assert input.dim() == 2 or input.dim( 361 | ) == 4, 'Input dimension should be 2 (1D) or 4(2D)' 362 | 363 | if input.dim() == 4: 364 | return self.make_grid2d_like(input, **kwargs) 365 | 366 | b, seq_len = input.shape 367 | max_pos = self.padding_idx + 1 + seq_len 368 | 369 | if self.weights is None or max_pos > self.weights.size(0): 370 | # recompute/expand embedding if needed 371 | self.weights = SinusoidalPositionalEmbedding.get_embedding( 372 | max_pos, self.embedding_dim, self.padding_idx) 373 | self.weights = self.weights.to(self._float_tensor) 374 | 375 | positions = self.make_positions(input, self.padding_idx).to( 376 | self._float_tensor.device) 377 | 378 | return self.weights.index_select(0, positions.view(-1)).view( 379 | b, seq_len, self.embedding_dim).detach() 380 | 381 | def make_positions(self, input, padding_idx): 382 | mask = input.ne(padding_idx).int() 383 | return (torch.cumsum(mask, dim=1).type_as(mask) * 384 | mask).long() + padding_idx 385 | 386 | def make_grid2d(self, height, width, num_batches=1, center_shift=None): 387 | h, w = height, width 388 | # if `center_shift` is not given from the outside, use 389 | # `self.center_shift` 390 | if center_shift is None: 391 | center_shift = self.center_shift 392 | 393 | h_shift = 0 394 | w_shift = 0 395 | # center shift to the input grid 396 | if center_shift is not None: 397 | # if h/w is even, the left center should be aligned with 398 | # center shift 399 | if h % 2 == 0: 400 | h_left_center = h // 2 401 | h_shift = center_shift - h_left_center 402 | else: 403 | h_center = h // 2 + 1 404 | h_shift = center_shift - h_center 405 | 406 | if w % 2 == 0: 407 | w_left_center = w // 2 408 | w_shift = center_shift - w_left_center 409 | else: 410 | w_center = w // 2 + 1 411 | w_shift = center_shift - w_center 412 | 413 | # Note that the index is started from 1 since zero will be padding idx. 414 | # axis -- (b, h or w) 415 | x_axis = torch.arange(1, w + 1).unsqueeze(0).repeat(num_batches, 416 | 1) + w_shift 417 | y_axis = torch.arange(1, h + 1).unsqueeze(0).repeat(num_batches, 418 | 1) + h_shift 419 | 420 | # emb -- (b, emb_dim, h or w) 421 | x_emb = self(x_axis).transpose(1, 2) 422 | y_emb = self(y_axis).transpose(1, 2) 423 | 424 | # make grid for x/y axis 425 | # Note that repeat will copy data. If use learned emb, expand may be 426 | # better. 427 | x_grid = x_emb.unsqueeze(2).repeat(1, 1, h, 1) 428 | y_grid = y_emb.unsqueeze(3).repeat(1, 1, 1, w) 429 | 430 | # cat grid -- (b, 2 x emb_dim, h, w) 431 | grid = torch.cat([x_grid, y_grid], dim=1) 432 | return grid.detach() 433 | 434 | def make_grid2d_like(self, x, center_shift=None): 435 | """Input tensor with shape of (b, ..., h, w) Return tensor with shape 436 | of (b, 2 x emb_dim, h, w) 437 | 438 | Note that the positional embedding highly depends on the the function, 439 | ``make_positions``. 440 | """ 441 | h, w = x.shape[-2:] 442 | grid = self.make_grid2d(h, w, x.size(0), center_shift) 443 | 444 | return grid.to(x) 445 | -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | 6 | import torch 7 | from op import FusedLeakyReLU, upfirdn2d 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn.utils import spectral_norm 11 | 12 | from models.basic_layers import (Blur, Downsample, EqualConv2d, EqualLinear, 13 | ScaledLeakyReLU) 14 | 15 | 16 | class ConvLayer(nn.Sequential): 17 | def __init__( 18 | self, 19 | in_channel, 20 | out_channel, 21 | kernel_size, 22 | downsample=False, 23 | blur_kernel=[1, 3, 3, 1], 24 | bias=True, 25 | activate=True, 26 | sn=False 27 | ): 28 | layers = [] 29 | 30 | if downsample: 31 | factor = 2 32 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 33 | pad0 = (p + 1) // 2 34 | pad1 = p // 2 35 | 36 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 37 | 38 | stride = 2 39 | self.padding = 0 40 | 41 | else: 42 | stride = 1 43 | self.padding = kernel_size // 2 44 | 45 | if sn: 46 | # Not use equal conv2d when apply SN 47 | layers.append( 48 | spectral_norm(nn.Conv2d( 49 | in_channel, 50 | out_channel, 51 | kernel_size, 52 | padding=self.padding, 53 | stride=stride, 54 | bias=bias and not activate, 55 | )) 56 | ) 57 | else: 58 | layers.append( 59 | EqualConv2d( 60 | in_channel, 61 | out_channel, 62 | kernel_size, 63 | padding=self.padding, 64 | stride=stride, 65 | bias=bias and not activate, 66 | ) 67 | ) 68 | 69 | if activate: 70 | if bias: 71 | layers.append(FusedLeakyReLU(out_channel)) 72 | else: 73 | layers.append(ScaledLeakyReLU(0.2)) 74 | 75 | super().__init__(*layers) 76 | 77 | 78 | class ConvBlock(nn.Module): 79 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], sn=False): 80 | super().__init__() 81 | 82 | self.conv1 = ConvLayer(in_channel, in_channel, 3, sn=sn) 83 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, sn=sn) 84 | 85 | def forward(self, input): 86 | out = self.conv1(input) 87 | out = self.conv2(out) 88 | 89 | return out 90 | 91 | 92 | def get_haar_wavelet(in_channels): 93 | haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2) 94 | haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2) 95 | haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0] 96 | 97 | haar_wav_ll = haar_wav_l.T * haar_wav_l 98 | haar_wav_lh = haar_wav_h.T * haar_wav_l 99 | haar_wav_hl = haar_wav_l.T * haar_wav_h 100 | haar_wav_hh = haar_wav_h.T * haar_wav_h 101 | 102 | return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh 103 | 104 | 105 | class HaarTransform(nn.Module): 106 | def __init__(self, in_channels): 107 | super().__init__() 108 | 109 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 110 | 111 | self.register_buffer('ll', ll) 112 | self.register_buffer('lh', lh) 113 | self.register_buffer('hl', hl) 114 | self.register_buffer('hh', hh) 115 | 116 | def forward(self, input): 117 | ll = upfirdn2d(input, self.ll, down=2) 118 | lh = upfirdn2d(input, self.lh, down=2) 119 | hl = upfirdn2d(input, self.hl, down=2) 120 | hh = upfirdn2d(input, self.hh, down=2) 121 | 122 | return torch.cat((ll, lh, hl, hh), 1) 123 | 124 | 125 | class InverseHaarTransform(nn.Module): 126 | def __init__(self, in_channels): 127 | super().__init__() 128 | 129 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 130 | 131 | self.register_buffer('ll', ll) 132 | self.register_buffer('lh', -lh) 133 | self.register_buffer('hl', -hl) 134 | self.register_buffer('hh', hh) 135 | 136 | def forward(self, input): 137 | ll, lh, hl, hh = input.chunk(4, 1) 138 | ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0)) 139 | lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0)) 140 | hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0)) 141 | hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0)) 142 | 143 | return ll + lh + hl + hh 144 | 145 | 146 | class FromRGB(nn.Module): 147 | def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1], sn=False): 148 | super().__init__() 149 | 150 | self.downsample = downsample 151 | 152 | if downsample: 153 | self.iwt = InverseHaarTransform(3) 154 | self.downsample = Downsample(blur_kernel) 155 | self.dwt = HaarTransform(3) 156 | 157 | self.conv = ConvLayer(3 * 4, out_channel, 1, sn=sn) 158 | 159 | def forward(self, input, skip=None): 160 | if self.downsample: 161 | input = self.iwt(input) 162 | input = self.downsample(input) 163 | input = self.dwt(input) 164 | 165 | out = self.conv(input) 166 | 167 | if skip is not None: 168 | out = out + skip 169 | 170 | return input, out 171 | 172 | 173 | class Discriminator(nn.Module): 174 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1], sn=False, ssd=False): 175 | super().__init__() 176 | 177 | channels = { 178 | 4: 512, 179 | 8: 512, 180 | 16: 512, 181 | 32: 512, 182 | 64: 256 * channel_multiplier, 183 | 128: 128 * channel_multiplier, 184 | 256: 64 * channel_multiplier, 185 | 512: 32 * channel_multiplier, 186 | 1024: 16 * channel_multiplier, 187 | } 188 | 189 | self.dwt = HaarTransform(3) 190 | 191 | self.from_rgbs = nn.ModuleList() 192 | self.convs = nn.ModuleList() 193 | 194 | log_size = int(math.log(size, 2)) - 1 195 | 196 | in_channel = channels[size] 197 | 198 | for i in range(log_size, 2, -1): 199 | out_channel = channels[2 ** (i - 1)] 200 | 201 | self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size, sn=sn)) 202 | self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel, sn=sn)) 203 | 204 | in_channel = out_channel 205 | 206 | self.from_rgbs.append(FromRGB(channels[4], sn=sn)) 207 | 208 | self.stddev_group = 4 209 | self.stddev_feat = 1 210 | 211 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3, sn=sn) 212 | if sn: 213 | self.final_linear = nn.Sequential( 214 | spectral_norm(nn.Linear(channels[4] * 4 * 4, channels[4])), 215 | FusedLeakyReLU(channels[4]), 216 | spectral_norm(nn.Linear(channels[4], 1)), 217 | ) 218 | else: 219 | self.final_linear = nn.Sequential( 220 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 221 | EqualLinear(channels[4], 1), 222 | ) 223 | 224 | def forward(self, input): 225 | input = self.dwt(input) 226 | out = None 227 | 228 | for from_rgb, conv in zip(self.from_rgbs, self.convs): 229 | input, out = from_rgb(input, out) 230 | out = conv(out) 231 | 232 | _, out = self.from_rgbs[-1](input, out) 233 | 234 | batch, channel, height, width = out.shape 235 | group = min(batch, self.stddev_group) 236 | stddev = out.view( 237 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 238 | ) 239 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 240 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 241 | stddev = stddev.repeat(group, 1, height, width) 242 | out = torch.cat([out, stddev], 1) 243 | 244 | out = self.final_conv(out) 245 | 246 | out = out.view(batch, -1) 247 | out = self.final_linear(out) 248 | 249 | return out 250 | -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.utils.checkpoint as checkpoint 8 | from timm.models.layers import to_2tuple, trunc_normal_ 9 | from torch import nn 10 | 11 | from models.basic_layers import (EqualLinear, PixelNorm, 12 | SinusoidalPositionalEmbedding, Upsample) 13 | 14 | 15 | class ToRGB(nn.Module): 16 | def __init__(self, in_channel, upsample=True, resolution=None, blur_kernel=[1, 3, 3, 1]): 17 | super().__init__() 18 | self.is_upsample = upsample 19 | self.resolution = resolution 20 | 21 | if upsample: 22 | self.upsample = Upsample(blur_kernel) 23 | 24 | self.conv = nn.Conv2d(in_channel, 3, kernel_size=1) 25 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 26 | 27 | def forward(self, input, skip=None): 28 | out = self.conv(input) 29 | out = out + self.bias 30 | 31 | if skip is not None: 32 | if self.is_upsample: 33 | skip = self.upsample(skip) 34 | 35 | out = out + skip 36 | return out 37 | 38 | def flops(self): 39 | m = self.conv 40 | kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh 41 | bias_ops = 1 42 | # N x Cout x H x W x (Cin x Kw x Kh + bias) 43 | flops = 1 * self.resolution * self.resolution * 3 * (m.in_channels // m.groups * kernel_ops + bias_ops) 44 | if self.is_upsample: 45 | # there is a conv used in upsample 46 | w_shape = (1, 1, 4, 4) 47 | kernel_ops = torch.zeros(w_shape[2:]).numel() # Kw x Kh 48 | # N x Cout x H x W x (Cin x Kw x Kh + bias) 49 | flops = 1 * 3 * (2 * self.resolution + 3) * (2 *self.resolution + 3) * (3 * kernel_ops) 50 | return flops 51 | 52 | 53 | class Mlp(nn.Module): 54 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 55 | super().__init__() 56 | out_features = out_features or in_features 57 | hidden_features = hidden_features or in_features 58 | self.hidden_features = hidden_features 59 | self.fc1 = nn.Linear(in_features, hidden_features) 60 | self.act = act_layer() 61 | self.fc2 = nn.Linear(hidden_features, out_features) 62 | self.drop = nn.Dropout(drop) 63 | 64 | def forward(self, x): 65 | x = self.fc1(x) 66 | x = self.act(x) 67 | x = self.drop(x) 68 | x = self.fc2(x) 69 | x = self.drop(x) 70 | return x 71 | 72 | 73 | def window_partition(x, window_size): 74 | """ 75 | Args: 76 | x: (B, H, W, C) 77 | window_size (int): window size 78 | 79 | Returns: 80 | windows: (num_windows*B, window_size, window_size, C) 81 | """ 82 | B, H, W, C = x.shape 83 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 84 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 85 | return windows 86 | 87 | 88 | def window_reverse(windows, window_size, H, W): 89 | """ 90 | Args: 91 | windows: (num_windows*B, window_size, window_size, C) 92 | window_size (int): Window size 93 | H (int): Height of image 94 | W (int): Width of image 95 | 96 | Returns: 97 | x: (B, H, W, C) 98 | """ 99 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 100 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 101 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 102 | return x 103 | 104 | 105 | class WindowAttention(nn.Module): 106 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 107 | It supports both of shifted and non-shifted window. 108 | 109 | Args: 110 | dim (int): Number of input channels. 111 | window_size (tuple[int]): The height and width of the window. 112 | num_heads (int): Number of attention heads. 113 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 114 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 115 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 116 | """ 117 | 118 | def __init__(self, dim, window_size, num_heads, qk_scale=None, attn_drop=0.): 119 | 120 | super().__init__() 121 | self.dim = dim 122 | self.window_size = window_size # Wh, Ww 123 | self.num_heads = num_heads 124 | head_dim = dim // num_heads 125 | self.head_dim = head_dim 126 | self.scale = qk_scale or head_dim ** -0.5 127 | 128 | # define a parameter table of relative position bias 129 | self.relative_position_bias_table = nn.Parameter( 130 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 131 | 132 | # get pair-wise relative position index for each token inside the window 133 | coords_h = torch.arange(self.window_size[0]) 134 | coords_w = torch.arange(self.window_size[1]) 135 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 136 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 137 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 138 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 139 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 140 | relative_coords[:, :, 1] += self.window_size[1] - 1 141 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 142 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 143 | self.register_buffer("relative_position_index", relative_position_index) 144 | trunc_normal_(self.relative_position_bias_table, std=.02) 145 | 146 | self.attn_drop = nn.Dropout(attn_drop) 147 | 148 | self.softmax = nn.Softmax(dim=-1) 149 | 150 | def forward(self, q, k, v, mask=None): 151 | """ 152 | Args: 153 | q: queries with shape of (num_windows*B, N, C) 154 | k: keys with shape of (num_windows*B, N, C) 155 | v: values with shape of (num_windows*B, N, C) 156 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 157 | """ 158 | B_, N, C = q.shape 159 | q = q.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 160 | k = k.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 161 | v = v.reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 162 | 163 | q = q * self.scale 164 | attn = (q @ k.transpose(-2, -1)) 165 | 166 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 167 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 168 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 169 | attn = attn + relative_position_bias.unsqueeze(0) 170 | 171 | if mask is not None: 172 | nW = mask.shape[0] 173 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 174 | attn = attn.view(-1, self.num_heads, N, N) 175 | attn = self.softmax(attn) 176 | else: 177 | attn = self.softmax(attn) 178 | 179 | attn = self.attn_drop(attn) 180 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 181 | 182 | return x 183 | 184 | def extra_repr(self) -> str: 185 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 186 | 187 | def flops(self, N): 188 | # calculate flops for 1 window with token length of N 189 | flops = 0 190 | # qkv = self.qkv(x) 191 | flops += N * self.dim * 3 * self.dim 192 | # attn = (q @ k.transpose(-2, -1)) 193 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 194 | # x = (attn @ v) 195 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 196 | # x = self.proj(x) 197 | flops += N * self.dim * self.dim 198 | return flops 199 | 200 | 201 | class AdaptiveInstanceNorm(nn.Module): 202 | def __init__(self, in_channel, style_dim): 203 | super().__init__() 204 | self.norm = nn.InstanceNorm1d(in_channel) 205 | self.style = EqualLinear(style_dim, in_channel * 2) 206 | 207 | def forward(self, input, style): 208 | style = self.style(style).unsqueeze(-1) 209 | gamma, beta = style.chunk(2, 1) 210 | 211 | out = self.norm(input) 212 | out = gamma * out + beta 213 | return out 214 | 215 | 216 | class StyleSwinTransformerBlock(nn.Module): 217 | r""" StyleSwin Transformer Block. 218 | 219 | Args: 220 | dim (int): Number of input channels. 221 | input_resolution (tuple[int]): Input resulotion. 222 | num_heads (int): Number of attention heads. 223 | window_size (int): Window size. 224 | shift_size (int): Shift size for SW-MSA. 225 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 226 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 227 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 228 | drop (float, optional): Dropout rate. Default: 0.0 229 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 230 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 231 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 232 | style_dim (int): Dimension of style vector. 233 | """ 234 | 235 | def __init__(self, dim, input_resolution, num_heads, window_size=7, 236 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 237 | act_layer=nn.GELU, style_dim=512): 238 | super().__init__() 239 | self.dim = dim 240 | self.input_resolution = input_resolution 241 | self.num_heads = num_heads 242 | self.window_size = window_size 243 | self.mlp_ratio = mlp_ratio 244 | self.shift_size = self.window_size // 2 245 | self.style_dim = style_dim 246 | if min(self.input_resolution) <= self.window_size: 247 | # if window size is larger than input resolution, we don't partition windows 248 | self.shift_size = 0 249 | self.window_size = min(self.input_resolution) 250 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 251 | 252 | self.norm1 = AdaptiveInstanceNorm(dim, style_dim) 253 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 254 | self.proj = nn.Linear(dim, dim) 255 | self.attn = nn.ModuleList([ 256 | WindowAttention( 257 | dim // 2, window_size=to_2tuple(self.window_size), num_heads=num_heads // 2, 258 | qk_scale=qk_scale, attn_drop=attn_drop), 259 | WindowAttention( 260 | dim // 2, window_size=to_2tuple(self.window_size), num_heads=num_heads // 2, 261 | qk_scale=qk_scale, attn_drop=attn_drop), 262 | ]) 263 | 264 | attn_mask1 = None 265 | attn_mask2 = None 266 | if self.shift_size > 0: 267 | # calculate attention mask for SW-MSA 268 | H, W = self.input_resolution 269 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 270 | h_slices = (slice(0, -self.window_size), 271 | slice(-self.window_size, -self.shift_size), 272 | slice(-self.shift_size, None)) 273 | w_slices = (slice(0, -self.window_size), 274 | slice(-self.window_size, -self.shift_size), 275 | slice(-self.shift_size, None)) 276 | cnt = 0 277 | for h in h_slices: 278 | for w in w_slices: 279 | img_mask[:, h, w, :] = cnt 280 | cnt += 1 281 | 282 | # nW, window_size, window_size, 1 283 | mask_windows = window_partition(img_mask, self.window_size) 284 | mask_windows = mask_windows.view(-1, 285 | self.window_size * self.window_size) 286 | attn_mask2 = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 287 | attn_mask2 = attn_mask2.masked_fill( 288 | attn_mask2 != 0, float(-100.0)).masked_fill(attn_mask2 == 0, float(0.0)) 289 | 290 | self.register_buffer("attn_mask1", attn_mask1) 291 | self.register_buffer("attn_mask2", attn_mask2) 292 | 293 | self.norm2 = AdaptiveInstanceNorm(dim, style_dim) 294 | mlp_hidden_dim = int(dim * mlp_ratio) 295 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 296 | 297 | def forward(self, x, style): 298 | H, W = self.input_resolution 299 | B, L, C = x.shape 300 | assert L == H * W, "input feature has wrong size" 301 | 302 | # Double Attn 303 | shortcut = x 304 | x = self.norm1(x.transpose(-1, -2), style).transpose(-1, -2) 305 | 306 | qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3).reshape(3 * B, H, W, C) 307 | qkv_1 = qkv[:, :, :, : C // 2].reshape(3, B, H, W, C // 2) 308 | if self.shift_size > 0: 309 | qkv_2 = torch.roll(qkv[:, :, :, C // 2:], shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)).reshape(3, B, H, W, C // 2) 310 | else: 311 | qkv_2 = qkv[:, :, :, C // 2:].reshape(3, B, H, W, C // 2) 312 | 313 | q1_windows, k1_windows, v1_windows = self.get_window_qkv(qkv_1) 314 | q2_windows, k2_windows, v2_windows = self.get_window_qkv(qkv_2) 315 | 316 | x1 = self.attn[0](q1_windows, k1_windows, v1_windows, self.attn_mask1) 317 | x2 = self.attn[1](q2_windows, k2_windows, v2_windows, self.attn_mask2) 318 | 319 | x1 = window_reverse(x1.view(-1, self.window_size * self.window_size, C // 2), self.window_size, H, W) 320 | x2 = window_reverse(x2.view(-1, self.window_size * self.window_size, C // 2), self.window_size, H, W) 321 | 322 | if self.shift_size > 0: 323 | x2 = torch.roll(x2, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 324 | else: 325 | x2 = x2 326 | 327 | x = torch.cat([x1.reshape(B, H * W, C // 2), x2.reshape(B, H * W, C // 2)], dim=2) 328 | x = self.proj(x) 329 | 330 | # FFN 331 | x = shortcut + x 332 | x = x + self.mlp(self.norm2(x.transpose(-1, -2), style).transpose(-1, -2)) 333 | 334 | return x 335 | 336 | def get_window_qkv(self, qkv): 337 | q, k, v = qkv[0], qkv[1], qkv[2] # B, H, W, C 338 | C = q.shape[-1] 339 | q_windows = window_partition(q, self.window_size).view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 340 | k_windows = window_partition(k, self.window_size).view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 341 | v_windows = window_partition(v, self.window_size).view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 342 | return q_windows, k_windows, v_windows 343 | 344 | def extra_repr(self) -> str: 345 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 346 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 347 | 348 | def flops(self): 349 | flops = 0 350 | H, W = self.input_resolution 351 | # norm1 352 | flops += 1 * self.style_dim * self.dim * 2 353 | flops += 2 * (H * W) * self.dim 354 | # W-MSA/SW-MSA 355 | nW = H * W / self.window_size / self.window_size 356 | for attn in self.attn: 357 | flops += nW * (attn.flops(self.window_size * self.window_size)) 358 | # mlp 359 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 360 | # norm2 361 | flops += 1 * self.style_dim * self.dim * 2 362 | flops += 2 * (H * W) * self.dim 363 | return flops 364 | 365 | 366 | class StyleBasicLayer(nn.Module): 367 | """ A basic StyleSwin layer for one stage. 368 | 369 | Args: 370 | dim (int): Number of input channels. 371 | input_resolution (tuple[int]): Input resolution. 372 | depth (int): Number of blocks. 373 | num_heads (int): Number of attention heads. 374 | window_size (int): Local window size. 375 | out_dim (int): Number of output channels. 376 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 377 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 378 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 379 | drop (float, optional): Dropout rate. Default: 0.0 380 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 381 | upsample (nn.Module | None, optional): Upsample layer at the end of the layer. Default: None 382 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 383 | style_dim (int): Dimension of style vector. 384 | """ 385 | 386 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, out_dim=None, 387 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., upsample=None, 388 | use_checkpoint=False, style_dim=512): 389 | 390 | super().__init__() 391 | self.dim = dim 392 | self.input_resolution = input_resolution 393 | self.depth = depth 394 | self.use_checkpoint = use_checkpoint 395 | 396 | # build blocks 397 | self.blocks = nn.ModuleList([ 398 | StyleSwinTransformerBlock(dim=dim, input_resolution=input_resolution, 399 | num_heads=num_heads, window_size=window_size, 400 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 401 | drop=drop, attn_drop=attn_drop, style_dim=style_dim) 402 | for _ in range(depth)]) 403 | 404 | if upsample is not None: 405 | self.upsample = upsample(input_resolution, dim=dim, out_dim=out_dim) 406 | else: 407 | self.upsample = None 408 | 409 | def forward(self, x, latent1, latent2): 410 | if self.use_checkpoint: 411 | x = checkpoint.checkpoint(self.blocks[0], x, latent1) 412 | x = checkpoint.checkpoint(self.blocks[1], x, latent2) 413 | else: 414 | x = self.blocks[0](x, latent1) 415 | x = self.blocks[1](x, latent2) 416 | 417 | if self.upsample is not None: 418 | x = self.upsample(x) 419 | return x 420 | 421 | def extra_repr(self) -> str: 422 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 423 | 424 | def flops(self): 425 | flops = 0 426 | for blk in self.blocks: 427 | flops += blk.flops() 428 | if self.upsample is not None: 429 | flops += self.upsample.flops() 430 | return flops 431 | 432 | 433 | class BilinearUpsample(nn.Module): 434 | """ BilinearUpsample Layer. 435 | 436 | Args: 437 | input_resolution (tuple[int]): Resolution of input feature. 438 | dim (int): Number of input channels. 439 | out_dim (int): Number of output channels. 440 | """ 441 | 442 | def __init__(self, input_resolution, dim, out_dim=None): 443 | super().__init__() 444 | assert dim % 2 == 0, f"x dim are not even." 445 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 446 | self.norm = nn.LayerNorm(dim) 447 | self.reduction = nn.Linear(dim, out_dim, bias=False) 448 | self.input_resolution = input_resolution 449 | self.dim = dim 450 | self.out_dim = out_dim 451 | self.alpha = nn.Parameter(torch.zeros(1)) 452 | self.sin_pos_embed = SinusoidalPositionalEmbedding(embedding_dim=out_dim // 2, padding_idx=0, init_size=out_dim // 2) 453 | 454 | def forward(self, x): 455 | """ 456 | x: B, H*W, C 457 | """ 458 | H, W = self.input_resolution 459 | B, L, C = x.shape 460 | assert L == H * W, "input feature has wrong size" 461 | assert C == self.dim, "wrong in PatchMerging" 462 | 463 | x = x.view(B, H, W, -1) 464 | x = x.permute(0, 3, 1, 2).contiguous() # B,C,H,W 465 | x = self.upsample(x) 466 | x = x.permute(0, 2, 3, 1).contiguous().view(B, L*4, C) # B,H,W,C 467 | x = self.norm(x) 468 | x = self.reduction(x) 469 | 470 | # Add SPE 471 | x = x.reshape(B, H * 2, W * 2, self.out_dim).permute(0, 3, 1, 2) 472 | x += self.sin_pos_embed.make_grid2d(H * 2, W * 2, B) * self.alpha 473 | x = x.permute(0, 2, 3, 1).contiguous().view(B, H * 2 * W * 2, self.out_dim) 474 | return x 475 | 476 | def extra_repr(self) -> str: 477 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 478 | 479 | def flops(self): 480 | H, W = self.input_resolution 481 | # LN 482 | flops = 4 * H * W * self.dim 483 | # proj 484 | flops += 4 * H * W * self.dim * (self.out_dim) 485 | # SPE 486 | flops += 4 * H * W * 2 487 | # bilinear 488 | flops += 4 * self.input_resolution[0] * self.input_resolution[1] * self.dim * 5 489 | return flops 490 | 491 | 492 | class ConstantInput(nn.Module): 493 | def __init__(self, channel, size=4): 494 | super().__init__() 495 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 496 | 497 | def forward(self, input): 498 | batch = input.shape[0] 499 | out = self.input.repeat(batch, 1, 1, 1) 500 | 501 | return out 502 | 503 | 504 | class Generator(nn.Module): 505 | def __init__( 506 | self, 507 | size, 508 | style_dim, 509 | n_mlp, 510 | channel_multiplier=2, 511 | lr_mlp=0.01, 512 | enable_full_resolution=8, 513 | mlp_ratio=4, 514 | use_checkpoint=False, 515 | qkv_bias=True, 516 | qk_scale=None, 517 | drop_rate=0, 518 | attn_drop_rate=0, 519 | ): 520 | super().__init__() 521 | self.style_dim = style_dim 522 | self.size = size 523 | self.mlp_ratio = mlp_ratio 524 | 525 | layers = [PixelNorm()] 526 | for _ in range(n_mlp): 527 | layers.append( 528 | EqualLinear( 529 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 530 | ) 531 | ) 532 | self.style = nn.Sequential(*layers) 533 | 534 | start = 2 535 | depths = [2, 2, 2, 2, 2, 2, 2, 2, 2] 536 | in_channels = [ 537 | 512, 538 | 512, 539 | 512, 540 | 512, 541 | 256 * channel_multiplier, 542 | 128 * channel_multiplier, 543 | 64 * channel_multiplier, 544 | 32 * channel_multiplier, 545 | 16 * channel_multiplier 546 | ] 547 | 548 | end = int(math.log(size, 2)) 549 | num_heads = [max(c // 32, 4) for c in in_channels] 550 | full_resolution_index = int(math.log(enable_full_resolution, 2)) 551 | window_sizes = [2 ** i if i <= full_resolution_index else 8 for i in range(start, end + 1)] 552 | 553 | self.input = ConstantInput(in_channels[0]) 554 | self.layers = nn.ModuleList() 555 | self.to_rgbs = nn.ModuleList() 556 | num_layers = 0 557 | 558 | for i_layer in range(start, end + 1): 559 | in_channel = in_channels[i_layer - start] 560 | layer = StyleBasicLayer(dim=in_channel, 561 | input_resolution=(2 ** i_layer,2 ** i_layer), 562 | depth=depths[i_layer - start], 563 | num_heads=num_heads[i_layer - start], 564 | window_size=window_sizes[i_layer - start], 565 | out_dim=in_channels[i_layer - start + 1] if (i_layer < end) else None, 566 | mlp_ratio=self.mlp_ratio, 567 | qkv_bias=qkv_bias, qk_scale=qk_scale, 568 | drop=drop_rate, attn_drop=attn_drop_rate, 569 | upsample=BilinearUpsample if (i_layer < end) else None, 570 | use_checkpoint=use_checkpoint, style_dim=style_dim) 571 | self.layers.append(layer) 572 | 573 | out_dim = in_channels[i_layer - start + 1] if (i_layer < end) else in_channels[i_layer - start] 574 | upsample = True if (i_layer < end) else False 575 | to_rgb = ToRGB(out_dim, upsample=upsample, resolution=(2 ** i_layer)) 576 | self.to_rgbs.append(to_rgb) 577 | num_layers += 2 578 | 579 | self.n_latent = num_layers 580 | self.apply(self._init_weights) 581 | 582 | def _init_weights(self, m): 583 | if isinstance(m, nn.Linear): 584 | trunc_normal_(m.weight, std=.02) 585 | if isinstance(m, nn.Linear) and m.bias is not None: 586 | nn.init.constant_(m.bias, 0) 587 | elif isinstance(m, nn.LayerNorm): 588 | if m.bias is not None: 589 | nn.init.constant_(m.bias, 0) 590 | if m.weight is not None: 591 | nn.init.constant_(m.weight, 1.0) 592 | elif isinstance(m, nn.Conv2d): 593 | nn.init.xavier_normal_(m.weight, gain=.02) 594 | if hasattr(m, 'bias') and m.bias is not None: 595 | nn.init.constant_(m.bias, 0) 596 | 597 | def forward( 598 | self, 599 | noise, 600 | return_latents=False, 601 | inject_index=None, 602 | truncation=1, 603 | truncation_latent=None, 604 | ): 605 | styles = self.style(noise) 606 | inject_index = self.n_latent 607 | 608 | if truncation < 1: 609 | style_t = [] 610 | for style in styles: 611 | style_t.append( 612 | truncation_latent + truncation * (style - truncation_latent) 613 | ) 614 | 615 | styles = torch.cat(style_t, dim=0) 616 | 617 | if styles.ndim < 3: 618 | latent = styles.unsqueeze(1).repeat(1, inject_index, 1) 619 | else: 620 | latent = styles 621 | 622 | x = self.input(latent) 623 | B, C, H, W = x.shape 624 | x = x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) 625 | 626 | count = 0 627 | skip = None 628 | for layer, to_rgb in zip(self.layers, self.to_rgbs): 629 | x = layer(x, latent[:,count,:], latent[:,count+1,:]) 630 | b, n, c = x.shape 631 | h, w = int(math.sqrt(n)), int(math.sqrt(n)) 632 | skip = to_rgb(x.transpose(-1, -2).reshape(b, c, h, w), skip) 633 | count = count + 2 634 | 635 | B, L, C = x.shape 636 | assert L == self.size * self.size 637 | x = x.reshape(B, self.size, self.size, C).permute(0, 3, 1, 2).contiguous() 638 | image = skip 639 | 640 | if return_latents: 641 | return image, latent 642 | else: 643 | return image, None 644 | 645 | def flops(self): 646 | flops = 0 647 | for _, layer in enumerate(self.layers): 648 | flops += layer.flops() 649 | for _, layer in enumerate(self.to_rgbs): 650 | flops += layer.flops() 651 | # 8 FC + PixelNorm 652 | flops += 1 * 10 * self.style_dim * self.style_dim 653 | return flops 654 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 5 | from .upfirdn2d import upfirdn2d 6 | 7 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | from torch.utils.cpp_extension import load 11 | 12 | from torch.cuda.amp import custom_fwd, custom_bwd 13 | 14 | module_path = os.path.dirname(__file__) 15 | fused = load( 16 | "fused", 17 | sources=[ 18 | os.path.join(module_path, "fused_bias_act.cpp"), 19 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 20 | ], 21 | ) 22 | 23 | 24 | class FusedLeakyReLUFunctionBackward(Function): 25 | @staticmethod 26 | def forward(ctx, grad_output, out, negative_slope, scale): 27 | ctx.save_for_backward(out) 28 | ctx.negative_slope = negative_slope 29 | ctx.scale = scale 30 | 31 | empty = grad_output.new_empty(0) 32 | 33 | grad_input = fused.fused_bias_act( 34 | grad_output, empty, out, 3, 1, negative_slope, scale 35 | ) 36 | 37 | dim = [0] 38 | 39 | if grad_input.ndim > 2: 40 | dim += list(range(2, grad_input.ndim)) 41 | 42 | grad_bias = grad_input.sum(dim).detach() 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 51 | ) 52 | 53 | return gradgrad_out, None, None, None 54 | 55 | 56 | class FusedLeakyReLUFunction(Function): 57 | @staticmethod 58 | @custom_fwd(cast_inputs=torch.float32) 59 | def forward(ctx, input, bias, negative_slope, scale): 60 | empty = input.new_empty(0) 61 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 62 | ctx.save_for_backward(out) 63 | ctx.negative_slope = negative_slope 64 | ctx.scale = scale 65 | 66 | return out 67 | 68 | @staticmethod 69 | @custom_bwd 70 | def backward(ctx, grad_output): 71 | out, = ctx.saved_tensors 72 | 73 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 74 | grad_output, out, ctx.negative_slope, ctx.scale 75 | ) 76 | 77 | return grad_input, grad_bias, None, None 78 | 79 | 80 | class FusedLeakyReLU(nn.Module): 81 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 82 | super().__init__() 83 | 84 | self.bias = nn.Parameter(torch.zeros(channel)) 85 | self.negative_slope = negative_slope 86 | self.scale = scale 87 | 88 | def forward(self, input): 89 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 90 | 91 | 92 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 93 | if input.device.type == "cpu": 94 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 95 | return ( 96 | F.leaky_relu( 97 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 98 | ) 99 | * scale 100 | ) 101 | 102 | else: 103 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 104 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | 7 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 15 | int act, int grad, float alpha, float scale) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(bias); 18 | 19 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 24 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include 5 | 6 | 7 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 8 | int up_x, int up_y, int down_x, int down_y, 9 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 10 | 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | 15 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 16 | int up_x, int up_y, int down_x, int down_y, 17 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(kernel); 20 | 21 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 26 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | 6 | import torch 7 | from torch.nn import functional as F 8 | from torch.autograd import Function 9 | from torch.utils.cpp_extension import load 10 | 11 | 12 | module_path = os.path.dirname(__file__) 13 | upfirdn2d_op = load( 14 | "upfirdn2d", 15 | sources=[ 16 | os.path.join(module_path, "upfirdn2d.cpp"), 17 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 18 | ], 19 | ) 20 | 21 | 22 | class UpFirDn2dBackward(Function): 23 | @staticmethod 24 | def forward( 25 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 26 | ): 27 | 28 | up_x, up_y = up 29 | down_x, down_y = down 30 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 31 | 32 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 33 | 34 | grad_input = upfirdn2d_op.upfirdn2d( 35 | grad_output, 36 | grad_kernel, 37 | down_x, 38 | down_y, 39 | up_x, 40 | up_y, 41 | g_pad_x0, 42 | g_pad_x1, 43 | g_pad_y0, 44 | g_pad_y1, 45 | ) 46 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 47 | 48 | ctx.save_for_backward(kernel) 49 | 50 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 51 | 52 | ctx.up_x = up_x 53 | ctx.up_y = up_y 54 | ctx.down_x = down_x 55 | ctx.down_y = down_y 56 | ctx.pad_x0 = pad_x0 57 | ctx.pad_x1 = pad_x1 58 | ctx.pad_y0 = pad_y0 59 | ctx.pad_y1 = pad_y1 60 | ctx.in_size = in_size 61 | ctx.out_size = out_size 62 | 63 | return grad_input 64 | 65 | @staticmethod 66 | def backward(ctx, gradgrad_input): 67 | kernel, = ctx.saved_tensors 68 | 69 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 70 | 71 | gradgrad_out = upfirdn2d_op.upfirdn2d( 72 | gradgrad_input, 73 | kernel, 74 | ctx.up_x, 75 | ctx.up_y, 76 | ctx.down_x, 77 | ctx.down_y, 78 | ctx.pad_x0, 79 | ctx.pad_x1, 80 | ctx.pad_y0, 81 | ctx.pad_y1, 82 | ) 83 | gradgrad_out = gradgrad_out.view( 84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 85 | ) 86 | 87 | return gradgrad_out, None, None, None, None, None, None, None, None 88 | 89 | 90 | class UpFirDn2d(Function): 91 | @staticmethod 92 | def forward(ctx, input, kernel, up, down, pad): 93 | up_x, up_y = up 94 | down_x, down_y = down 95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 96 | 97 | kernel_h, kernel_w = kernel.shape 98 | batch, channel, in_h, in_w = input.shape 99 | ctx.in_size = input.shape 100 | 101 | input = input.reshape(-1, in_h, in_w, 1) 102 | 103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 104 | 105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 107 | ctx.out_size = (out_h, out_w) 108 | 109 | ctx.up = (up_x, up_y) 110 | ctx.down = (down_x, down_y) 111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 112 | 113 | g_pad_x0 = kernel_w - pad_x0 - 1 114 | g_pad_y0 = kernel_h - pad_y0 - 1 115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 117 | 118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 119 | 120 | out = upfirdn2d_op.upfirdn2d( 121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 122 | ) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = UpFirDn2dBackward.apply( 132 | grad_output, 133 | kernel, 134 | grad_kernel, 135 | ctx.up, 136 | ctx.down, 137 | ctx.pad, 138 | ctx.g_pad, 139 | ctx.in_size, 140 | ctx.out_size, 141 | ) 142 | 143 | return grad_input, None, None, None, None 144 | 145 | 146 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 147 | if input.device.type == "cpu": 148 | out = upfirdn2d_native( 149 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 150 | ) 151 | 152 | else: 153 | out = UpFirDn2d.apply( 154 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 155 | ) 156 | 157 | return out 158 | 159 | 160 | def upfirdn2d_native( 161 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 162 | ): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad( 174 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 175 | ) 176 | out = out[ 177 | :, 178 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 179 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 180 | :, 181 | ] 182 | 183 | out = out.permute(0, 3, 1, 2) 184 | out = out.reshape( 185 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 186 | ) 187 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 188 | out = F.conv2d(out, w) 189 | out = out.reshape( 190 | -1, 191 | minor, 192 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 193 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 194 | ) 195 | out = out.permute(0, 2, 3, 1) 196 | out = out[:, ::down_y, ::down_x, :] 197 | 198 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 199 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 200 | 201 | return out.view(-1, channel, out_h, out_w) 202 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | lmdb 3 | timm 4 | scipy 5 | sklearn 6 | einops 7 | tensorflow==1.15.0 8 | tqdm 9 | wandb 10 | -------------------------------------------------------------------------------- /train_styleswin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | import builtins 6 | import os 7 | import sys 8 | from datetime import timedelta 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import torchvision 13 | import torchvision.datasets as datasets 14 | from torch import autograd, nn, optim 15 | from torch.nn import functional as F 16 | from torch.utils import data 17 | from torchvision import transforms 18 | 19 | try: 20 | import wandb 21 | except ImportError: 22 | wandb = None 23 | 24 | import time 25 | 26 | from dataset.dataset import MultiResolutionDataset 27 | from models.discriminator import Discriminator 28 | from models.generator import Generator 29 | from utils import fid_score 30 | from utils.CRDiffAug import CR_DiffAug 31 | from utils.distributed import get_rank, reduce_loss_dict, synchronize 32 | 33 | 34 | def data_sampler(dataset, shuffle, distributed): 35 | if distributed: 36 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 37 | if shuffle: 38 | return data.RandomSampler(dataset) 39 | else: 40 | return data.SequentialSampler(dataset) 41 | 42 | 43 | def requires_grad(model, flag=True): 44 | for p in model.parameters(): 45 | p.requires_grad = flag 46 | 47 | 48 | def accumulate(model1, model2, decay=0.999): 49 | par1 = dict(model1.named_parameters()) 50 | par2 = dict(model2.named_parameters()) 51 | 52 | for k in par1.keys(): 53 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 54 | 55 | 56 | def sample_data(loader): 57 | while True: 58 | for batch in loader: 59 | yield batch 60 | 61 | 62 | def d_logistic_loss(real_pred, fake_pred): 63 | assert type(real_pred) == type(fake_pred), "real_pred must be the same type as fake_pred" 64 | real_loss = F.softplus(-real_pred) 65 | fake_loss = F.softplus(fake_pred) 66 | return real_loss.mean() + fake_loss.mean() 67 | 68 | 69 | def d_r1_loss(real_pred, real_img): 70 | grad_real, = autograd.grad( 71 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 72 | ) 73 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 74 | return grad_penalty 75 | 76 | 77 | def g_nonsaturating_loss(fake_pred): 78 | loss = F.softplus(-fake_pred).mean() 79 | return loss 80 | 81 | 82 | def set_grad_none(model, targets): 83 | for n, p in model.named_parameters(): 84 | if n in targets: 85 | p.grad = None 86 | 87 | 88 | def tensor_transform_reverse(image): 89 | assert image.dim() == 4 90 | moco_input = torch.zeros(image.size()).type_as(image) 91 | moco_input[:,0,:,:] = image[:,0,:,:] * 0.229 + 0.485 92 | moco_input[:,1,:,:] = image[:,1,:,:] * 0.224 + 0.456 93 | moco_input[:,2,:,:] = image[:,2,:,:] * 0.225 + 0.406 94 | return moco_input 95 | 96 | def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device): 97 | if get_rank() == 0 and args.tf_log: 98 | from utils.visualizer import Visualizer 99 | vis = Visualizer(args) 100 | 101 | loader = sample_data(loader) 102 | 103 | d_loss_val = 0 104 | r1_loss = torch.tensor(0.0, device=device) 105 | g_loss_val = 0 106 | accum = 0.5 ** (32 / (10 * 1000)) 107 | loss_dict = {} 108 | l2_loss = torch.nn.MSELoss() 109 | loss_dict = {} 110 | 111 | if args.distributed: 112 | g_module = generator.module 113 | d_module = discriminator.module 114 | else: 115 | g_module = generator 116 | d_module = discriminator 117 | 118 | print(" -- start training -- ") 119 | end = time.time() 120 | if args.ttur: 121 | args.G_lr = args.D_lr / 4 122 | if args.lr_decay: 123 | lr_decay_per_step = args.G_lr / (args.iter - args.lr_decay_start_steps) 124 | 125 | for idx in range(args.iter): 126 | i = idx + args.start_iter 127 | if i > args.iter: 128 | print("Done!") 129 | break 130 | 131 | # Train D 132 | generator.train() 133 | if not args.lmdb: 134 | this_data = next(loader) 135 | real_img = this_data[0] 136 | else: 137 | real_img = next(loader) 138 | real_img = real_img.to(device) 139 | 140 | requires_grad(generator, False) 141 | requires_grad(discriminator, True) 142 | noise = torch.randn((args.batch, 512)).cuda() 143 | 144 | fake_img, _ = generator(noise) 145 | fake_pred = discriminator(fake_img) 146 | real_pred = discriminator(real_img) 147 | d_loss = d_logistic_loss(real_pred, fake_pred) * args.gan_weight 148 | 149 | if args.bcr: 150 | real_img_cr_aug = CR_DiffAug(real_img) 151 | fake_img_cr_aug = CR_DiffAug(fake_img) 152 | fake_pred_aug = discriminator(fake_img_cr_aug) 153 | real_pred_aug = discriminator(real_img_cr_aug) 154 | d_loss += args.bcr_fake_lambda * l2_loss(fake_pred_aug, fake_pred) \ 155 | + args.bcr_real_lambda * l2_loss(real_pred_aug, real_pred) 156 | 157 | loss_dict["d"] = d_loss 158 | 159 | discriminator.zero_grad() 160 | d_loss.backward() 161 | nn.utils.clip_grad_norm_(discriminator.parameters(), 5.0) 162 | d_optim.step() 163 | 164 | d_regularize = i % args.d_reg_every == 0 165 | if d_regularize: 166 | real_img.requires_grad = True 167 | 168 | real_pred = discriminator(real_img) 169 | r1_loss = d_r1_loss(real_pred, real_img) 170 | 171 | discriminator.zero_grad() 172 | (args.gan_weight * (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0])).backward() 173 | 174 | d_optim.step() 175 | 176 | loss_dict["r1"] = r1_loss 177 | 178 | # Train G 179 | requires_grad(generator, True) 180 | requires_grad(discriminator, False) 181 | 182 | if not args.lmdb: 183 | this_data = next(loader) 184 | real_img = this_data[0] 185 | else: 186 | real_img = next(loader) 187 | real_img = real_img.to(device) 188 | 189 | noise = torch.randn((args.batch, 512)).cuda() 190 | fake_img, _ = generator(noise) 191 | fake_pred = discriminator(fake_img) 192 | g_loss = g_nonsaturating_loss(fake_pred)* args.gan_weight 193 | 194 | loss_dict["g"] = g_loss 195 | generator.zero_grad() 196 | g_loss.backward() 197 | g_optim.step() 198 | 199 | accumulate(g_ema, g_module, accum) 200 | 201 | # Finish one iteration and reduce loss dict 202 | loss_reduced = reduce_loss_dict(loss_dict) 203 | d_loss_val = loss_reduced["d"].mean().item() 204 | g_loss_val = loss_reduced["g"].mean().item() 205 | r1_val = loss_reduced["r1"].mean().item() 206 | 207 | if args.lr_decay and i > args.lr_decay_start_steps: 208 | args.G_lr -= lr_decay_per_step 209 | args.D_lr = args.G_lr * 4 if args.ttur else (args.D_lr - lr_decay_per_step) 210 | 211 | for param_group in d_optim.param_groups: 212 | param_group['lr'] = args.D_lr 213 | for param_group in g_optim.param_groups: 214 | param_group['lr'] = args.G_lr 215 | 216 | # Log, save and evaluate 217 | if get_rank() == 0: 218 | if i % args.print_freq == 0: 219 | vis_loss = { 220 | 'd_loss': d_loss_val, 221 | 'g_loss': g_loss_val, 222 | 'r1_val': r1_val, 223 | } 224 | if wandb and args.wandb: 225 | wandb.log(vis_loss, step=i) 226 | iters_time = time.time() - end 227 | end = time.time() 228 | if args.lr_decay: 229 | print("Iters: {}\tTime: {:.4f}\tD_loss: {:.4f}\tG_loss: {:.4f}\tR1: {:.4f}\tG_lr: {:e}\tD_lr: {:e}".format(i, iters_time, d_loss_val, g_loss_val, r1_val, args.G_lr, args.D_lr)) 230 | else: 231 | print("Iters: {}\tTime: {:.4f}\tD_loss: {:.4f}\tG_loss: {:.4f}\tR1: {:.4f}".format(i, iters_time, d_loss_val, g_loss_val, r1_val)) 232 | if args.tf_log: 233 | vis.plot_dict(vis_loss, step=(i * args.batch * int(os.environ["WORLD_SIZE"]))) 234 | 235 | if i != 0 and i % args.eval_freq == 0: 236 | torch.save( 237 | { 238 | "g": g_module.state_dict(), 239 | "d": d_module.state_dict(), 240 | "g_ema": g_ema.state_dict(), 241 | "g_optim": g_optim.state_dict(), 242 | "d_optim": d_optim.state_dict(), 243 | "args": args, 244 | }, 245 | args.checkpoint_path + f"/{str(i).zfill(6)}.pt", 246 | ) 247 | 248 | print("=> Evaluation ...") 249 | g_ema.eval() 250 | fid1 = evaluation(g_ema, args, i * args.batch * int(os.environ["WORLD_SIZE"])) 251 | fid_dict = {'fid1': fid1} 252 | if wandb and args.wandb: 253 | wandb.log({'fid': fid1}, step=i) 254 | if args.tf_log: 255 | vis.plot_dict(fid_dict, step=(i * args.batch * int(os.environ["WORLD_SIZE"]))) 256 | 257 | if i % args.save_freq == 0: 258 | torch.save( 259 | { 260 | "g": g_module.state_dict(), 261 | "d": d_module.state_dict(), 262 | "g_ema": g_ema.state_dict(), 263 | "g_optim": g_optim.state_dict(), 264 | "d_optim": d_optim.state_dict(), 265 | "args": args, 266 | }, 267 | args.checkpoint_path + f"/{str(i).zfill(6)}.pt", 268 | ) 269 | 270 | 271 | def evaluation(generator, args, steps): 272 | cnt = 0 273 | 274 | for _ in tqdm(range(args.val_num_batches)): 275 | with torch.no_grad(): 276 | noise = torch.randn((args.val_batch_size, 512)).cuda() 277 | 278 | out_sample, _ = generator(noise) 279 | out_sample = tensor_transform_reverse(out_sample) 280 | 281 | if not os.path.exists(os.path.join(args.sample_path, "eval_{}".format(str(steps)))): 282 | os.mkdir(os.path.join(args.sample_path, 283 | "eval_{}".format(str(steps)))) 284 | 285 | for j in range(args.val_batch_size): 286 | torchvision.utils.save_image( 287 | out_sample[j], 288 | os.path.join(args.sample_path, "eval_{}".format( 289 | str(steps))) + f"/{str(cnt).zfill(6)}.png", 290 | nrow=1, 291 | padding=0, 292 | normalize=True, 293 | range=(0, 1), 294 | ) 295 | cnt += 1 296 | 297 | gt_path = args.eval_gt_path 298 | device = torch.device('cuda:0') 299 | fid = fid_score.calculate_fid_given_paths([os.path.join(args.sample_path, "eval_{}".format( 300 | str(steps))), gt_path], batch_size=args.val_batch_size, device=device, dims=2048) 301 | 302 | print("Fid Score : ({:.2f}, {:.1f}M)".format(fid, steps / 1000000)) 303 | 304 | return fid 305 | 306 | 307 | if __name__ == "__main__": 308 | device = "cuda" 309 | 310 | parser = argparse.ArgumentParser() 311 | 312 | parser.add_argument("--path", type=str, default=None, help="Path of training data") 313 | parser.add_argument("--iter", type=int, default=800000) 314 | parser.add_argument("--batch", type=int, default=4) 315 | parser.add_argument("--size", type=int, default=256) 316 | parser.add_argument("--style_dim", type=int, default=512) 317 | parser.add_argument("--r1", type=float, default=10) 318 | parser.add_argument("--d_reg_every", type=int, default=16) 319 | parser.add_argument("--ckpt", type=str, default=None) 320 | parser.add_argument("--G_lr", type=float, default=0.0002) 321 | parser.add_argument("--D_lr", type=float, default=0.0002) 322 | parser.add_argument("--beta1", type=float, default=0.0) 323 | parser.add_argument("--beta2", type=float, default=0.99) 324 | parser.add_argument("--start_dim", type=int, default=512, help="Start dim of generator input dim") 325 | parser.add_argument("--D_channel_multiplier", type=int, default=2) 326 | parser.add_argument("--G_channel_multiplier", type=int, default=1) 327 | parser.add_argument("--local_rank", type=int, default=0) 328 | parser.add_argument("--print_freq", type=int, default=1000) 329 | parser.add_argument("--save_freq", type=int, default=20000) 330 | parser.add_argument("--eval_freq", type=int, default=50000) 331 | parser.add_argument('--workers', default=8, type=int, help='Number of workers') 332 | 333 | parser.add_argument('--checkpoint_path', default='/tmp', type=str, help='Save checkpoints') 334 | parser.add_argument('--sample_path', default='/tmp', type=str, help='Save sample') 335 | parser.add_argument('--start_iter', default=0, type=int, help='Start iter number') 336 | parser.add_argument('--tf_log', action="store_true", help='If we use tensorboard file') 337 | parser.add_argument('--gan_weight', default=1, type=float, help='Gan loss weight') 338 | parser.add_argument('--val_num_batches', default=1250, type=int, help='Num of batches will be generated during evalution') 339 | parser.add_argument('--val_batch_size', default=4, type=int, help='Batch size during evalution') 340 | parser.add_argument('--D_sn', action="store_true", help='If we use spectral norm in D') 341 | parser.add_argument('--ttur', action="store_true", help='If we use TTUR during training') 342 | parser.add_argument('--eval', action="store_true", help='Only do evaluation') 343 | parser.add_argument("--eval_iters", type=int, default=0, help="Iters of evaluation ckpt") 344 | parser.add_argument('--eval_gt_path', default='/tmp', type=str, help='Path to ground truth images to evaluate FID score') 345 | parser.add_argument('--mlp_ratio', default=4, type=int, help='MLP ratio in swin') 346 | parser.add_argument("--lr_mlp", default=0.01, type=float, help='Lr mul for 8 * fc') 347 | parser.add_argument("--bcr", action="store_true", help='If we add bcr during training') 348 | parser.add_argument("--bcr_fake_lambda", default=10, type=float, help='Bcr weight for fake data') 349 | parser.add_argument("--bcr_real_lambda", default=10, type=float, help='Bcr weight for real data') 350 | parser.add_argument("--enable_full_resolution", default=8, type=int, help='Enable full resolution attention index') 351 | parser.add_argument("--auto_resume", action="store_true", help="Auto resume from checkpoint") 352 | parser.add_argument("--lmdb", action="store_true", help='Whether to use lmdb datasets') 353 | parser.add_argument("--use_checkpoint", action="store_true", help='Whether to use checkpoint') 354 | parser.add_argument("--use_flip", action="store_true", help='Whether to use random flip in training') 355 | parser.add_argument("--wandb", action="store_true", help='Whether to use wandb record training') 356 | parser.add_argument("--project_name", type=str, default='StyleSwin', help='Project name') 357 | parser.add_argument("--lr_decay", action="store_true", help='Whether to use lr decay') 358 | parser.add_argument("--lr_decay_start_steps", default=800000, type=int, help='Steps to start lr decay') 359 | 360 | args = parser.parse_args() 361 | 362 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 363 | args.distributed = n_gpu > 1 364 | 365 | args.latent = 4096 366 | args.n_mlp = 8 367 | args.g_reg_every = 10000000 # We do not apply regularization on G 368 | 369 | if args.distributed: 370 | torch.cuda.set_device(args.local_rank) 371 | torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(0, 18000)) 372 | synchronize() 373 | 374 | if args.distributed and get_rank() != 0: 375 | def print_pass(*args): 376 | pass 377 | builtins.print = print_pass 378 | 379 | if get_rank() == 0: 380 | args.sample_path = os.path.join(args.sample_path, 'samples') 381 | if not os.path.exists(args.sample_path): 382 | os.mkdir(args.sample_path) 383 | 384 | generator = Generator( 385 | args.size, args.style_dim, args.n_mlp, channel_multiplier=args.G_channel_multiplier, lr_mlp=args.lr_mlp, 386 | enable_full_resolution=args.enable_full_resolution, use_checkpoint=args.use_checkpoint 387 | ).to(device) 388 | discriminator = Discriminator(args.size, channel_multiplier=args.D_channel_multiplier, sn=args.D_sn).to(device) 389 | g_ema = Generator( 390 | args.size, args.style_dim, args.n_mlp, channel_multiplier=args.G_channel_multiplier, lr_mlp=args.lr_mlp, 391 | enable_full_resolution=args.enable_full_resolution, use_checkpoint=args.use_checkpoint 392 | ).to(device) 393 | g_ema.eval() 394 | accumulate(g_ema, generator, 0) 395 | 396 | g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) 397 | d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) 398 | 399 | # Load model checkpoint. 400 | if args.ckpt is not None: 401 | print("load model: ", args.ckpt) 402 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 403 | ckpt_name = os.path.basename(args.ckpt) 404 | try: 405 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 406 | except: 407 | pass 408 | 409 | generator.load_state_dict(ckpt["g"]) 410 | g_ema.load_state_dict(ckpt["g_ema"]) 411 | try: 412 | discriminator.load_state_dict(ckpt["d"]) 413 | except: 414 | print("We don't load D.") 415 | 416 | print("-" * 80) 417 | print("Generator: ") 418 | print(generator) 419 | print("-" * 80) 420 | print("Discriminator: ") 421 | print(discriminator) 422 | 423 | if args.distributed: 424 | generator = nn.parallel.DistributedDataParallel( 425 | generator, 426 | device_ids=[args.local_rank], 427 | output_device=args.local_rank, 428 | broadcast_buffers=False, 429 | ) 430 | 431 | discriminator = nn.parallel.DistributedDataParallel( 432 | discriminator, 433 | device_ids=[args.local_rank], 434 | output_device=args.local_rank, 435 | broadcast_buffers=False, 436 | ) 437 | 438 | g_optim = optim.Adam( 439 | generator.parameters(), 440 | lr=args.G_lr * g_reg_ratio if not args.ttur else args.D_lr / 4 * g_reg_ratio, 441 | betas=(args.beta1 ** g_reg_ratio, args.beta2 ** g_reg_ratio), 442 | ) 443 | d_optim = optim.Adam( 444 | discriminator.parameters(), 445 | lr=args.D_lr * d_reg_ratio, 446 | betas=(args.beta1 ** d_reg_ratio, args.beta2 ** d_reg_ratio), 447 | ) 448 | 449 | # Load optimizer checkpoint. 450 | if args.ckpt is not None: 451 | print("load optimizer: ", args.ckpt) 452 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 453 | ckpt_name = os.path.basename(args.ckpt) 454 | 455 | try: 456 | g_optim.load_state_dict(ckpt["g_optim"]) 457 | d_optim.load_state_dict(ckpt["d_optim"]) 458 | except: 459 | print("We don't load optimizers.") 460 | 461 | if args.eval: 462 | if get_rank() == 0: 463 | g_ema.eval() 464 | evaluation(g_ema, args, (args.eval_iters * args.batch * int(os.environ["WORLD_SIZE"]))) 465 | sys.exit(0) 466 | sys.exit(0) 467 | 468 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 469 | if args.use_flip: 470 | transform = transforms.Compose( 471 | [ 472 | transforms.Resize((args.size, args.size)), 473 | transforms.RandomHorizontalFlip(), 474 | transforms.ToTensor(), 475 | normalize 476 | ] 477 | ) 478 | else: 479 | transform = transforms.Compose( 480 | [ 481 | transforms.Resize((args.size, args.size)), 482 | transforms.ToTensor(), 483 | normalize 484 | ] 485 | ) 486 | 487 | if args.lmdb: 488 | dataset = MultiResolutionDataset(args.path, transform, args.size) 489 | else: 490 | dataset = datasets.ImageFolder(root=args.path, transform=transform) 491 | 492 | loader = data.DataLoader( 493 | dataset, 494 | batch_size=args.batch, 495 | num_workers=args.workers, 496 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 497 | drop_last=True, 498 | ) 499 | 500 | if get_rank() == 0 and wandb is not None and args.wandb: 501 | wandb.init(project=args.project_name) 502 | 503 | train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, device) 504 | -------------------------------------------------------------------------------- /utils/CRDiffAug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | def CR_DiffAug(x, flip=True, translation=True, color=True, cutout=True): 9 | if flip: 10 | x = random_flip(x, 0.5) 11 | if translation: 12 | x = rand_translation(x, 1/8) 13 | if color: 14 | aug_list = [rand_brightness, rand_saturation, rand_contrast] 15 | for func in aug_list: 16 | x = func(x) 17 | if cutout: 18 | x = rand_cutout(x) 19 | if flip or translation: 20 | x = x.contiguous() 21 | return x 22 | 23 | 24 | def random_flip(x, p): 25 | x_out = x.clone() 26 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 27 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) 28 | flip_mask = flip_prob < p 29 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) 30 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) 31 | return x_out 32 | 33 | 34 | def rand_brightness(x): 35 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 36 | return x 37 | 38 | 39 | def rand_saturation(x): 40 | x_mean = x.mean(dim=1, keepdim=True) 41 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 42 | return x 43 | 44 | 45 | def rand_contrast(x): 46 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 47 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 48 | return x 49 | 50 | 51 | def rand_translation(x, ratio=0.125): 52 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 53 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 54 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 55 | grid_batch, grid_x, grid_y = torch.meshgrid( 56 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 57 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 58 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 59 | ) 60 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 61 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 62 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 63 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous() 64 | return x 65 | 66 | 67 | def rand_cutout(x, ratio=0.5): 68 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 69 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 70 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 71 | grid_batch, grid_x, grid_y = torch.meshgrid( 72 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 73 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 74 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 75 | ) 76 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 77 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 78 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 79 | mask[grid_batch, grid_x, grid_y] = 0 80 | x = x * mask.unsqueeze(1) 81 | return x 82 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import pickle 5 | 6 | import torch 7 | from torch import distributed as dist 8 | 9 | 10 | def get_rank(): 11 | if not dist.is_available(): 12 | return 0 13 | 14 | if not dist.is_initialized(): 15 | return 0 16 | 17 | return dist.get_rank() 18 | 19 | 20 | def synchronize(): 21 | if not dist.is_available(): 22 | return 23 | 24 | if not dist.is_initialized(): 25 | return 26 | 27 | world_size = dist.get_world_size() 28 | 29 | if world_size == 1: 30 | return 31 | 32 | dist.barrier() 33 | 34 | 35 | def get_world_size(): 36 | if not dist.is_available(): 37 | return 1 38 | 39 | if not dist.is_initialized(): 40 | return 1 41 | 42 | return dist.get_world_size() 43 | 44 | 45 | def reduce_sum(tensor): 46 | if not dist.is_available(): 47 | return tensor 48 | 49 | if not dist.is_initialized(): 50 | return tensor 51 | 52 | tensor = tensor.clone() 53 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 54 | 55 | return tensor 56 | 57 | 58 | def gather_grad(params): 59 | world_size = get_world_size() 60 | 61 | if world_size == 1: 62 | return 63 | 64 | for param in params: 65 | if param.grad is not None: 66 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 67 | param.grad.data.div_(world_size) 68 | 69 | 70 | def all_gather(data): 71 | world_size = get_world_size() 72 | 73 | if world_size == 1: 74 | return [data] 75 | 76 | buffer = pickle.dumps(data) 77 | storage = torch.ByteStorage.from_buffer(buffer) 78 | tensor = torch.ByteTensor(storage).to('cuda') 79 | 80 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 81 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 82 | dist.all_gather(size_list, local_size) 83 | size_list = [int(size.item()) for size in size_list] 84 | max_size = max(size_list) 85 | 86 | tensor_list = [] 87 | for _ in size_list: 88 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 89 | 90 | if local_size != max_size: 91 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 92 | tensor = torch.cat((tensor, padding), 0) 93 | 94 | dist.all_gather(tensor_list, tensor) 95 | 96 | data_list = [] 97 | 98 | for size, tensor in zip(size_list, tensor_list): 99 | buffer = tensor.cpu().numpy().tobytes()[:size] 100 | data_list.append(pickle.loads(buffer)) 101 | 102 | return data_list 103 | 104 | 105 | def reduce_loss_dict(loss_dict): 106 | world_size = get_world_size() 107 | 108 | if world_size < 2: 109 | return loss_dict 110 | 111 | with torch.no_grad(): 112 | keys = [] 113 | losses = [] 114 | 115 | for k in sorted(loss_dict.keys()): 116 | keys.append(k) 117 | losses.append(loss_dict[k]) 118 | 119 | losses = torch.stack(losses, 0) 120 | dist.reduce(losses, dst=0) 121 | 122 | if dist.get_rank() == 0: 123 | losses /= world_size 124 | 125 | reduced_losses = {k: v for k, v in zip(keys, losses)} 126 | 127 | return reduced_losses 128 | -------------------------------------------------------------------------------- /utils/fid_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 5 | 6 | The FID metric calculates the distance between two distributions of images. 7 | Typically, we have summary statistics (mean & covariance matrix) of one 8 | of these distributions, while the 2nd distribution is given by a GAN. 9 | 10 | When run as a stand-alone program, it compares the distribution of 11 | images that are stored as PNG/JPEG at a specified location with a 12 | distribution given by summary statistics (in pickle format). 13 | 14 | The FID is calculated by assuming that X_1 and X_2 are the activations of 15 | the pool_3 layer of the inception net for generated samples and real world 16 | samples respectively. 17 | 18 | See --help to see further details. 19 | 20 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 21 | of Tensorflow 22 | 23 | Copyright 2018 Institute of Bioinformatics, JKU Linz 24 | 25 | Licensed under the Apache License, Version 2.0 (the "License"); 26 | you may not use this file except in compliance with the License. 27 | You may obtain a copy of the License at 28 | 29 | http://www.apache.org/licenses/LICENSE-2.0 30 | 31 | Unless required by applicable law or agreed to in writing, software 32 | distributed under the License is distributed on an "AS IS" BASIS, 33 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 34 | See the License for the specific language governing permissions and 35 | limitations under the License. 36 | """ 37 | import os 38 | import pathlib 39 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 40 | from multiprocessing import cpu_count 41 | 42 | import numpy as np 43 | import torch 44 | import torchvision.transforms as TF 45 | from PIL import Image 46 | from scipy import linalg 47 | from torch.nn.functional import adaptive_avg_pool2d 48 | 49 | try: 50 | from tqdm import tqdm 51 | except ImportError: 52 | # If tqdm is not available, provide a mock version of it 53 | def tqdm(x): 54 | return x 55 | 56 | from utils.inception import InceptionV3 57 | 58 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 59 | parser.add_argument('--batch-size', type=int, default=50, 60 | help='Batch size to use') 61 | parser.add_argument('--num-workers', type=int, default=8, 62 | help='Number of processes to use for data loading') 63 | parser.add_argument('--device', type=str, default=None, 64 | help='Device to use. Like cuda, cuda:0 or cpu') 65 | parser.add_argument('--dims', type=int, default=2048, 66 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 67 | help=('Dimensionality of Inception features to use. ' 68 | 'By default, uses pool3 features')) 69 | parser.add_argument('path', type=str, nargs=2, 70 | help=('Paths to the generated images or ' 71 | 'to .npz statistic files')) 72 | 73 | IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 74 | 'tif', 'tiff', 'webp'} 75 | 76 | 77 | class ImagePathDataset(torch.utils.data.Dataset): 78 | def __init__(self, files, transforms=None): 79 | self.files = files 80 | self.transforms = transforms 81 | 82 | def __len__(self): 83 | return len(self.files) 84 | 85 | def __getitem__(self, i): 86 | path = self.files[i] 87 | img = Image.open(path).convert('RGB') 88 | if self.transforms is not None: 89 | img = self.transforms(img) 90 | return img 91 | 92 | 93 | def get_activations(files, model, batch_size=50, dims=2048, device='cpu', num_workers=8): 94 | """Calculates the activations of the pool_3 layer for all images. 95 | 96 | Params: 97 | -- files : List of image files paths 98 | -- model : Instance of inception model 99 | -- batch_size : Batch size of images for the model to process at once. 100 | Make sure that the number of samples is a multiple of 101 | the batch size, otherwise some samples are ignored. This 102 | behavior is retained to match the original FID score 103 | implementation. 104 | -- dims : Dimensionality of features returned by Inception 105 | -- device : Device to run calculations 106 | -- num_workers : Number of parallel dataloader workers 107 | 108 | Returns: 109 | -- A numpy array of dimension (num images, dims) that contains the 110 | activations of the given tensor when feeding inception with the 111 | query tensor. 112 | """ 113 | model.eval() 114 | 115 | if batch_size > len(files): 116 | print(('Warning: batch size is bigger than the data size. ' 117 | 'Setting batch size to data size')) 118 | batch_size = len(files) 119 | 120 | dataset = ImagePathDataset(files, transforms=TF.ToTensor()) 121 | dataloader = torch.utils.data.DataLoader(dataset, 122 | batch_size=batch_size, 123 | shuffle=False, 124 | drop_last=False, 125 | num_workers=num_workers) 126 | 127 | pred_arr = np.empty((len(files), dims)) 128 | 129 | start_idx = 0 130 | 131 | for batch in dataloader: 132 | batch = batch.to(device) 133 | 134 | with torch.no_grad(): 135 | pred = model(batch)[0] 136 | 137 | # If model output is not scalar, apply global spatial average pooling. 138 | # This happens if you choose a dimensionality not equal 2048. 139 | if pred.size(2) != 1 or pred.size(3) != 1: 140 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 141 | 142 | pred = pred.squeeze(3).squeeze(2).cpu().numpy() 143 | 144 | pred_arr[start_idx:start_idx + pred.shape[0]] = pred 145 | 146 | start_idx = start_idx + pred.shape[0] 147 | 148 | return pred_arr 149 | 150 | 151 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 152 | """Numpy implementation of the Frechet Distance. 153 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 154 | and X_2 ~ N(mu_2, C_2) is 155 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 156 | 157 | Stable version by Dougal J. Sutherland. 158 | 159 | Params: 160 | -- mu1 : Numpy array containing the activations of a layer of the 161 | inception net (like returned by the function 'get_predictions') 162 | for generated samples. 163 | -- mu2 : The sample mean over activations, precalculated on an 164 | representative data set. 165 | -- sigma1: The covariance matrix over activations for generated samples. 166 | -- sigma2: The covariance matrix over activations, precalculated on an 167 | representative data set. 168 | 169 | Returns: 170 | -- : The Frechet Distance. 171 | """ 172 | 173 | mu1 = np.atleast_1d(mu1) 174 | mu2 = np.atleast_1d(mu2) 175 | 176 | sigma1 = np.atleast_2d(sigma1) 177 | sigma2 = np.atleast_2d(sigma2) 178 | 179 | assert mu1.shape == mu2.shape, \ 180 | 'Training and test mean vectors have different lengths' 181 | assert sigma1.shape == sigma2.shape, \ 182 | 'Training and test covariances have different dimensions' 183 | 184 | diff = mu1 - mu2 185 | 186 | # Product might be almost singular 187 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 188 | if not np.isfinite(covmean).all(): 189 | msg = ('fid calculation produces singular product; ' 190 | 'adding %s to diagonal of cov estimates') % eps 191 | print(msg) 192 | offset = np.eye(sigma1.shape[0]) * eps 193 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 194 | 195 | # Numerical error might give slight imaginary component 196 | if np.iscomplexobj(covmean): 197 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 198 | m = np.max(np.abs(covmean.imag)) 199 | raise ValueError('Imaginary component {}'.format(m)) 200 | covmean = covmean.real 201 | 202 | tr_covmean = np.trace(covmean) 203 | 204 | return (diff.dot(diff) + np.trace(sigma1) 205 | + np.trace(sigma2) - 2 * tr_covmean) 206 | 207 | 208 | def calculate_activation_statistics(files, model, batch_size=50, dims=2048, 209 | device='cpu', num_workers=8): 210 | """Calculation of the statistics used by the FID. 211 | Params: 212 | -- files : List of image files paths 213 | -- model : Instance of inception model 214 | -- batch_size : The images numpy array is split into batches with 215 | batch size batch_size. A reasonable batch size 216 | depends on the hardware. 217 | -- dims : Dimensionality of features returned by Inception 218 | -- device : Device to run calculations 219 | -- num_workers : Number of parallel dataloader workers 220 | 221 | Returns: 222 | -- mu : The mean over samples of the activations of the pool_3 layer of 223 | the inception model. 224 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 225 | the inception model. 226 | """ 227 | act = get_activations(files, model, batch_size, dims, device, num_workers) 228 | mu = np.mean(act, axis=0) 229 | sigma = np.cov(act, rowvar=False) 230 | return mu, sigma 231 | 232 | 233 | def compute_statistics_of_path(path, model, batch_size, dims, device, num_workers=8): 234 | if path.endswith('.npz'): 235 | with np.load(path) as f: 236 | m, s = f['mu'][:], f['sigma'][:] 237 | else: 238 | path = pathlib.Path(path) 239 | files = sorted([file for ext in IMAGE_EXTENSIONS 240 | for file in path.glob('*.{}'.format(ext))]) 241 | 242 | m, s = calculate_activation_statistics(files, model, batch_size, 243 | dims, device, num_workers) 244 | 245 | return m, s 246 | 247 | 248 | def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=8): 249 | """Calculates the FID of two paths""" 250 | for p in paths: 251 | if not os.path.exists(p): 252 | raise RuntimeError('Invalid path: %s' % p) 253 | 254 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 255 | 256 | model = InceptionV3([block_idx]).to(device) 257 | 258 | m1, s1 = compute_statistics_of_path(paths[0], model, batch_size, 259 | dims, device, num_workers) 260 | m2, s2 = compute_statistics_of_path(paths[1], model, batch_size, 261 | dims, device, num_workers) 262 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 263 | 264 | return fid_value 265 | 266 | 267 | def main(): 268 | args = parser.parse_args() 269 | 270 | if args.device is None: 271 | device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') 272 | else: 273 | device = torch.device(args.device) 274 | 275 | fid_value = calculate_fid_given_paths(args.path, 276 | args.batch_size, 277 | device, 278 | args.dims, 279 | args.num_workers) 280 | print('FID: ', fid_value) 281 | 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /utils/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision 8 | 9 | try: 10 | from torchvision.models.utils import load_state_dict_from_url 11 | except ImportError: 12 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 13 | 14 | # Inception weights ported to Pytorch from 15 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 16 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501 17 | 18 | 19 | class InceptionV3(nn.Module): 20 | """Pretrained InceptionV3 network returning feature maps""" 21 | 22 | # Index of default block of inception to return, 23 | # corresponds to output of final average pooling 24 | DEFAULT_BLOCK_INDEX = 3 25 | 26 | # Maps feature dimensionality to their output blocks indices 27 | BLOCK_INDEX_BY_DIM = { 28 | 64: 0, # First max pooling features 29 | 192: 1, # Second max pooling featurs 30 | 768: 2, # Pre-aux classifier features 31 | 2048: 3 # Final average pooling features 32 | } 33 | 34 | def __init__(self, 35 | output_blocks=(DEFAULT_BLOCK_INDEX,), 36 | resize_input=True, 37 | normalize_input=True, 38 | requires_grad=False, 39 | use_fid_inception=True): 40 | """Build pretrained InceptionV3 41 | 42 | Parameters 43 | ---------- 44 | output_blocks : list of int 45 | Indices of blocks to return features of. Possible values are: 46 | - 0: corresponds to output of first max pooling 47 | - 1: corresponds to output of second max pooling 48 | - 2: corresponds to output which is fed to aux classifier 49 | - 3: corresponds to output of final average pooling 50 | resize_input : bool 51 | If true, bilinearly resizes input to width and height 299 before 52 | feeding input to model. As the network without fully connected 53 | layers is fully convolutional, it should be able to handle inputs 54 | of arbitrary size, so resizing might not be strictly needed 55 | normalize_input : bool 56 | If true, scales the input from range (0, 1) to the range the 57 | pretrained Inception network expects, namely (-1, 1) 58 | requires_grad : bool 59 | If true, parameters of the model require gradients. Possibly useful 60 | for finetuning the network 61 | use_fid_inception : bool 62 | If true, uses the pretrained Inception model used in Tensorflow's 63 | FID implementation. If false, uses the pretrained Inception model 64 | available in torchvision. The FID Inception model has different 65 | weights and a slightly different structure from torchvision's 66 | Inception model. If you want to compute FID scores, you are 67 | strongly advised to set this parameter to true to get comparable 68 | results. 69 | """ 70 | super(InceptionV3, self).__init__() 71 | 72 | self.resize_input = resize_input 73 | self.normalize_input = normalize_input 74 | self.output_blocks = sorted(output_blocks) 75 | self.last_needed_block = max(output_blocks) 76 | 77 | assert self.last_needed_block <= 3, \ 78 | 'Last possible output block index is 3' 79 | 80 | self.blocks = nn.ModuleList() 81 | 82 | if use_fid_inception: 83 | inception = fid_inception_v3() 84 | else: 85 | inception = _inception_v3(pretrained=True) 86 | 87 | # Block 0: input to maxpool1 88 | block0 = [ 89 | inception.Conv2d_1a_3x3, 90 | inception.Conv2d_2a_3x3, 91 | inception.Conv2d_2b_3x3, 92 | nn.MaxPool2d(kernel_size=3, stride=2) 93 | ] 94 | self.blocks.append(nn.Sequential(*block0)) 95 | 96 | # Block 1: maxpool1 to maxpool2 97 | if self.last_needed_block >= 1: 98 | block1 = [ 99 | inception.Conv2d_3b_1x1, 100 | inception.Conv2d_4a_3x3, 101 | nn.MaxPool2d(kernel_size=3, stride=2) 102 | ] 103 | self.blocks.append(nn.Sequential(*block1)) 104 | 105 | # Block 2: maxpool2 to aux classifier 106 | if self.last_needed_block >= 2: 107 | block2 = [ 108 | inception.Mixed_5b, 109 | inception.Mixed_5c, 110 | inception.Mixed_5d, 111 | inception.Mixed_6a, 112 | inception.Mixed_6b, 113 | inception.Mixed_6c, 114 | inception.Mixed_6d, 115 | inception.Mixed_6e, 116 | ] 117 | self.blocks.append(nn.Sequential(*block2)) 118 | 119 | # Block 3: aux classifier to final avgpool 120 | if self.last_needed_block >= 3: 121 | block3 = [ 122 | inception.Mixed_7a, 123 | inception.Mixed_7b, 124 | inception.Mixed_7c, 125 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 126 | ] 127 | self.blocks.append(nn.Sequential(*block3)) 128 | 129 | for param in self.parameters(): 130 | param.requires_grad = requires_grad 131 | 132 | def forward(self, inp): 133 | """Get Inception feature maps 134 | 135 | Parameters 136 | ---------- 137 | inp : torch.autograd.Variable 138 | Input tensor of shape Bx3xHxW. Values are expected to be in 139 | range (0, 1) 140 | 141 | Returns 142 | ------- 143 | List of torch.autograd.Variable, corresponding to the selected output 144 | block, sorted ascending by index 145 | """ 146 | outp = [] 147 | x = inp 148 | 149 | if self.resize_input: 150 | x = F.interpolate(x, 151 | size=(299, 299), 152 | mode='bilinear', 153 | align_corners=False) 154 | 155 | if self.normalize_input: 156 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 157 | 158 | for idx, block in enumerate(self.blocks): 159 | x = block(x) 160 | if idx in self.output_blocks: 161 | outp.append(x) 162 | 163 | if idx == self.last_needed_block: 164 | break 165 | 166 | return outp 167 | 168 | 169 | def _inception_v3(*args, **kwargs): 170 | """Wraps `torchvision.models.inception_v3` 171 | 172 | Skips default weight inititialization if supported by torchvision version. 173 | See https://github.com/mseitzer/pytorch-fid/issues/28. 174 | """ 175 | try: 176 | version = tuple(map(int, torchvision.__version__.split('.')[:2])) 177 | except ValueError: 178 | # Just a caution against weird version strings 179 | version = (0,) 180 | 181 | if version >= (0, 6): 182 | kwargs['init_weights'] = False 183 | 184 | return torchvision.models.inception_v3(*args, **kwargs) 185 | 186 | 187 | def fid_inception_v3(): 188 | """Build pretrained Inception model for FID computation 189 | 190 | The Inception model for FID computation uses a different set of weights 191 | and has a slightly different structure than torchvision's Inception. 192 | 193 | This method first constructs torchvision's Inception and then patches the 194 | necessary parts that are different in the FID Inception model. 195 | """ 196 | inception = _inception_v3(num_classes=1008, 197 | aux_logits=False, 198 | pretrained=False) 199 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 200 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 201 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 202 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 203 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 204 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 205 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 206 | inception.Mixed_7b = FIDInceptionE_1(1280) 207 | inception.Mixed_7c = FIDInceptionE_2(2048) 208 | 209 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 210 | inception.load_state_dict(state_dict) 211 | return inception 212 | 213 | 214 | class FIDInceptionA(torchvision.models.inception.InceptionA): 215 | """InceptionA block patched for FID computation""" 216 | def __init__(self, in_channels, pool_features): 217 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 218 | 219 | def forward(self, x): 220 | branch1x1 = self.branch1x1(x) 221 | 222 | branch5x5 = self.branch5x5_1(x) 223 | branch5x5 = self.branch5x5_2(branch5x5) 224 | 225 | branch3x3dbl = self.branch3x3dbl_1(x) 226 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 227 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 228 | 229 | # Patch: Tensorflow's average pool does not use the padded zero's in 230 | # its average calculation 231 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 232 | count_include_pad=False) 233 | branch_pool = self.branch_pool(branch_pool) 234 | 235 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 236 | return torch.cat(outputs, 1) 237 | 238 | 239 | class FIDInceptionC(torchvision.models.inception.InceptionC): 240 | """InceptionC block patched for FID computation""" 241 | def __init__(self, in_channels, channels_7x7): 242 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 243 | 244 | def forward(self, x): 245 | branch1x1 = self.branch1x1(x) 246 | 247 | branch7x7 = self.branch7x7_1(x) 248 | branch7x7 = self.branch7x7_2(branch7x7) 249 | branch7x7 = self.branch7x7_3(branch7x7) 250 | 251 | branch7x7dbl = self.branch7x7dbl_1(x) 252 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 253 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 254 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 255 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 256 | 257 | # Patch: Tensorflow's average pool does not use the padded zero's in 258 | # its average calculation 259 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 260 | count_include_pad=False) 261 | branch_pool = self.branch_pool(branch_pool) 262 | 263 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 264 | return torch.cat(outputs, 1) 265 | 266 | 267 | class FIDInceptionE_1(torchvision.models.inception.InceptionE): 268 | """First InceptionE block patched for FID computation""" 269 | def __init__(self, in_channels): 270 | super(FIDInceptionE_1, self).__init__(in_channels) 271 | 272 | def forward(self, x): 273 | branch1x1 = self.branch1x1(x) 274 | 275 | branch3x3 = self.branch3x3_1(x) 276 | branch3x3 = [ 277 | self.branch3x3_2a(branch3x3), 278 | self.branch3x3_2b(branch3x3), 279 | ] 280 | branch3x3 = torch.cat(branch3x3, 1) 281 | 282 | branch3x3dbl = self.branch3x3dbl_1(x) 283 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 284 | branch3x3dbl = [ 285 | self.branch3x3dbl_3a(branch3x3dbl), 286 | self.branch3x3dbl_3b(branch3x3dbl), 287 | ] 288 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 289 | 290 | # Patch: Tensorflow's average pool does not use the padded zero's in 291 | # its average calculation 292 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 293 | count_include_pad=False) 294 | branch_pool = self.branch_pool(branch_pool) 295 | 296 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 297 | return torch.cat(outputs, 1) 298 | 299 | 300 | class FIDInceptionE_2(torchvision.models.inception.InceptionE): 301 | """Second InceptionE block patched for FID computation""" 302 | def __init__(self, in_channels): 303 | super(FIDInceptionE_2, self).__init__(in_channels) 304 | 305 | def forward(self, x): 306 | branch1x1 = self.branch1x1(x) 307 | 308 | branch3x3 = self.branch3x3_1(x) 309 | branch3x3 = [ 310 | self.branch3x3_2a(branch3x3), 311 | self.branch3x3_2b(branch3x3), 312 | ] 313 | branch3x3 = torch.cat(branch3x3, 1) 314 | 315 | branch3x3dbl = self.branch3x3dbl_1(x) 316 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 317 | branch3x3dbl = [ 318 | self.branch3x3dbl_3a(branch3x3dbl), 319 | self.branch3x3dbl_3b(branch3x3dbl), 320 | ] 321 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 322 | 323 | # Patch: The FID Inception model uses max pooling instead of average 324 | # pooling. This is likely an error in this specific Inception 325 | # implementation, as other Inception models use average pooling here 326 | # (which matches the description in the paper). 327 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 328 | branch_pool = self.branch_pool(branch_pool) 329 | 330 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 331 | return torch.cat(outputs, 1) 332 | -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import tensorflow as tf 6 | 7 | 8 | class Visualizer(): 9 | def __init__(self, args): 10 | self.args = args 11 | self.tf = tf 12 | self.log_dir = os.path.join(args.checkpoint_path, 'logs') 13 | self.writer = tf.summary.FileWriter(self.log_dir) 14 | 15 | def plot_loss(self, loss, step, tag): 16 | summary = self.tf.Summary( 17 | value=[self.tf.Summary.Value(tag=tag, simple_value=loss)]) 18 | self.writer.add_summary(summary, step) 19 | 20 | def plot_dict(self, loss, step): 21 | for tag, value in loss.items(): 22 | summary = self.tf.Summary( 23 | value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 24 | self.writer.add_summary(summary, step) 25 | --------------------------------------------------------------------------------