├── LICENSE
├── LICENSE-KSH
├── LICENSE-NVIDIA
├── README.md
├── dataset.py
├── examples
├── 000001.png
├── 000002.png
├── 000003.png
├── 000004.png
├── 000005.png
├── 000006.png
├── 000007.png
├── 000008.png
├── 000009.png
└── 000010.png
├── imgs
├── 0k.png
├── 1M.png
├── face2webtoon.png
├── interpolation_domain_guided_encoder.png
├── interpolation_idinversion_500steps.png
└── interpolation_results.png
├── interpolate.ipynb
├── model.py
├── op
├── __init__.py
├── fused_act.py
├── fused_bias_act.cpp
├── fused_bias_act_kernel.cu
├── upfirdn2d.cpp
├── upfirdn2d.py
└── upfirdn2d_kernel.cu
└── train_encoder.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Bryan Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE-KSH:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Kim Seonghyeon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/LICENSE-NVIDIA:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 |
3 |
4 | Nvidia Source Code License-NC
5 |
6 | =======================================================================
7 |
8 | 1. Definitions
9 |
10 | "Licensor" means any person or entity that distributes its Work.
11 |
12 | "Software" means the original work of authorship made available under
13 | this License.
14 |
15 | "Work" means the Software and any additions to or derivative works of
16 | the Software that are made available under this License.
17 |
18 | "Nvidia Processors" means any central processing unit (CPU), graphics
19 | processing unit (GPU), field-programmable gate array (FPGA),
20 | application-specific integrated circuit (ASIC) or any combination
21 | thereof designed, made, sold, or provided by Nvidia or its affiliates.
22 |
23 | The terms "reproduce," "reproduction," "derivative works," and
24 | "distribution" have the meaning as provided under U.S. copyright law;
25 | provided, however, that for the purposes of this License, derivative
26 | works shall not include works that remain separable from, or merely
27 | link (or bind by name) to the interfaces of, the Work.
28 |
29 | Works, including the Software, are "made available" under this License
30 | by including in or with the Work either (a) a copyright notice
31 | referencing the applicability of this License to the Work, or (b) a
32 | copy of this License.
33 |
34 | 2. License Grants
35 |
36 | 2.1 Copyright Grant. Subject to the terms and conditions of this
37 | License, each Licensor grants to you a perpetual, worldwide,
38 | non-exclusive, royalty-free, copyright license to reproduce,
39 | prepare derivative works of, publicly display, publicly perform,
40 | sublicense and distribute its Work and any resulting derivative
41 | works in any form.
42 |
43 | 3. Limitations
44 |
45 | 3.1 Redistribution. You may reproduce or distribute the Work only
46 | if (a) you do so under this License, (b) you include a complete
47 | copy of this License with your distribution, and (c) you retain
48 | without modification any copyright, patent, trademark, or
49 | attribution notices that are present in the Work.
50 |
51 | 3.2 Derivative Works. You may specify that additional or different
52 | terms apply to the use, reproduction, and distribution of your
53 | derivative works of the Work ("Your Terms") only if (a) Your Terms
54 | provide that the use limitation in Section 3.3 applies to your
55 | derivative works, and (b) you identify the specific derivative
56 | works that are subject to Your Terms. Notwithstanding Your Terms,
57 | this License (including the redistribution requirements in Section
58 | 3.1) will continue to apply to the Work itself.
59 |
60 | 3.3 Use Limitation. The Work and any derivative works thereof only
61 | may be used or intended for use non-commercially. The Work or
62 | derivative works thereof may be used or intended for use by Nvidia
63 | or its affiliates commercially or non-commercially. As used herein,
64 | "non-commercially" means for research or evaluation purposes only.
65 |
66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim
67 | against any Licensor (including any claim, cross-claim or
68 | counterclaim in a lawsuit) to enforce any patents that you allege
69 | are infringed by any Work, then your rights under this License from
70 | such Licensor (including the grants in Sections 2.1 and 2.2) will
71 | terminate immediately.
72 |
73 | 3.5 Trademarks. This License does not grant any rights to use any
74 | Licensor's or its affiliates' names, logos, or trademarks, except
75 | as necessary to reproduce the notices described in this License.
76 |
77 | 3.6 Termination. If you violate any term of this License, then your
78 | rights under this License (including the grants in Sections 2.1 and
79 | 2.2) will terminate immediately.
80 |
81 | 4. Disclaimer of Warranty.
82 |
83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
87 | THIS LICENSE.
88 |
89 | 5. Limitation of Liability.
90 |
91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
99 | THE POSSIBILITY OF SUCH DAMAGES.
100 |
101 | =======================================================================
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # In-Domain GAN Inversion for Real Image Editing
2 |
3 |
4 | Based on **Seonghyeon Kim's Pytorch Implementation of StyleGAN2**
5 |
6 | [[Paper](https://arxiv.org/pdf/2004.00049.pdf)] [[Official Code](https://github.com/genforce/idinvert)] [[StyleGAN2 Pytorch](https://github.com/rosinality/stylegan2-pytorch)]
7 |
8 | ## Train Encoder
9 |
10 | ```
11 | python train_encoder.py
12 | ```
13 |
14 | **0k iter**\
15 |
16 |
17 | **1M iter**\
18 |
\
19 | [[encoder checkpoint](https://drive.google.com/file/d/1QQuZGtHgD24Dn5E21Z2Ik25EPng58MoU/view?usp=sharing)] [[generator checkpoint](https://drive.google.com/file/d/1TH77dUsqcq50htIZT6DljFYr4T_ziJli/view)]
20 |
21 | **Note:** The encoder architecture and loss weights are different from the original implemetation.
22 |
23 |
24 |
25 | ## Interpolation
26 |
27 | ```
28 | interpolate.ipynb
29 | ```
30 |
31 | **Domain-Guided Encoder (Initial projection)**\
32 |
33 |
34 | **In-Domain Inversion (500 steps)**\
35 |
36 |
37 | **Inperpolation Result**\
38 |
39 |
40 | ## Encoder + Model Interpolation
41 | [[Paper](https://arxiv.org/abs/2010.05334)] [[Naver Webtoon Model](https://github.com/bryandlee/naver-webtoon-faces#stylegan2)]
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from io import BytesIO
3 | import multiprocessing
4 | from functools import partial
5 |
6 | from PIL import Image
7 | import lmdb
8 | from tqdm import tqdm
9 | from torchvision import datasets
10 | from torchvision.transforms import functional as trans_fn
11 | from torch.utils.data import Dataset
12 |
13 |
14 | class MultiResolutionDataset(Dataset):
15 | def __init__(self, path, transform, resolution=256):
16 | self.env = lmdb.open(
17 | path,
18 | max_readers=32,
19 | readonly=True,
20 | lock=False,
21 | readahead=False,
22 | meminit=False,
23 | )
24 |
25 | if not self.env:
26 | raise IOError('Cannot open lmdb dataset', path)
27 |
28 | with self.env.begin(write=False) as txn:
29 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
30 |
31 | self.resolution = resolution
32 | self.transform = transform
33 |
34 | def __len__(self):
35 | return self.length
36 |
37 | def __getitem__(self, index):
38 | with self.env.begin(write=False) as txn:
39 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
40 | img_bytes = txn.get(key)
41 |
42 | buffer = BytesIO(img_bytes)
43 | img = Image.open(buffer)
44 | img = self.transform(img)
45 |
46 | return img
47 |
48 |
49 | def resize_and_convert(img, size, resample, quality=100):
50 | img = trans_fn.resize(img, size, resample)
51 | img = trans_fn.center_crop(img, size)
52 | buffer = BytesIO()
53 | img.save(buffer, format='jpeg', quality=quality)
54 | val = buffer.getvalue()
55 |
56 | return val
57 |
58 |
59 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100):
60 | imgs = []
61 |
62 | for size in sizes:
63 | imgs.append(resize_and_convert(img, size, resample, quality))
64 |
65 | return imgs
66 |
67 |
68 | def resize_worker(img_file, sizes, resample):
69 | i, file = img_file
70 | img = Image.open(file)
71 | img = img.convert('RGB')
72 | out = resize_multiple(img, sizes=sizes, resample=resample)
73 |
74 | return i, out
75 |
76 |
77 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS):
78 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
79 |
80 | files = sorted(dataset.imgs, key=lambda x: x[0])
81 | files = [(i, file) for i, (file, label) in enumerate(files)]
82 | total = 0
83 |
84 | with multiprocessing.Pool(n_worker) as pool:
85 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
86 | for size, img in zip(sizes, imgs):
87 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8')
88 |
89 | with env.begin(write=True) as txn:
90 | txn.put(key, img)
91 |
92 | total += 1
93 |
94 | with env.begin(write=True) as txn:
95 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8'))
96 |
97 |
98 | if __name__ == '__main__':
99 | parser = argparse.ArgumentParser()
100 | parser.add_argument('--out', type=str)
101 | parser.add_argument('--size', type=str, default='128,256,512,1024')
102 | parser.add_argument('--n_worker', type=int, default=8)
103 | parser.add_argument('--resample', type=str, default='lanczos')
104 | parser.add_argument('path', type=str)
105 |
106 | args = parser.parse_args()
107 |
108 | resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR}
109 | resample = resample_map[args.resample]
110 |
111 | sizes = [int(s.strip()) for s in args.size.split(',')]
112 |
113 | print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes))
114 |
115 | imgset = datasets.ImageFolder(args.path)
116 |
117 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
118 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
119 |
--------------------------------------------------------------------------------
/examples/000001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000001.png
--------------------------------------------------------------------------------
/examples/000002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000002.png
--------------------------------------------------------------------------------
/examples/000003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000003.png
--------------------------------------------------------------------------------
/examples/000004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000004.png
--------------------------------------------------------------------------------
/examples/000005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000005.png
--------------------------------------------------------------------------------
/examples/000006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000006.png
--------------------------------------------------------------------------------
/examples/000007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000007.png
--------------------------------------------------------------------------------
/examples/000008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000008.png
--------------------------------------------------------------------------------
/examples/000009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000009.png
--------------------------------------------------------------------------------
/examples/000010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/examples/000010.png
--------------------------------------------------------------------------------
/imgs/0k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/0k.png
--------------------------------------------------------------------------------
/imgs/1M.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/1M.png
--------------------------------------------------------------------------------
/imgs/face2webtoon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/face2webtoon.png
--------------------------------------------------------------------------------
/imgs/interpolation_domain_guided_encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_domain_guided_encoder.png
--------------------------------------------------------------------------------
/imgs/interpolation_idinversion_500steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_idinversion_500steps.png
--------------------------------------------------------------------------------
/imgs/interpolation_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bryandlee/stylegan2-encoder-pytorch/a07b25cc0ba1a9900386e3d1315b757cb9b6df4d/imgs/interpolation_results.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import functools
4 | import operator
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.autograd import Function
10 |
11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
12 |
13 |
14 | class PixelNorm(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20 |
21 |
22 | def make_kernel(k):
23 | k = torch.tensor(k, dtype=torch.float32)
24 |
25 | if k.ndim == 1:
26 | k = k[None, :] * k[:, None]
27 |
28 | k /= k.sum()
29 |
30 | return k
31 |
32 |
33 | class Upsample(nn.Module):
34 | def __init__(self, kernel, factor=2):
35 | super().__init__()
36 |
37 | self.factor = factor
38 | kernel = make_kernel(kernel) * (factor ** 2)
39 | self.register_buffer('kernel', kernel)
40 |
41 | p = kernel.shape[0] - factor
42 |
43 | pad0 = (p + 1) // 2 + factor - 1
44 | pad1 = p // 2
45 |
46 | self.pad = (pad0, pad1)
47 |
48 | def forward(self, input):
49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50 |
51 | return out
52 |
53 |
54 | class Downsample(nn.Module):
55 | def __init__(self, kernel, factor=2):
56 | super().__init__()
57 |
58 | self.factor = factor
59 | kernel = make_kernel(kernel)
60 | self.register_buffer('kernel', kernel)
61 |
62 | p = kernel.shape[0] - factor
63 |
64 | pad0 = (p + 1) // 2
65 | pad1 = p // 2
66 |
67 | self.pad = (pad0, pad1)
68 |
69 | def forward(self, input):
70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71 |
72 | return out
73 |
74 |
75 | class Blur(nn.Module):
76 | def __init__(self, kernel, pad, upsample_factor=1):
77 | super().__init__()
78 |
79 | kernel = make_kernel(kernel)
80 |
81 | if upsample_factor > 1:
82 | kernel = kernel * (upsample_factor ** 2)
83 |
84 | self.register_buffer('kernel', kernel)
85 |
86 | self.pad = pad
87 |
88 | def forward(self, input):
89 | out = upfirdn2d(input, self.kernel, pad=self.pad)
90 |
91 | return out
92 |
93 |
94 | class EqualConv2d(nn.Module):
95 | def __init__(
96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97 | ):
98 | super().__init__()
99 |
100 | self.weight = nn.Parameter(
101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102 | )
103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104 |
105 | self.stride = stride
106 | self.padding = padding
107 |
108 | if bias:
109 | self.bias = nn.Parameter(torch.zeros(out_channel))
110 |
111 | else:
112 | self.bias = None
113 |
114 | def forward(self, input):
115 | out = F.conv2d(
116 | input,
117 | self.weight * self.scale,
118 | bias=self.bias,
119 | stride=self.stride,
120 | padding=self.padding,
121 | )
122 |
123 | return out
124 |
125 | def __repr__(self):
126 | return (
127 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
128 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
129 | )
130 |
131 |
132 | class EqualLinear(nn.Module):
133 | def __init__(
134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135 | ):
136 | super().__init__()
137 |
138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139 |
140 | if bias:
141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142 |
143 | else:
144 | self.bias = None
145 |
146 | self.activation = activation
147 |
148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149 | self.lr_mul = lr_mul
150 |
151 | def forward(self, input):
152 | if self.activation:
153 | out = F.linear(input, self.weight * self.scale)
154 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
155 |
156 | else:
157 | out = F.linear(
158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
159 | )
160 |
161 | return out
162 |
163 | def __repr__(self):
164 | return (
165 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
166 | )
167 |
168 |
169 | class ScaledLeakyReLU(nn.Module):
170 | def __init__(self, negative_slope=0.2):
171 | super().__init__()
172 |
173 | self.negative_slope = negative_slope
174 |
175 | def forward(self, input):
176 | out = F.leaky_relu(input, negative_slope=self.negative_slope)
177 |
178 | return out * math.sqrt(2)
179 |
180 |
181 | class ModulatedConv2d(nn.Module):
182 | def __init__(
183 | self,
184 | in_channel,
185 | out_channel,
186 | kernel_size,
187 | style_dim,
188 | demodulate=True,
189 | upsample=False,
190 | downsample=False,
191 | blur_kernel=[1, 3, 3, 1],
192 | ):
193 | super().__init__()
194 |
195 | self.eps = 1e-8
196 | self.kernel_size = kernel_size
197 | self.in_channel = in_channel
198 | self.out_channel = out_channel
199 | self.upsample = upsample
200 | self.downsample = downsample
201 |
202 | if upsample:
203 | factor = 2
204 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
205 | pad0 = (p + 1) // 2 + factor - 1
206 | pad1 = p // 2 + 1
207 |
208 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
209 |
210 | if downsample:
211 | factor = 2
212 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
213 | pad0 = (p + 1) // 2
214 | pad1 = p // 2
215 |
216 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
217 |
218 | fan_in = in_channel * kernel_size ** 2
219 | self.scale = 1 / math.sqrt(fan_in)
220 | self.padding = kernel_size // 2
221 |
222 | self.weight = nn.Parameter(
223 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
224 | )
225 |
226 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
227 |
228 | self.demodulate = demodulate
229 |
230 | def __repr__(self):
231 | return (
232 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
233 | f'upsample={self.upsample}, downsample={self.downsample})'
234 | )
235 |
236 | def forward(self, input, style):
237 | batch, in_channel, height, width = input.shape
238 |
239 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
240 | weight = self.scale * self.weight * style
241 |
242 | if self.demodulate:
243 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
244 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
245 |
246 | weight = weight.view(
247 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
248 | )
249 |
250 | if self.upsample:
251 | input = input.view(1, batch * in_channel, height, width)
252 | weight = weight.view(
253 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
254 | )
255 | weight = weight.transpose(1, 2).reshape(
256 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
257 | )
258 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
259 | _, _, height, width = out.shape
260 | out = out.view(batch, self.out_channel, height, width)
261 | out = self.blur(out)
262 |
263 | elif self.downsample:
264 | input = self.blur(input)
265 | _, _, height, width = input.shape
266 | input = input.view(1, batch * in_channel, height, width)
267 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
268 | _, _, height, width = out.shape
269 | out = out.view(batch, self.out_channel, height, width)
270 |
271 | else:
272 | input = input.view(1, batch * in_channel, height, width)
273 | out = F.conv2d(input, weight, padding=self.padding, groups=batch)
274 | _, _, height, width = out.shape
275 | out = out.view(batch, self.out_channel, height, width)
276 |
277 | return out
278 |
279 |
280 | class NoiseInjection(nn.Module):
281 | def __init__(self):
282 | super().__init__()
283 |
284 | self.weight = nn.Parameter(torch.zeros(1))
285 |
286 | def forward(self, image, noise=None):
287 | if noise is None:
288 | batch, _, height, width = image.shape
289 | noise = image.new_empty(batch, 1, height, width).normal_()
290 |
291 | return image + self.weight * noise
292 |
293 |
294 | class ConstantInput(nn.Module):
295 | def __init__(self, channel, size=4):
296 | super().__init__()
297 |
298 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
299 |
300 | def forward(self, input):
301 | batch = input.shape[0]
302 | out = self.input.repeat(batch, 1, 1, 1)
303 |
304 | return out
305 |
306 |
307 | class StyledConv(nn.Module):
308 | def __init__(
309 | self,
310 | in_channel,
311 | out_channel,
312 | kernel_size,
313 | style_dim,
314 | upsample=False,
315 | blur_kernel=[1, 3, 3, 1],
316 | demodulate=True,
317 | ):
318 | super().__init__()
319 |
320 | self.conv = ModulatedConv2d(
321 | in_channel,
322 | out_channel,
323 | kernel_size,
324 | style_dim,
325 | upsample=upsample,
326 | blur_kernel=blur_kernel,
327 | demodulate=demodulate,
328 | )
329 |
330 | self.noise = NoiseInjection()
331 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
332 | # self.activate = ScaledLeakyReLU(0.2)
333 | self.activate = FusedLeakyReLU(out_channel)
334 |
335 | def forward(self, input, style, noise=None):
336 | out = self.conv(input, style)
337 | out = self.noise(out, noise=noise)
338 | # out = out + self.bias
339 | out = self.activate(out)
340 |
341 | return out
342 |
343 |
344 | class ToRGB(nn.Module):
345 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
346 | super().__init__()
347 |
348 | if upsample:
349 | self.upsample = Upsample(blur_kernel)
350 |
351 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
352 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
353 |
354 | def forward(self, input, style, skip=None):
355 | out = self.conv(input, style)
356 | out = out + self.bias
357 |
358 | if skip is not None:
359 | skip = self.upsample(skip)
360 |
361 | out = out + skip
362 |
363 | return out
364 |
365 |
366 | class Generator(nn.Module):
367 | def __init__(
368 | self,
369 | size,
370 | style_dim,
371 | n_mlp,
372 | channel_multiplier=2,
373 | blur_kernel=[1, 3, 3, 1],
374 | lr_mlp=0.01,
375 | ):
376 | super().__init__()
377 |
378 | self.size = size
379 |
380 | self.style_dim = style_dim
381 |
382 | layers = [PixelNorm()]
383 |
384 | for i in range(n_mlp):
385 | layers.append(
386 | EqualLinear(
387 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
388 | )
389 | )
390 |
391 | self.style = nn.Sequential(*layers)
392 |
393 | self.channels = {
394 | 4: 512,
395 | 8: 512,
396 | 16: 512,
397 | 32: 512,
398 | 64: 256 * channel_multiplier,
399 | 128: 128 * channel_multiplier,
400 | 256: 64 * channel_multiplier,
401 | 512: 32 * channel_multiplier,
402 | 1024: 16 * channel_multiplier,
403 | }
404 |
405 | self.input = ConstantInput(self.channels[4])
406 | self.conv1 = StyledConv(
407 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
408 | )
409 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
410 |
411 | self.log_size = int(math.log(size, 2))
412 | self.num_layers = (self.log_size - 2) * 2 + 1
413 |
414 | self.convs = nn.ModuleList()
415 | self.upsamples = nn.ModuleList()
416 | self.to_rgbs = nn.ModuleList()
417 | self.noises = nn.Module()
418 |
419 | in_channel = self.channels[4]
420 |
421 | for layer_idx in range(self.num_layers):
422 | res = (layer_idx + 5) // 2
423 | shape = [1, 1, 2 ** res, 2 ** res]
424 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
425 |
426 | for i in range(3, self.log_size + 1):
427 | out_channel = self.channels[2 ** i]
428 |
429 | self.convs.append(
430 | StyledConv(
431 | in_channel,
432 | out_channel,
433 | 3,
434 | style_dim,
435 | upsample=True,
436 | blur_kernel=blur_kernel,
437 | )
438 | )
439 |
440 | self.convs.append(
441 | StyledConv(
442 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
443 | )
444 | )
445 |
446 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
447 |
448 | in_channel = out_channel
449 |
450 | self.n_latent = self.log_size * 2 - 2
451 |
452 | def make_noise(self):
453 | device = self.input.input.device
454 |
455 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
456 |
457 | for i in range(3, self.log_size + 1):
458 | for _ in range(2):
459 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
460 |
461 | return noises
462 |
463 | def mean_latent(self, n_latent):
464 | latent_in = torch.randn(
465 | n_latent, self.style_dim, device=self.input.input.device
466 | )
467 | latent = self.style(latent_in).mean(0, keepdim=True)
468 |
469 | return latent
470 |
471 | def get_latent(self, input):
472 | return self.style(input)
473 |
474 | def forward(
475 | self,
476 | styles,
477 | return_latents=False,
478 | inject_index=None,
479 | truncation=1,
480 | truncation_latent=None,
481 | input_is_latent=False,
482 | noise=None,
483 | randomize_noise=True,
484 | ):
485 | if not input_is_latent:
486 | styles = [self.style(s) for s in styles]
487 |
488 | if noise is None:
489 | if randomize_noise:
490 | noise = [None] * self.num_layers
491 | else:
492 | noise = [
493 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
494 | ]
495 |
496 | if truncation < 1:
497 | style_t = []
498 |
499 | for style in styles:
500 | style_t.append(
501 | truncation_latent + truncation * (style - truncation_latent)
502 | )
503 |
504 | styles = style_t
505 |
506 | if len(styles) < 2:
507 | inject_index = self.n_latent
508 |
509 | if styles[0].ndim < 3:
510 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
511 |
512 | else:
513 | latent = styles[0]
514 |
515 | else:
516 | if inject_index is None:
517 | inject_index = random.randint(1, self.n_latent - 1)
518 |
519 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
520 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
521 |
522 | latent = torch.cat([latent, latent2], 1)
523 |
524 | out = self.input(latent)
525 | out = self.conv1(out, latent[:, 0], noise=noise[0])
526 |
527 | skip = self.to_rgb1(out, latent[:, 1])
528 |
529 | i = 1
530 | for conv1, conv2, noise1, noise2, to_rgb in zip(
531 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
532 | ):
533 | out = conv1(out, latent[:, i], noise=noise1)
534 | out = conv2(out, latent[:, i + 1], noise=noise2)
535 | skip = to_rgb(out, latent[:, i + 2], skip)
536 |
537 | i += 2
538 |
539 | image = skip
540 |
541 | if return_latents:
542 | return image, latent
543 |
544 | else:
545 | return image, None
546 |
547 |
548 | class ConvLayer(nn.Sequential):
549 | def __init__(
550 | self,
551 | in_channel,
552 | out_channel,
553 | kernel_size,
554 | downsample=False,
555 | blur_kernel=[1, 3, 3, 1],
556 | bias=True,
557 | activate=True,
558 | ):
559 | layers = []
560 |
561 | if downsample:
562 | factor = 2
563 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
564 | pad0 = (p + 1) // 2
565 | pad1 = p // 2
566 |
567 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
568 |
569 | stride = 2
570 | self.padding = 0
571 |
572 | else:
573 | stride = 1
574 | self.padding = kernel_size // 2
575 |
576 | layers.append(
577 | EqualConv2d(
578 | in_channel,
579 | out_channel,
580 | kernel_size,
581 | padding=self.padding,
582 | stride=stride,
583 | bias=bias and not activate,
584 | )
585 | )
586 |
587 | if activate:
588 | if bias:
589 | layers.append(FusedLeakyReLU(out_channel))
590 |
591 | else:
592 | layers.append(ScaledLeakyReLU(0.2))
593 |
594 | super().__init__(*layers)
595 |
596 |
597 | class ResBlock(nn.Module):
598 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
599 | super().__init__()
600 |
601 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
602 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
603 |
604 | self.skip = ConvLayer(
605 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
606 | )
607 |
608 | def forward(self, input):
609 | out = self.conv1(input)
610 | out = self.conv2(out)
611 |
612 | skip = self.skip(input)
613 | out = (out + skip) / math.sqrt(2)
614 |
615 | return out
616 |
617 |
618 | class Discriminator(nn.Module):
619 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
620 | super().__init__()
621 |
622 | channels = {
623 | 4: 512,
624 | 8: 512,
625 | 16: 512,
626 | 32: 512,
627 | 64: 256 * channel_multiplier,
628 | 128: 128 * channel_multiplier,
629 | 256: 64 * channel_multiplier,
630 | 512: 32 * channel_multiplier,
631 | 1024: 16 * channel_multiplier,
632 | }
633 |
634 | convs = [ConvLayer(3, channels[size], 1)]
635 |
636 | log_size = int(math.log(size, 2))
637 |
638 | in_channel = channels[size]
639 |
640 | for i in range(log_size, 2, -1):
641 | out_channel = channels[2 ** (i - 1)]
642 |
643 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
644 |
645 | in_channel = out_channel
646 |
647 | self.convs = nn.Sequential(*convs)
648 |
649 | self.stddev_group = 4
650 | self.stddev_feat = 1
651 |
652 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
653 | self.final_linear = nn.Sequential(
654 | EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
655 | EqualLinear(channels[4], 1),
656 | )
657 |
658 | def forward(self, input):
659 | out = self.convs(input)
660 |
661 | batch, channel, height, width = out.shape
662 | group = min(batch, self.stddev_group)
663 | stddev = out.view(
664 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
665 | )
666 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
667 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
668 | stddev = stddev.repeat(group, 1, height, width)
669 | out = torch.cat([out, stddev], 1)
670 |
671 | out = self.final_conv(out)
672 |
673 | out = out.view(batch, -1)
674 | out = self.final_linear(out)
675 |
676 | return out
677 |
678 |
679 | class Encoder(nn.Module):
680 | def __init__(self, size, w_dim=512):
681 | super().__init__()
682 |
683 | channels = {
684 | 4: 512,
685 | 8: 512,
686 | 16: 512,
687 | 32: 512,
688 | 64: 256,
689 | 128: 128,
690 | 256: 64,
691 | 512: 32,
692 | 1024: 16
693 | }
694 |
695 | self.w_dim = w_dim
696 | log_size = int(math.log(size, 2))
697 |
698 | self.n_latents = log_size*2 - 2
699 |
700 | convs = [ConvLayer(3, channels[size], 1)]
701 |
702 | in_channel = channels[size]
703 | for i in range(log_size, 2, -1):
704 | out_channel = channels[2 ** (i - 1)]
705 | convs.append(ResBlock(in_channel, out_channel))
706 | in_channel = out_channel
707 |
708 | convs.append(EqualConv2d(in_channel, self.n_latents*self.w_dim, 4, padding=0, bias=False))
709 |
710 | self.convs = nn.Sequential(*convs)
711 |
712 | def forward(self, input):
713 | out = self.convs(input)
714 | return out.view(len(input), self.n_latents, self.w_dim)
715 |
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 | module_path = os.path.dirname(__file__)
11 | fused = load(
12 | "fused",
13 | sources=[
14 | os.path.join(module_path, "fused_bias_act.cpp"),
15 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
16 | ],
17 | )
18 |
19 |
20 | class FusedLeakyReLUFunctionBackward(Function):
21 | @staticmethod
22 | def forward(ctx, grad_output, out, negative_slope, scale):
23 | ctx.save_for_backward(out)
24 | ctx.negative_slope = negative_slope
25 | ctx.scale = scale
26 |
27 | empty = grad_output.new_empty(0)
28 |
29 | grad_input = fused.fused_bias_act(
30 | grad_output, empty, out, 3, 1, negative_slope, scale
31 | )
32 |
33 | dim = [0]
34 |
35 | if grad_input.ndim > 2:
36 | dim += list(range(2, grad_input.ndim))
37 |
38 | grad_bias = grad_input.sum(dim).detach()
39 |
40 | return grad_input, grad_bias
41 |
42 | @staticmethod
43 | def backward(ctx, gradgrad_input, gradgrad_bias):
44 | out, = ctx.saved_tensors
45 | gradgrad_out = fused.fused_bias_act(
46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
47 | )
48 |
49 | return gradgrad_out, None, None, None
50 |
51 |
52 | class FusedLeakyReLUFunction(Function):
53 | @staticmethod
54 | def forward(ctx, input, bias, negative_slope, scale):
55 | empty = input.new_empty(0)
56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
57 | ctx.save_for_backward(out)
58 | ctx.negative_slope = negative_slope
59 | ctx.scale = scale
60 |
61 | return out
62 |
63 | @staticmethod
64 | def backward(ctx, grad_output):
65 | out, = ctx.saved_tensors
66 |
67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
68 | grad_output, out, ctx.negative_slope, ctx.scale
69 | )
70 |
71 | return grad_input, grad_bias, None, None
72 |
73 |
74 | class FusedLeakyReLU(nn.Module):
75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
76 | super().__init__()
77 |
78 | self.bias = nn.Parameter(torch.zeros(channel))
79 | self.negative_slope = negative_slope
80 | self.scale = scale
81 |
82 | def forward(self, input):
83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
84 |
85 |
86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
87 | if input.device.type == "cpu":
88 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
89 | return (
90 | F.leaky_relu(
91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
92 | )
93 | * scale
94 | )
95 |
96 | else:
97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
98 |
--------------------------------------------------------------------------------
/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 |
9 | module_path = os.path.dirname(__file__)
10 | upfirdn2d_op = load(
11 | "upfirdn2d",
12 | sources=[
13 | os.path.join(module_path, "upfirdn2d.cpp"),
14 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
15 | ],
16 | )
17 |
18 |
19 | class UpFirDn2dBackward(Function):
20 | @staticmethod
21 | def forward(
22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
23 | ):
24 |
25 | up_x, up_y = up
26 | down_x, down_y = down
27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
28 |
29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
30 |
31 | grad_input = upfirdn2d_op.upfirdn2d(
32 | grad_output,
33 | grad_kernel,
34 | down_x,
35 | down_y,
36 | up_x,
37 | up_y,
38 | g_pad_x0,
39 | g_pad_x1,
40 | g_pad_y0,
41 | g_pad_y1,
42 | )
43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
44 |
45 | ctx.save_for_backward(kernel)
46 |
47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
48 |
49 | ctx.up_x = up_x
50 | ctx.up_y = up_y
51 | ctx.down_x = down_x
52 | ctx.down_y = down_y
53 | ctx.pad_x0 = pad_x0
54 | ctx.pad_x1 = pad_x1
55 | ctx.pad_y0 = pad_y0
56 | ctx.pad_y1 = pad_y1
57 | ctx.in_size = in_size
58 | ctx.out_size = out_size
59 |
60 | return grad_input
61 |
62 | @staticmethod
63 | def backward(ctx, gradgrad_input):
64 | kernel, = ctx.saved_tensors
65 |
66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
67 |
68 | gradgrad_out = upfirdn2d_op.upfirdn2d(
69 | gradgrad_input,
70 | kernel,
71 | ctx.up_x,
72 | ctx.up_y,
73 | ctx.down_x,
74 | ctx.down_y,
75 | ctx.pad_x0,
76 | ctx.pad_x1,
77 | ctx.pad_y0,
78 | ctx.pad_y1,
79 | )
80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
81 | gradgrad_out = gradgrad_out.view(
82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
83 | )
84 |
85 | return gradgrad_out, None, None, None, None, None, None, None, None
86 |
87 |
88 | class UpFirDn2d(Function):
89 | @staticmethod
90 | def forward(ctx, input, kernel, up, down, pad):
91 | up_x, up_y = up
92 | down_x, down_y = down
93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
94 |
95 | kernel_h, kernel_w = kernel.shape
96 | batch, channel, in_h, in_w = input.shape
97 | ctx.in_size = input.shape
98 |
99 | input = input.reshape(-1, in_h, in_w, 1)
100 |
101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
102 |
103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
105 | ctx.out_size = (out_h, out_w)
106 |
107 | ctx.up = (up_x, up_y)
108 | ctx.down = (down_x, down_y)
109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
110 |
111 | g_pad_x0 = kernel_w - pad_x0 - 1
112 | g_pad_y0 = kernel_h - pad_y0 - 1
113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
115 |
116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
117 |
118 | out = upfirdn2d_op.upfirdn2d(
119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
120 | )
121 | # out = out.view(major, out_h, out_w, minor)
122 | out = out.view(-1, channel, out_h, out_w)
123 |
124 | return out
125 |
126 | @staticmethod
127 | def backward(ctx, grad_output):
128 | kernel, grad_kernel = ctx.saved_tensors
129 |
130 | grad_input = UpFirDn2dBackward.apply(
131 | grad_output,
132 | kernel,
133 | grad_kernel,
134 | ctx.up,
135 | ctx.down,
136 | ctx.pad,
137 | ctx.g_pad,
138 | ctx.in_size,
139 | ctx.out_size,
140 | )
141 |
142 | return grad_input, None, None, None, None
143 |
144 |
145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
146 | if input.device.type == "cpu":
147 | out = upfirdn2d_native(
148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
149 | )
150 |
151 | else:
152 | out = UpFirDn2d.apply(
153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
154 | )
155 |
156 | return out
157 |
158 |
159 | def upfirdn2d_native(
160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
161 | ):
162 | _, channel, in_h, in_w = input.shape
163 | input = input.reshape(-1, in_h, in_w, 1)
164 |
165 | _, in_h, in_w, minor = input.shape
166 | kernel_h, kernel_w = kernel.shape
167 |
168 | out = input.view(-1, in_h, 1, in_w, 1, minor)
169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
171 |
172 | out = F.pad(
173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
174 | )
175 | out = out[
176 | :,
177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
179 | :,
180 | ]
181 |
182 | out = out.permute(0, 3, 1, 2)
183 | out = out.reshape(
184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
185 | )
186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
187 | out = F.conv2d(out, w)
188 | out = out.reshape(
189 | -1,
190 | minor,
191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
193 | )
194 | out = out.permute(0, 2, 3, 1)
195 | out = out[:, ::down_y, ::down_x, :]
196 |
197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
199 |
200 | return out.view(-1, channel, out_h, out_w)
201 |
--------------------------------------------------------------------------------
/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18 | int c = a / b;
19 |
20 | if (c * b > a) {
21 | c--;
22 | }
23 |
24 | return c;
25 | }
26 |
27 | struct UpFirDn2DKernelParams {
28 | int up_x;
29 | int up_y;
30 | int down_x;
31 | int down_y;
32 | int pad_x0;
33 | int pad_x1;
34 | int pad_y0;
35 | int pad_y1;
36 |
37 | int major_dim;
38 | int in_h;
39 | int in_w;
40 | int minor_dim;
41 | int kernel_h;
42 | int kernel_w;
43 | int out_h;
44 | int out_w;
45 | int loop_major;
46 | int loop_x;
47 | };
48 |
49 | template
50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51 | const scalar_t *kernel,
52 | const UpFirDn2DKernelParams p) {
53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54 | int out_y = minor_idx / p.minor_dim;
55 | minor_idx -= out_y * p.minor_dim;
56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57 | int major_idx_base = blockIdx.z * p.loop_major;
58 |
59 | if (out_x_base >= p.out_w || out_y >= p.out_h ||
60 | major_idx_base >= p.major_dim) {
61 | return;
62 | }
63 |
64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68 |
69 | for (int loop_major = 0, major_idx = major_idx_base;
70 | loop_major < p.loop_major && major_idx < p.major_dim;
71 | loop_major++, major_idx++) {
72 | for (int loop_x = 0, out_x = out_x_base;
73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78 |
79 | const scalar_t *x_p =
80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81 | minor_idx];
82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83 | int x_px = p.minor_dim;
84 | int k_px = -p.up_x;
85 | int x_py = p.in_w * p.minor_dim;
86 | int k_py = -p.up_y * p.kernel_w;
87 |
88 | scalar_t v = 0.0f;
89 |
90 | for (int y = 0; y < h; y++) {
91 | for (int x = 0; x < w; x++) {
92 | v += static_cast(*x_p) * static_cast(*k_p);
93 | x_p += x_px;
94 | k_p += k_px;
95 | }
96 |
97 | x_p += x_py - w * x_px;
98 | k_p += k_py - w * k_px;
99 | }
100 |
101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102 | minor_idx] = v;
103 | }
104 | }
105 | }
106 |
107 | template
109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110 | const scalar_t *kernel,
111 | const UpFirDn2DKernelParams p) {
112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114 |
115 | __shared__ volatile float sk[kernel_h][kernel_w];
116 | __shared__ volatile float sx[tile_in_h][tile_in_w];
117 |
118 | int minor_idx = blockIdx.x;
119 | int tile_out_y = minor_idx / p.minor_dim;
120 | minor_idx -= tile_out_y * p.minor_dim;
121 | tile_out_y *= tile_out_h;
122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123 | int major_idx_base = blockIdx.z * p.loop_major;
124 |
125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126 | major_idx_base >= p.major_dim) {
127 | return;
128 | }
129 |
130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131 | tap_idx += blockDim.x) {
132 | int ky = tap_idx / kernel_w;
133 | int kx = tap_idx - ky * kernel_w;
134 | scalar_t v = 0.0;
135 |
136 | if (kx < p.kernel_w & ky < p.kernel_h) {
137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138 | }
139 |
140 | sk[ky][kx] = v;
141 | }
142 |
143 | for (int loop_major = 0, major_idx = major_idx_base;
144 | loop_major < p.loop_major & major_idx < p.major_dim;
145 | loop_major++, major_idx++) {
146 | for (int loop_x = 0, tile_out_x = tile_out_x_base;
147 | loop_x < p.loop_x & tile_out_x < p.out_w;
148 | loop_x++, tile_out_x += tile_out_w) {
149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151 | int tile_in_x = floor_div(tile_mid_x, up_x);
152 | int tile_in_y = floor_div(tile_mid_y, up_y);
153 |
154 | __syncthreads();
155 |
156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157 | in_idx += blockDim.x) {
158 | int rel_in_y = in_idx / tile_in_w;
159 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
160 | int in_x = rel_in_x + tile_in_x;
161 | int in_y = rel_in_y + tile_in_y;
162 |
163 | scalar_t v = 0.0;
164 |
165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167 | p.minor_dim +
168 | minor_idx];
169 | }
170 |
171 | sx[rel_in_y][rel_in_x] = v;
172 | }
173 |
174 | __syncthreads();
175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176 | out_idx += blockDim.x) {
177 | int rel_out_y = out_idx / tile_out_w;
178 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
179 | int out_x = rel_out_x + tile_out_x;
180 | int out_y = rel_out_y + tile_out_y;
181 |
182 | int mid_x = tile_mid_x + rel_out_x * down_x;
183 | int mid_y = tile_mid_y + rel_out_y * down_y;
184 | int in_x = floor_div(mid_x, up_x);
185 | int in_y = floor_div(mid_y, up_y);
186 | int rel_in_x = in_x - tile_in_x;
187 | int rel_in_y = in_y - tile_in_y;
188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190 |
191 | scalar_t v = 0.0;
192 |
193 | #pragma unroll
194 | for (int y = 0; y < kernel_h / up_y; y++)
195 | #pragma unroll
196 | for (int x = 0; x < kernel_w / up_x; x++)
197 | v += sx[rel_in_y + y][rel_in_x + x] *
198 | sk[kernel_y + y * up_y][kernel_x + x * up_x];
199 |
200 | if (out_x < p.out_w & out_y < p.out_h) {
201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202 | minor_idx] = v;
203 | }
204 | }
205 | }
206 | }
207 | }
208 |
209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210 | const torch::Tensor &kernel, int up_x, int up_y,
211 | int down_x, int down_y, int pad_x0, int pad_x1,
212 | int pad_y0, int pad_y1) {
213 | int curDevice = -1;
214 | cudaGetDevice(&curDevice);
215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
216 |
217 | UpFirDn2DKernelParams p;
218 |
219 | auto x = input.contiguous();
220 | auto k = kernel.contiguous();
221 |
222 | p.major_dim = x.size(0);
223 | p.in_h = x.size(1);
224 | p.in_w = x.size(2);
225 | p.minor_dim = x.size(3);
226 | p.kernel_h = k.size(0);
227 | p.kernel_w = k.size(1);
228 | p.up_x = up_x;
229 | p.up_y = up_y;
230 | p.down_x = down_x;
231 | p.down_y = down_y;
232 | p.pad_x0 = pad_x0;
233 | p.pad_x1 = pad_x1;
234 | p.pad_y0 = pad_y0;
235 | p.pad_y1 = pad_y1;
236 |
237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238 | p.down_y;
239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240 | p.down_x;
241 |
242 | auto out =
243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244 |
245 | int mode = -1;
246 |
247 | int tile_out_h = -1;
248 | int tile_out_w = -1;
249 |
250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251 | p.kernel_h <= 4 && p.kernel_w <= 4) {
252 | mode = 1;
253 | tile_out_h = 16;
254 | tile_out_w = 64;
255 | }
256 |
257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258 | p.kernel_h <= 3 && p.kernel_w <= 3) {
259 | mode = 2;
260 | tile_out_h = 16;
261 | tile_out_w = 64;
262 | }
263 |
264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265 | p.kernel_h <= 4 && p.kernel_w <= 4) {
266 | mode = 3;
267 | tile_out_h = 16;
268 | tile_out_w = 64;
269 | }
270 |
271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272 | p.kernel_h <= 2 && p.kernel_w <= 2) {
273 | mode = 4;
274 | tile_out_h = 16;
275 | tile_out_w = 64;
276 | }
277 |
278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279 | p.kernel_h <= 4 && p.kernel_w <= 4) {
280 | mode = 5;
281 | tile_out_h = 8;
282 | tile_out_w = 32;
283 | }
284 |
285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286 | p.kernel_h <= 2 && p.kernel_w <= 2) {
287 | mode = 6;
288 | tile_out_h = 8;
289 | tile_out_w = 32;
290 | }
291 |
292 | dim3 block_size;
293 | dim3 grid_size;
294 |
295 | if (tile_out_h > 0 && tile_out_w > 0) {
296 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
297 | p.loop_x = 1;
298 | block_size = dim3(32 * 8, 1, 1);
299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301 | (p.major_dim - 1) / p.loop_major + 1);
302 | } else {
303 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
304 | p.loop_x = 4;
305 | block_size = dim3(4, 32, 1);
306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308 | (p.major_dim - 1) / p.loop_major + 1);
309 | }
310 |
311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312 | switch (mode) {
313 | case 1:
314 | upfirdn2d_kernel
315 | <<>>(out.data_ptr(),
316 | x.data_ptr(),
317 | k.data_ptr(), p);
318 |
319 | break;
320 |
321 | case 2:
322 | upfirdn2d_kernel
323 | <<>>(out.data_ptr(),
324 | x.data_ptr(),
325 | k.data_ptr(), p);
326 |
327 | break;
328 |
329 | case 3:
330 | upfirdn2d_kernel
331 | <<>>(out.data_ptr(),
332 | x.data_ptr(),
333 | k.data_ptr(), p);
334 |
335 | break;
336 |
337 | case 4:
338 | upfirdn2d_kernel
339 | <<>>(out.data_ptr(),
340 | x.data_ptr(),
341 | k.data_ptr(), p);
342 |
343 | break;
344 |
345 | case 5:
346 | upfirdn2d_kernel
347 | <<>>(out.data_ptr(),
348 | x.data_ptr(),
349 | k.data_ptr(), p);
350 |
351 | break;
352 |
353 | case 6:
354 | upfirdn2d_kernel
355 | <<>>(out.data_ptr(),
356 | x.data_ptr(),
357 | k.data_ptr(), p);
358 |
359 | break;
360 |
361 | default:
362 | upfirdn2d_kernel_large<<>>(
363 | out.data_ptr(), x.data_ptr(),
364 | k.data_ptr(), p);
365 | }
366 | });
367 |
368 | return out;
369 | }
--------------------------------------------------------------------------------
/train_encoder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import random
4 | import os
5 |
6 | import numpy as np
7 | import torch
8 | from torch import nn, autograd, optim
9 | from torch.nn import functional as F
10 | from torch.utils import data
11 | import torch.distributed as dist
12 | import torchvision
13 | from torchvision import transforms, utils
14 | from tqdm import tqdm
15 |
16 | from model import Encoder, Generator, Discriminator
17 | from dataset import MultiResolutionDataset
18 |
19 | try:
20 | from tensorboardX import SummaryWriter
21 | except ImportError:
22 | SummaryWriter = None
23 |
24 |
25 | def data_sampler(dataset, shuffle):
26 | if shuffle:
27 | return data.RandomSampler(dataset)
28 |
29 | else:
30 | return data.SequentialSampler(dataset)
31 |
32 |
33 | def requires_grad(model, flag=True):
34 | for p in model.parameters():
35 | p.requires_grad = flag
36 |
37 |
38 | def sample_data(loader):
39 | while True:
40 | for batch in loader:
41 | yield batch
42 |
43 |
44 | def d_logistic_loss(real_pred, fake_pred):
45 | real_loss = F.softplus(-real_pred)
46 | fake_loss = F.softplus(fake_pred)
47 |
48 | return real_loss.mean() + fake_loss.mean()
49 |
50 |
51 | def d_r1_loss(real_pred, real_img):
52 | grad_real, = autograd.grad(
53 | outputs=real_pred.sum(), inputs=real_img, create_graph=True
54 | )
55 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
56 |
57 | return grad_penalty
58 |
59 |
60 | def g_nonsaturating_loss(fake_pred):
61 | loss = F.softplus(-fake_pred).mean()
62 |
63 | return loss
64 |
65 |
66 | class VGGLoss(nn.Module):
67 | def __init__(self, device, n_layers=5):
68 | super().__init__()
69 |
70 | feature_layers = (2, 7, 12, 21, 30)
71 | self.weights = (1.0, 1.0, 1.0, 1.0, 1.0)
72 |
73 | vgg = torchvision.models.vgg19(pretrained=True).features
74 |
75 | self.layers = nn.ModuleList()
76 | prev_layer = 0
77 | for next_layer in feature_layers[:n_layers]:
78 | layers = nn.Sequential()
79 | for layer in range(prev_layer, next_layer):
80 | layers.add_module(str(layer), vgg[layer])
81 | self.layers.append(layers.to(device))
82 | prev_layer = next_layer
83 |
84 | for param in self.parameters():
85 | param.requires_grad = False
86 |
87 | self.criterion = nn.L1Loss().to(device)
88 |
89 | def forward(self, source, target):
90 | loss = 0
91 | for layer, weight in zip(self.layers, self.weights):
92 | source = layer(source)
93 | with torch.no_grad():
94 | target = layer(target)
95 | loss += weight*self.criterion(source, target)
96 |
97 | return loss
98 |
99 |
100 | def train(args, loader, encoder, generator, discriminator, e_optim, d_optim, device):
101 | loader = sample_data(loader)
102 |
103 | pbar = range(args.iter)
104 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
105 |
106 | e_loss_val = 0
107 | d_loss_val = 0
108 | r1_loss = torch.tensor(0.0, device=device)
109 | loss_dict = {}
110 | vgg_loss = VGGLoss(device=device)
111 |
112 | accum = 0.5 ** (32 / (10 * 1000))
113 |
114 | requires_grad(generator, False)
115 |
116 | truncation = 0.7
117 | trunc = generator.mean_latent(4096).detach()
118 | trunc.requires_grad = False
119 |
120 | if SummaryWriter and args.tensorboard:
121 | logger = SummaryWriter(logdir='./checkpoint')
122 |
123 | for idx in pbar:
124 | i = idx + args.start_iter
125 |
126 | if i > args.iter:
127 | print("Done!")
128 |
129 | break
130 |
131 | # D update
132 | requires_grad(encoder, False)
133 | requires_grad(discriminator, True)
134 |
135 |
136 | real_img = next(loader)
137 | real_img = real_img.to(device)
138 |
139 | latents = encoder(real_img)
140 | recon_img, _ = generator([latents],
141 | input_is_latent=True,
142 | truncation=truncation,
143 | truncation_latent=trunc,
144 | randomize_noise=False)
145 |
146 | recon_pred = discriminator(recon_img)
147 | real_pred = discriminator(real_img)
148 | d_loss = d_logistic_loss(real_pred, recon_pred)
149 |
150 | loss_dict["d"] = d_loss
151 |
152 | discriminator.zero_grad()
153 | d_loss.backward()
154 | d_optim.step()
155 |
156 | d_regularize = i % args.d_reg_every == 0
157 |
158 | if d_regularize:
159 | real_img.requires_grad = True
160 | real_pred = discriminator(real_img)
161 | r1_loss = d_r1_loss(real_pred, real_img)
162 |
163 | discriminator.zero_grad()
164 | (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
165 |
166 | d_optim.step()
167 |
168 | loss_dict["r1"] = r1_loss
169 |
170 | # E update
171 | requires_grad(encoder, True)
172 | requires_grad(discriminator, False)
173 |
174 | real_img = real_img.detach()
175 | real_img.requires_grad = False
176 |
177 | latents = encoder(real_img)
178 | recon_img, _ = generator([latents],
179 | input_is_latent=True,
180 | truncation=truncation,
181 | truncation_latent=trunc,
182 | randomize_noise=False)
183 |
184 | recon_vgg_loss = vgg_loss(recon_img, real_img)
185 | loss_dict["vgg"] = recon_vgg_loss * args.vgg
186 |
187 | recon_l2_loss = F.mse_loss(recon_img, real_img)
188 | loss_dict["l2"] = recon_l2_loss * args.l2
189 |
190 | recon_pred = discriminator(recon_img)
191 | adv_loss = g_nonsaturating_loss(recon_pred) * args.adv
192 | loss_dict["adv"] = adv_loss
193 |
194 | e_loss = recon_vgg_loss + recon_l2_loss + adv_loss
195 | loss_dict["e_loss"] = e_loss
196 |
197 |
198 | encoder.zero_grad()
199 | e_loss.backward()
200 | e_optim.step()
201 |
202 | e_loss_val = loss_dict["e_loss"].item()
203 | vgg_loss_val = loss_dict["vgg"].item()
204 | l2_loss_val = loss_dict["l2"].item()
205 | adv_loss_val = loss_dict["adv"].item()
206 | d_loss_val = loss_dict["d"].item()
207 | r1_val = loss_dict["r1"].item()
208 |
209 | pbar.set_description(
210 | (
211 | f"e: {e_loss_val:.4f}; vgg: {vgg_loss_val:.4f}; l2: {l2_loss_val:.4f}; adv: {adv_loss_val:.4f}; d: {d_loss_val:.4f}; r1: {r1_val:.4f}; "
212 | )
213 | )
214 |
215 | if SummaryWriter and args.tensorboard:
216 | logger.add_scalar('E_loss/total', e_loss_val, i)
217 | logger.add_scalar('E_loss/vgg', vgg_loss_val, i)
218 | logger.add_scalar('E_loss/l2', l2_loss_val, i)
219 | logger.add_scalar('E_loss/adv', adv_loss_val, i)
220 | logger.add_scalar('D_loss/adv', d_loss_val, i)
221 | logger.add_scalar('D_loss/r1', r1_val, i)
222 |
223 | if i % 100 == 0:
224 | with torch.no_grad():
225 | sample = torch.cat([real_img.detach(), recon_img.detach()])
226 | utils.save_image(
227 | sample,
228 | f"sample/{str(i).zfill(6)}.png",
229 | nrow=int(args.batch),
230 | normalize=True,
231 | range=(-1, 1),
232 | )
233 |
234 | if i % 10000 == 0:
235 | torch.save(
236 | {
237 | "e": encoder.state_dict(),
238 | "d": discriminator.state_dict(),
239 | "e_optim": e_optim.state_dict(),
240 | "d_optim": d_optim.state_dict(),
241 | "args": args,
242 | },
243 | f"checkpoint/encoder_{str(i).zfill(6)}.pt",
244 | )
245 |
246 |
247 | if __name__ == "__main__":
248 | parser = argparse.ArgumentParser()
249 |
250 | parser.add_argument("--data", type=str, default=None)
251 | parser.add_argument("--g_ckpt", type=str, default=None)
252 | parser.add_argument("--e_ckpt", type=str, default=None)
253 |
254 | parser.add_argument("--device", type=str, default='cuda')
255 | parser.add_argument("--iter", type=int, default=1000000)
256 | parser.add_argument("--batch", type=int, default=8)
257 | parser.add_argument("--lr", type=float, default=0.0001)
258 | parser.add_argument("--local_rank", type=int, default=0)
259 |
260 | parser.add_argument("--vgg", type=float, default=1.0)
261 | parser.add_argument("--l2", type=float, default=1.0)
262 | parser.add_argument("--adv", type=float, default=0.05)
263 | parser.add_argument("--r1", type=float, default=10)
264 | parser.add_argument("--d_reg_every", type=int, default=16)
265 |
266 | parser.add_argument("--tensorboard", action="store_true")
267 |
268 | args = parser.parse_args()
269 |
270 | device = args.device
271 |
272 | args.start_iter = 0
273 |
274 | print("load generator:", args.g_ckpt)
275 | g_ckpt = torch.load(args.g_ckpt, map_location=lambda storage, loc: storage)
276 | g_args = g_ckpt['args']
277 |
278 | args.size = g_args.size
279 | args.latent = g_args.latent
280 | args.n_mlp = g_args.n_mlp
281 | args.channel_multiplier = g_args.channel_multiplier
282 |
283 | encoder = Encoder(args.size, args.latent).to(device)
284 | generator = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
285 | discriminator = Discriminator(args.size, channel_multiplier=args.channel_multiplier).to(device)
286 |
287 | e_optim = optim.Adam(
288 | encoder.parameters(),
289 | lr=args.lr,
290 | betas=(0.9, 0.99),
291 | )
292 |
293 | d_optim = optim.Adam(
294 | discriminator.parameters(),
295 | lr=args.lr,
296 | betas=(0.9, 0.99),
297 | )
298 |
299 | generator.load_state_dict(g_ckpt["g_ema"])
300 | discriminator.load_state_dict(g_ckpt["d"])
301 | d_optim.load_state_dict(g_ckpt["d_optim"])
302 |
303 | if args.e_ckpt is not None:
304 | print("resume training:", args.e_ckpt)
305 | e_ckpt = torch.load(args.e_ckpt, map_location=lambda storage, loc: storage)
306 |
307 | encoder.load_state_dict(e_ckpt["e"])
308 | e_optim.load_state_dict(e_ckpt["e_optim"])
309 | discriminator.load_state_dict(e_ckpt["d"])
310 | d_optim.load_state_dict(e_ckpt["d_optim"])
311 |
312 | try:
313 | ckpt_name = os.path.basename(args.e_ckpt)
314 | args.start_iter = int(os.path.splitext(ckpt_name.split('_')[-1])[0])
315 | except ValueError:
316 | pass
317 |
318 | transform = transforms.Compose(
319 | [
320 | transforms.RandomHorizontalFlip(),
321 | transforms.ToTensor(),
322 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
323 | ]
324 | )
325 |
326 | dataset = MultiResolutionDataset(args.data, transform, args.size)
327 | loader = data.DataLoader(
328 | dataset,
329 | batch_size=args.batch,
330 | sampler=data_sampler(dataset, shuffle=True),
331 | drop_last=True,
332 | )
333 |
334 | train(args, loader, encoder, generator, discriminator, e_optim, d_optim, device)
335 |
--------------------------------------------------------------------------------