├── LICENSE ├── LICENSE-KSH ├── LICENSE-NVIDIA ├── README.md ├── dataset.py ├── examples ├── 000001.png ├── 000002.png ├── 000003.png ├── 000004.png ├── 000005.png ├── 000006.png ├── 000007.png ├── 000008.png ├── 000009.png └── 000010.png ├── imgs ├── 0k.png ├── 1M.png ├── face2webtoon.png ├── interpolation_domain_guided_encoder.png ├── interpolation_idinversion_500steps.png └── interpolation_results.png ├── interpolate.ipynb ├── model.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu └── train_encoder.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Bryan Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-KSH: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In-Domain GAN Inversion for Real Image Editing 2 | 3 | 4 | Based on **Seonghyeon Kim's Pytorch Implementation of StyleGAN2** 5 | 6 | [[Paper](https://arxiv.org/pdf/2004.00049.pdf)] [[Official Code](https://github.com/genforce/idinvert)] [[StyleGAN2 Pytorch](https://github.com/rosinality/stylegan2-pytorch)] 7 | 8 | ## Train Encoder 9 | 10 | ``` 11 | python train_encoder.py 12 | ``` 13 | 14 | **0k iter**\ 15 | 16 | 17 | **1M iter**\ 18 | \ 19 | [[encoder checkpoint](https://drive.google.com/file/d/1QQuZGtHgD24Dn5E21Z2Ik25EPng58MoU/view?usp=sharing)] [[generator checkpoint](https://drive.google.com/file/d/1TH77dUsqcq50htIZT6DljFYr4T_ziJli/view)] 20 | 21 | **Note:** The encoder architecture and loss weights are different from the original implemetation. 22 | 23 | 24 | 25 | ## Interpolation 26 | 27 | ``` 28 | interpolate.ipynb 29 | ``` 30 | 31 | **Domain-Guided Encoder (Initial projection)**\ 32 | 33 | 34 | **In-Domain Inversion (500 steps)**\ 35 | 36 | 37 | **Inperpolation Result**\ 38 | 39 | 40 | ## Encoder + Model Interpolation 41 | [[Paper](https://arxiv.org/abs/2010.05334)] [[Naver Webtoon Model](https://github.com/bryandlee/naver-webtoon-faces#stylegan2)] 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class MultiResolutionDataset(Dataset): 15 | def __init__(self, path, transform, resolution=256): 16 | self.env = lmdb.open( 17 | path, 18 | max_readers=32, 19 | readonly=True, 20 | lock=False, 21 | readahead=False, 22 | meminit=False, 23 | ) 24 | 25 | if not self.env: 26 | raise IOError('Cannot open lmdb dataset', path) 27 | 28 | with self.env.begin(write=False) as txn: 29 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 30 | 31 | self.resolution = resolution 32 | self.transform = transform 33 | 34 | def __len__(self): 35 | return self.length 36 | 37 | def __getitem__(self, index): 38 | with self.env.begin(write=False) as txn: 39 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 40 | img_bytes = txn.get(key) 41 | 42 | buffer = BytesIO(img_bytes) 43 | img = Image.open(buffer) 44 | img = self.transform(img) 45 | 46 | return img 47 | 48 | 49 | def resize_and_convert(img, size, resample, quality=100): 50 | img = trans_fn.resize(img, size, resample) 51 | img = trans_fn.center_crop(img, size) 52 | buffer = BytesIO() 53 | img.save(buffer, format='jpeg', quality=quality) 54 | val = buffer.getvalue() 55 | 56 | return val 57 | 58 | 59 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): 60 | imgs = [] 61 | 62 | for size in sizes: 63 | imgs.append(resize_and_convert(img, size, resample, quality)) 64 | 65 | return imgs 66 | 67 | 68 | def resize_worker(img_file, sizes, resample): 69 | i, file = img_file 70 | img = Image.open(file) 71 | img = img.convert('RGB') 72 | out = resize_multiple(img, sizes=sizes, resample=resample) 73 | 74 | return i, out 75 | 76 | 77 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): 78 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 79 | 80 | files = sorted(dataset.imgs, key=lambda x: x[0]) 81 | files = [(i, file) for i, (file, label) in enumerate(files)] 82 | total = 0 83 | 84 | with multiprocessing.Pool(n_worker) as pool: 85 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 86 | for size, img in zip(sizes, imgs): 87 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 88 | 89 | with env.begin(write=True) as txn: 90 | txn.put(key, img) 91 | 92 | total += 1 93 | 94 | with env.begin(write=True) as txn: 95 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument('--out', type=str) 101 | parser.add_argument('--size', type=str, default='128,256,512,1024') 102 | parser.add_argument('--n_worker', type=int, default=8) 103 | parser.add_argument('--resample', type=str, default='lanczos') 104 | parser.add_argument('path', type=str) 105 | 106 | args = parser.parse_args() 107 | 108 | resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR} 109 | resample = resample_map[args.resample] 110 | 111 | sizes = [int(s.strip()) for s in args.size.split(',')] 112 | 113 | print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes)) 114 | 115 | imgset = datasets.ImageFolder(args.path) 116 | 117 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 118 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 119 | -------------------------------------------------------------------------------- /examples/000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000001.png -------------------------------------------------------------------------------- /examples/000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000002.png -------------------------------------------------------------------------------- /examples/000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000003.png -------------------------------------------------------------------------------- /examples/000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000004.png -------------------------------------------------------------------------------- /examples/000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000005.png -------------------------------------------------------------------------------- /examples/000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000006.png -------------------------------------------------------------------------------- /examples/000007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000007.png -------------------------------------------------------------------------------- /examples/000008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000008.png -------------------------------------------------------------------------------- /examples/000009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000009.png -------------------------------------------------------------------------------- /examples/000010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000010.png -------------------------------------------------------------------------------- /imgs/0k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/0k.png -------------------------------------------------------------------------------- /imgs/1M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/1M.png -------------------------------------------------------------------------------- /imgs/face2webtoon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/face2webtoon.png -------------------------------------------------------------------------------- /imgs/interpolation_domain_guided_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_domain_guided_encoder.png -------------------------------------------------------------------------------- /imgs/interpolation_idinversion_500steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_idinversion_500steps.png -------------------------------------------------------------------------------- /imgs/interpolation_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_results.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer('kernel', kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer('kernel', kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer('kernel', kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = F.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 128 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 166 | ) 167 | 168 | 169 | class ScaledLeakyReLU(nn.Module): 170 | def __init__(self, negative_slope=0.2): 171 | super().__init__() 172 | 173 | self.negative_slope = negative_slope 174 | 175 | def forward(self, input): 176 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 177 | 178 | return out * math.sqrt(2) 179 | 180 | 181 | class ModulatedConv2d(nn.Module): 182 | def __init__( 183 | self, 184 | in_channel, 185 | out_channel, 186 | kernel_size, 187 | style_dim, 188 | demodulate=True, 189 | upsample=False, 190 | downsample=False, 191 | blur_kernel=[1, 3, 3, 1], 192 | ): 193 | super().__init__() 194 | 195 | self.eps = 1e-8 196 | self.kernel_size = kernel_size 197 | self.in_channel = in_channel 198 | self.out_channel = out_channel 199 | self.upsample = upsample 200 | self.downsample = downsample 201 | 202 | if upsample: 203 | factor = 2 204 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 205 | pad0 = (p + 1) // 2 + factor - 1 206 | pad1 = p // 2 + 1 207 | 208 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 209 | 210 | if downsample: 211 | factor = 2 212 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 213 | pad0 = (p + 1) // 2 214 | pad1 = p // 2 215 | 216 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 217 | 218 | fan_in = in_channel * kernel_size ** 2 219 | self.scale = 1 / math.sqrt(fan_in) 220 | self.padding = kernel_size // 2 221 | 222 | self.weight = nn.Parameter( 223 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 224 | ) 225 | 226 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 227 | 228 | self.demodulate = demodulate 229 | 230 | def __repr__(self): 231 | return ( 232 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 233 | f'upsample={self.upsample}, downsample={self.downsample})' 234 | ) 235 | 236 | def forward(self, input, style): 237 | batch, in_channel, height, width = input.shape 238 | 239 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 240 | weight = self.scale * self.weight * style 241 | 242 | if self.demodulate: 243 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 244 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 245 | 246 | weight = weight.view( 247 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 248 | ) 249 | 250 | if self.upsample: 251 | input = input.view(1, batch * in_channel, height, width) 252 | weight = weight.view( 253 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 254 | ) 255 | weight = weight.transpose(1, 2).reshape( 256 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 257 | ) 258 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 259 | _, _, height, width = out.shape 260 | out = out.view(batch, self.out_channel, height, width) 261 | out = self.blur(out) 262 | 263 | elif self.downsample: 264 | input = self.blur(input) 265 | _, _, height, width = input.shape 266 | input = input.view(1, batch * in_channel, height, width) 267 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 268 | _, _, height, width = out.shape 269 | out = out.view(batch, self.out_channel, height, width) 270 | 271 | else: 272 | input = input.view(1, batch * in_channel, height, width) 273 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 274 | _, _, height, width = out.shape 275 | out = out.view(batch, self.out_channel, height, width) 276 | 277 | return out 278 | 279 | 280 | class NoiseInjection(nn.Module): 281 | def __init__(self): 282 | super().__init__() 283 | 284 | self.weight = nn.Parameter(torch.zeros(1)) 285 | 286 | def forward(self, image, noise=None): 287 | if noise is None: 288 | batch, _, height, width = image.shape 289 | noise = image.new_empty(batch, 1, height, width).normal_() 290 | 291 | return image + self.weight * noise 292 | 293 | 294 | class ConstantInput(nn.Module): 295 | def __init__(self, channel, size=4): 296 | super().__init__() 297 | 298 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 299 | 300 | def forward(self, input): 301 | batch = input.shape[0] 302 | out = self.input.repeat(batch, 1, 1, 1) 303 | 304 | return out 305 | 306 | 307 | class StyledConv(nn.Module): 308 | def __init__( 309 | self, 310 | in_channel, 311 | out_channel, 312 | kernel_size, 313 | style_dim, 314 | upsample=False, 315 | blur_kernel=[1, 3, 3, 1], 316 | demodulate=True, 317 | ): 318 | super().__init__() 319 | 320 | self.conv = ModulatedConv2d( 321 | in_channel, 322 | out_channel, 323 | kernel_size, 324 | style_dim, 325 | upsample=upsample, 326 | blur_kernel=blur_kernel, 327 | demodulate=demodulate, 328 | ) 329 | 330 | self.noise = NoiseInjection() 331 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 332 | # self.activate = ScaledLeakyReLU(0.2) 333 | self.activate = FusedLeakyReLU(out_channel) 334 | 335 | def forward(self, input, style, noise=None): 336 | out = self.conv(input, style) 337 | out = self.noise(out, noise=noise) 338 | # out = out + self.bias 339 | out = self.activate(out) 340 | 341 | return out 342 | 343 | 344 | class ToRGB(nn.Module): 345 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 346 | super().__init__() 347 | 348 | if upsample: 349 | self.upsample = Upsample(blur_kernel) 350 | 351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 353 | 354 | def forward(self, input, style, skip=None): 355 | out = self.conv(input, style) 356 | out = out + self.bias 357 | 358 | if skip is not None: 359 | skip = self.upsample(skip) 360 | 361 | out = out + skip 362 | 363 | return out 364 | 365 | 366 | class Generator(nn.Module): 367 | def __init__( 368 | self, 369 | size, 370 | style_dim, 371 | n_mlp, 372 | channel_multiplier=2, 373 | blur_kernel=[1, 3, 3, 1], 374 | lr_mlp=0.01, 375 | ): 376 | super().__init__() 377 | 378 | self.size = size 379 | 380 | self.style_dim = style_dim 381 | 382 | layers = [PixelNorm()] 383 | 384 | for i in range(n_mlp): 385 | layers.append( 386 | EqualLinear( 387 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 388 | ) 389 | ) 390 | 391 | self.style = nn.Sequential(*layers) 392 | 393 | self.channels = { 394 | 4: 512, 395 | 8: 512, 396 | 16: 512, 397 | 32: 512, 398 | 64: 256 * channel_multiplier, 399 | 128: 128 * channel_multiplier, 400 | 256: 64 * channel_multiplier, 401 | 512: 32 * channel_multiplier, 402 | 1024: 16 * channel_multiplier, 403 | } 404 | 405 | self.input = ConstantInput(self.channels[4]) 406 | self.conv1 = StyledConv( 407 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 408 | ) 409 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 410 | 411 | self.log_size = int(math.log(size, 2)) 412 | self.num_layers = (self.log_size - 2) * 2 + 1 413 | 414 | self.convs = nn.ModuleList() 415 | self.upsamples = nn.ModuleList() 416 | self.to_rgbs = nn.ModuleList() 417 | self.noises = nn.Module() 418 | 419 | in_channel = self.channels[4] 420 | 421 | for layer_idx in range(self.num_layers): 422 | res = (layer_idx + 5) // 2 423 | shape = [1, 1, 2 ** res, 2 ** res] 424 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) 425 | 426 | for i in range(3, self.log_size + 1): 427 | out_channel = self.channels[2 ** i] 428 | 429 | self.convs.append( 430 | StyledConv( 431 | in_channel, 432 | out_channel, 433 | 3, 434 | style_dim, 435 | upsample=True, 436 | blur_kernel=blur_kernel, 437 | ) 438 | ) 439 | 440 | self.convs.append( 441 | StyledConv( 442 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 443 | ) 444 | ) 445 | 446 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 447 | 448 | in_channel = out_channel 449 | 450 | self.n_latent = self.log_size * 2 - 2 451 | 452 | def make_noise(self): 453 | device = self.input.input.device 454 | 455 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 456 | 457 | for i in range(3, self.log_size + 1): 458 | for _ in range(2): 459 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 460 | 461 | return noises 462 | 463 | def mean_latent(self, n_latent): 464 | latent_in = torch.randn( 465 | n_latent, self.style_dim, device=self.input.input.device 466 | ) 467 | latent = self.style(latent_in).mean(0, keepdim=True) 468 | 469 | return latent 470 | 471 | def get_latent(self, input): 472 | return self.style(input) 473 | 474 | def forward( 475 | self, 476 | styles, 477 | return_latents=False, 478 | inject_index=None, 479 | truncation=1, 480 | truncation_latent=None, 481 | input_is_latent=False, 482 | noise=None, 483 | randomize_noise=True, 484 | ): 485 | if not input_is_latent: 486 | styles = [self.style(s) for s in styles] 487 | 488 | if noise is None: 489 | if randomize_noise: 490 | noise = [None] * self.num_layers 491 | else: 492 | noise = [ 493 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 494 | ] 495 | 496 | if truncation < 1: 497 | style_t = [] 498 | 499 | for style in styles: 500 | style_t.append( 501 | truncation_latent + truncation * (style - truncation_latent) 502 | ) 503 | 504 | styles = style_t 505 | 506 | if len(styles) < 2: 507 | inject_index = self.n_latent 508 | 509 | if styles[0].ndim < 3: 510 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 511 | 512 | else: 513 | latent = styles[0] 514 | 515 | else: 516 | if inject_index is None: 517 | inject_index = random.randint(1, self.n_latent - 1) 518 | 519 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 520 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 521 | 522 | latent = torch.cat([latent, latent2], 1) 523 | 524 | out = self.input(latent) 525 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 526 | 527 | skip = self.to_rgb1(out, latent[:, 1]) 528 | 529 | i = 1 530 | for conv1, conv2, noise1, noise2, to_rgb in zip( 531 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 532 | ): 533 | out = conv1(out, latent[:, i], noise=noise1) 534 | out = conv2(out, latent[:, i + 1], noise=noise2) 535 | skip = to_rgb(out, latent[:, i + 2], skip) 536 | 537 | i += 2 538 | 539 | image = skip 540 | 541 | if return_latents: 542 | return image, latent 543 | 544 | else: 545 | return image, None 546 | 547 | 548 | class ConvLayer(nn.Sequential): 549 | def __init__( 550 | self, 551 | in_channel, 552 | out_channel, 553 | kernel_size, 554 | downsample=False, 555 | blur_kernel=[1, 3, 3, 1], 556 | bias=True, 557 | activate=True, 558 | ): 559 | layers = [] 560 | 561 | if downsample: 562 | factor = 2 563 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 564 | pad0 = (p + 1) // 2 565 | pad1 = p // 2 566 | 567 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 568 | 569 | stride = 2 570 | self.padding = 0 571 | 572 | else: 573 | stride = 1 574 | self.padding = kernel_size // 2 575 | 576 | layers.append( 577 | EqualConv2d( 578 | in_channel, 579 | out_channel, 580 | kernel_size, 581 | padding=self.padding, 582 | stride=stride, 583 | bias=bias and not activate, 584 | ) 585 | ) 586 | 587 | if activate: 588 | if bias: 589 | layers.append(FusedLeakyReLU(out_channel)) 590 | 591 | else: 592 | layers.append(ScaledLeakyReLU(0.2)) 593 | 594 | super().__init__(*layers) 595 | 596 | 597 | class ResBlock(nn.Module): 598 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 599 | super().__init__() 600 | 601 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 602 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 603 | 604 | self.skip = ConvLayer( 605 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 606 | ) 607 | 608 | def forward(self, input): 609 | out = self.conv1(input) 610 | out = self.conv2(out) 611 | 612 | skip = self.skip(input) 613 | out = (out + skip) / math.sqrt(2) 614 | 615 | return out 616 | 617 | 618 | class Discriminator(nn.Module): 619 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 620 | super().__init__() 621 | 622 | channels = { 623 | 4: 512, 624 | 8: 512, 625 | 16: 512, 626 | 32: 512, 627 | 64: 256 * channel_multiplier, 628 | 128: 128 * channel_multiplier, 629 | 256: 64 * channel_multiplier, 630 | 512: 32 * channel_multiplier, 631 | 1024: 16 * channel_multiplier, 632 | } 633 | 634 | convs = [ConvLayer(3, channels[size], 1)] 635 | 636 | log_size = int(math.log(size, 2)) 637 | 638 | in_channel = channels[size] 639 | 640 | for i in range(log_size, 2, -1): 641 | out_channel = channels[2 ** (i - 1)] 642 | 643 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 644 | 645 | in_channel = out_channel 646 | 647 | self.convs = nn.Sequential(*convs) 648 | 649 | self.stddev_group = 4 650 | self.stddev_feat = 1 651 | 652 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 653 | self.final_linear = nn.Sequential( 654 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), 655 | EqualLinear(channels[4], 1), 656 | ) 657 | 658 | def forward(self, input): 659 | out = self.convs(input) 660 | 661 | batch, channel, height, width = out.shape 662 | group = min(batch, self.stddev_group) 663 | stddev = out.view( 664 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 665 | ) 666 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 667 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 668 | stddev = stddev.repeat(group, 1, height, width) 669 | out = torch.cat([out, stddev], 1) 670 | 671 | out = self.final_conv(out) 672 | 673 | out = out.view(batch, -1) 674 | out = self.final_linear(out) 675 | 676 | return out 677 | 678 | 679 | class Encoder(nn.Module): 680 | def __init__(self, size, w_dim=512): 681 | super().__init__() 682 | 683 | channels = { 684 | 4: 512, 685 | 8: 512, 686 | 16: 512, 687 | 32: 512, 688 | 64: 256, 689 | 128: 128, 690 | 256: 64, 691 | 512: 32, 692 | 1024: 16 693 | } 694 | 695 | self.w_dim = w_dim 696 | log_size = int(math.log(size, 2)) 697 | 698 | self.n_latents = log_size*2 - 2 699 | 700 | convs = [ConvLayer(3, channels[size], 1)] 701 | 702 | in_channel = channels[size] 703 | for i in range(log_size, 2, -1): 704 | out_channel = channels[2 ** (i - 1)] 705 | convs.append(ResBlock(in_channel, out_channel)) 706 | in_channel = out_channel 707 | 708 | convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False)) 709 | 710 | self.convs = nn.Sequential(*convs) 711 | 712 | def forward(self, input): 713 | out = self.convs(input) 714 | return out.view(len(input), self.n_latents, self.w_dim) 715 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /train_encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | import torchvision 13 | from torchvision import transforms, utils 14 | from tqdm import tqdm 15 | 16 | from model import Encoder, Generator, Discriminator 17 | from dataset import MultiResolutionDataset 18 | 19 | try: 20 | from tensorboardX import SummaryWriter 21 | except ImportError: 22 | SummaryWriter = None 23 | 24 | 25 | def data_sampler(dataset, shuffle): 26 | if shuffle: 27 | return data.RandomSampler(dataset) 28 | 29 | else: 30 | return data.SequentialSampler(dataset) 31 | 32 | 33 | def requires_grad(model, flag=True): 34 | for p in model.parameters(): 35 | p.requires_grad = flag 36 | 37 | 38 | def sample_data(loader): 39 | while True: 40 | for batch in loader: 41 | yield batch 42 | 43 | 44 | def d_logistic_loss(real_pred, fake_pred): 45 | real_loss = F.softplus(-real_pred) 46 | fake_loss = F.softplus(fake_pred) 47 | 48 | return real_loss.mean() + fake_loss.mean() 49 | 50 | 51 | def d_r1_loss(real_pred, real_img): 52 | grad_real, = autograd.grad( 53 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 54 | ) 55 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 56 | 57 | return grad_penalty 58 | 59 | 60 | def g_nonsaturating_loss(fake_pred): 61 | loss = F.softplus(-fake_pred).mean() 62 | 63 | return loss 64 | 65 | 66 | class VGGLoss(nn.Module): 67 | def __init__(self, device, n_layers=5): 68 | super().__init__() 69 | 70 | feature_layers = (2, 7, 12, 21, 30) 71 | self.weights = (1.0, 1.0, 1.0, 1.0, 1.0) 72 | 73 | vgg = torchvision.models.vgg19(pretrained=True).features 74 | 75 | self.layers = nn.ModuleList() 76 | prev_layer = 0 77 | for next_layer in feature_layers[:n_layers]: 78 | layers = nn.Sequential() 79 | for layer in range(prev_layer, next_layer): 80 | layers.add_module(str(layer), vgg[layer]) 81 | self.layers.append(layers.to(device)) 82 | prev_layer = next_layer 83 | 84 | for param in self.parameters(): 85 | param.requires_grad = False 86 | 87 | self.criterion = nn.L1Loss().to(device) 88 | 89 | def forward(self, source, target): 90 | loss = 0 91 | for layer, weight in zip(self.layers, self.weights): 92 | source = layer(source) 93 | with torch.no_grad(): 94 | target = layer(target) 95 | loss += weight*self.criterion(source, target) 96 | 97 | return loss 98 | 99 | 100 | def train(args, loader, encoder, generator, discriminator, e_optim, d_optim, device): 101 | loader = sample_data(loader) 102 | 103 | pbar = range(args.iter) 104 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 105 | 106 | e_loss_val = 0 107 | d_loss_val = 0 108 | r1_loss = torch.tensor(0.0, device=device) 109 | loss_dict = {} 110 | vgg_loss = VGGLoss(device=device) 111 | 112 | accum = 0.5 ** (32 / (10 * 1000)) 113 | 114 | requires_grad(generator, False) 115 | 116 | truncation = 0.7 117 | trunc = generator.mean_latent(4096).detach() 118 | trunc.requires_grad = False 119 | 120 | if SummaryWriter and args.tensorboard: 121 | logger = SummaryWriter(logdir='./checkpoint') 122 | 123 | for idx in pbar: 124 | i = idx + args.start_iter 125 | 126 | if i > args.iter: 127 | print("Done!") 128 | 129 | break 130 | 131 | # D update 132 | requires_grad(encoder, False) 133 | requires_grad(discriminator, True) 134 | 135 | 136 | real_img = next(loader) 137 | real_img = real_img.to(device) 138 | 139 | latents = encoder(real_img) 140 | recon_img, _ = generator([latents], 141 | input_is_latent=True, 142 | truncation=truncation, 143 | truncation_latent=trunc, 144 | randomize_noise=False) 145 | 146 | recon_pred = discriminator(recon_img) 147 | real_pred = discriminator(real_img) 148 | d_loss = d_logistic_loss(real_pred, recon_pred) 149 | 150 | loss_dict["d"] = d_loss 151 | 152 | discriminator.zero_grad() 153 | d_loss.backward() 154 | d_optim.step() 155 | 156 | d_regularize = i % args.d_reg_every == 0 157 | 158 | if d_regularize: 159 | real_img.requires_grad = True 160 | real_pred = discriminator(real_img) 161 | r1_loss = d_r1_loss(real_pred, real_img) 162 | 163 | discriminator.zero_grad() 164 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() 165 | 166 | d_optim.step() 167 | 168 | loss_dict["r1"] = r1_loss 169 | 170 | # E update 171 | requires_grad(encoder, True) 172 | requires_grad(discriminator, False) 173 | 174 | real_img = real_img.detach() 175 | real_img.requires_grad = False 176 | 177 | latents = encoder(real_img) 178 | recon_img, _ = generator([latents], 179 | input_is_latent=True, 180 | truncation=truncation, 181 | truncation_latent=trunc, 182 | randomize_noise=False) 183 | 184 | recon_vgg_loss = vgg_loss(recon_img, real_img) 185 | loss_dict["vgg"] = recon_vgg_loss * args.vgg 186 | 187 | recon_l2_loss = F.mse_loss(recon_img, real_img) 188 | loss_dict["l2"] = recon_l2_loss * args.l2 189 | 190 | recon_pred = discriminator(recon_img) 191 | adv_loss = g_nonsaturating_loss(recon_pred) * args.adv 192 | loss_dict["adv"] = adv_loss 193 | 194 | e_loss = recon_vgg_loss + recon_l2_loss + adv_loss 195 | loss_dict["e_loss"] = e_loss 196 | 197 | 198 | encoder.zero_grad() 199 | e_loss.backward() 200 | e_optim.step() 201 | 202 | e_loss_val = loss_dict["e_loss"].item() 203 | vgg_loss_val = loss_dict["vgg"].item() 204 | l2_loss_val = loss_dict["l2"].item() 205 | adv_loss_val = loss_dict["adv"].item() 206 | d_loss_val = loss_dict["d"].item() 207 | r1_val = loss_dict["r1"].item() 208 | 209 | pbar.set_description( 210 | ( 211 | f"e: {e_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; l2: {l2_loss_val:.4f}; adv: {adv_loss_val:.4f}; d: {d_loss_val:.4f}; r1: {r1_val:.4f}; " 212 | ) 213 | ) 214 | 215 | if SummaryWriter and args.tensorboard: 216 | logger.add_scalar('E_loss/total', e_loss_val, i) 217 | logger.add_scalar('E_loss/vgg', vgg_loss_val, i) 218 | logger.add_scalar('E_loss/l2', l2_loss_val, i) 219 | logger.add_scalar('E_loss/adv', adv_loss_val, i) 220 | logger.add_scalar('D_loss/adv', d_loss_val, i) 221 | logger.add_scalar('D_loss/r1', r1_val, i) 222 | 223 | if i % 100 == 0: 224 | with torch.no_grad(): 225 | sample = torch.cat([real_img.detach(), recon_img.detach()]) 226 | utils.save_image( 227 | sample, 228 | f"sample/{str(i).zfill(6)}.png", 229 | nrow=int(args.batch), 230 | normalize=True, 231 | range=(-1, 1), 232 | ) 233 | 234 | if i % 10000 == 0: 235 | torch.save( 236 | { 237 | "e": encoder.state_dict(), 238 | "d": discriminator.state_dict(), 239 | "e_optim": e_optim.state_dict(), 240 | "d_optim": d_optim.state_dict(), 241 | "args": args, 242 | }, 243 | f"checkpoint/encoder_{str(i).zfill(6)}.pt", 244 | ) 245 | 246 | 247 | if __name__ == "__main__": 248 | parser = argparse.ArgumentParser() 249 | 250 | parser.add_argument("--data", type=str, default=None) 251 | parser.add_argument("--g_ckpt", type=str, default=None) 252 | parser.add_argument("--e_ckpt", type=str, default=None) 253 | 254 | parser.add_argument("--device", type=str, default='cuda') 255 | parser.add_argument("--iter", type=int, default=1000000) 256 | parser.add_argument("--batch", type=int, default=8) 257 | parser.add_argument("--lr", type=float, default=0.0001) 258 | parser.add_argument("--local_rank", type=int, default=0) 259 | 260 | parser.add_argument("--vgg", type=float, default=1.0) 261 | parser.add_argument("--l2", type=float, default=1.0) 262 | parser.add_argument("--adv", type=float, default=0.05) 263 | parser.add_argument("--r1", type=float, default=10) 264 | parser.add_argument("--d_reg_every", type=int, default=16) 265 | 266 | parser.add_argument("--tensorboard", action="store_true") 267 | 268 | args = parser.parse_args() 269 | 270 | device = args.device 271 | 272 | args.start_iter = 0 273 | 274 | print("load generator:", args.g_ckpt) 275 | g_ckpt = torch.load(args.g_ckpt, map_location=lambda storage, loc: storage) 276 | g_args = g_ckpt['args'] 277 | 278 | args.size = g_args.size 279 | args.latent = g_args.latent 280 | args.n_mlp = g_args.n_mlp 281 | args.channel_multiplier = g_args.channel_multiplier 282 | 283 | encoder = Encoder(args.size, args.latent).to(device) 284 | generator = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 285 | discriminator = Discriminator(args.size, channel_multiplier=args.channel_multiplier).to(device) 286 | 287 | e_optim = optim.Adam( 288 | encoder.parameters(), 289 | lr=args.lr, 290 | betas=(0.9, 0.99), 291 | ) 292 | 293 | d_optim = optim.Adam( 294 | discriminator.parameters(), 295 | lr=args.lr, 296 | betas=(0.9, 0.99), 297 | ) 298 | 299 | generator.load_state_dict(g_ckpt["g_ema"]) 300 | discriminator.load_state_dict(g_ckpt["d"]) 301 | d_optim.load_state_dict(g_ckpt["d_optim"]) 302 | 303 | if args.e_ckpt is not None: 304 | print("resume training:", args.e_ckpt) 305 | e_ckpt = torch.load(args.e_ckpt, map_location=lambda storage, loc: storage) 306 | 307 | encoder.load_state_dict(e_ckpt["e"]) 308 | e_optim.load_state_dict(e_ckpt["e_optim"]) 309 | discriminator.load_state_dict(e_ckpt["d"]) 310 | d_optim.load_state_dict(e_ckpt["d_optim"]) 311 | 312 | try: 313 | ckpt_name = os.path.basename(args.e_ckpt) 314 | args.start_iter = int(os.path.splitext(ckpt_name.split('_')[-1])[0]) 315 | except ValueError: 316 | pass 317 | 318 | transform = transforms.Compose( 319 | [ 320 | transforms.RandomHorizontalFlip(), 321 | transforms.ToTensor(), 322 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 323 | ] 324 | ) 325 | 326 | dataset = MultiResolutionDataset(args.data, transform, args.size) 327 | loader = data.DataLoader( 328 | dataset, 329 | batch_size=args.batch, 330 | sampler=data_sampler(dataset, shuffle=True), 331 | drop_last=True, 332 | ) 333 | 334 | train(args, loader, encoder, generator, discriminator, e_optim, d_optim, device) 335 | --------------------------------------------------------------------------------